# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import torch
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from omegaconf import DictConfig, OmegaConf, open_dict

from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common.parts.preprocessing import parsers
from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.losses.fastpitchloss import DurationLoss, EnergyLoss, MelLoss, PitchLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.collections.tts.modules.fastpitch import FastPitchModule
from nemo.collections.tts.parts.mixins import FastPitchAdapterModelMixin
from nemo.collections.tts.parts.utils.callbacks import LoggingCallback
from nemo.collections.tts.parts.utils.helpers import (
    batch_from_ragged,
    g2p_backward_compatible_support,
    plot_alignment_to_numpy,
    plot_spectrogram_to_numpy,
    process_batch,
    sample_tts_input,
)
from nemo.core.classes import Exportable
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import (
    Index,
    LengthsType,
    MelSpectrogramType,
    ProbsType,
    RegressionValuesType,
    TokenDurationType,
    TokenIndex,
    TokenLogDurationType,
)
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging, model_utils


@dataclass
class G2PConfig:
    _target_: str = "nemo.collections.tts.g2p.models.en_us_arpabet.EnglishG2p"
    phoneme_dict: str = "scripts/tts_dataset_files/cmudict-0.7b_nv22.10"
    heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722"
    phoneme_probability: float = 0.5


@dataclass
class TextTokenizer:
    _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.EnglishPhonemesTokenizer"
    punct: bool = True
    stresses: bool = True
    chars: bool = True
    apostrophe: bool = True
    pad_with_space: bool = True
    add_blank_at: bool = True
    g2p: G2PConfig = field(default_factory=lambda: G2PConfig())


@dataclass
class TextTokenizerConfig:
    text_tokenizer: TextTokenizer = field(default_factory=lambda: TextTokenizer())


class FastPitchModel(SpectrogramGenerator, Exportable, FastPitchAdapterModelMixin):
    """FastPitch model (https://arxiv.org/abs/2006.06873) that is used to generate mel spectrogram from text."""

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Setup normalizer
        self.normalizer = None
        self.text_normalizer_call = None
        self.text_normalizer_call_kwargs = {}
        self._setup_normalizer(cfg)

        self.learn_alignment = cfg.get("learn_alignment", False)

        # Setup vocabulary (=tokenizer) and input_fft_kwargs (supported only with self.learn_alignment=True)
        input_fft_kwargs = {}
        if self.learn_alignment:
            self.vocab = None

            self.ds_class = cfg.train_ds.dataset._target_
            self.ds_class_name = self.ds_class.split(".")[-1]
            if not self.ds_class in [
                "nemo.collections.tts.data.dataset.TTSDataset",
                "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset",
                "nemo.collections.tts.torch.data.TTSDataset",
            ]:
                raise ValueError(f"Unknown dataset class: {self.ds_class}.")

            self._setup_tokenizer(cfg)
            assert self.vocab is not None
            input_fft_kwargs["n_embed"] = len(self.vocab.tokens)
            input_fft_kwargs["padding_idx"] = self.vocab.pad

        self._parser = None
        self._tb_logger = None
        super().__init__(cfg=cfg, trainer=trainer)

        self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100)
        self.log_images = cfg.get("log_images", False)
        self.log_train_images = False

        default_prosody_loss_scale = 0.1 if self.learn_alignment else 1.0
        dur_loss_scale = cfg.get("dur_loss_scale", default_prosody_loss_scale)
        pitch_loss_scale = cfg.get("pitch_loss_scale", default_prosody_loss_scale)
        energy_loss_scale = cfg.get("energy_loss_scale", default_prosody_loss_scale)

        self.mel_loss_fn = MelLoss()
        self.pitch_loss_fn = PitchLoss(loss_scale=pitch_loss_scale)
        self.duration_loss_fn = DurationLoss(loss_scale=dur_loss_scale)
        self.energy_loss_fn = EnergyLoss(loss_scale=energy_loss_scale)

        self.aligner = None
        if self.learn_alignment:
            aligner_loss_scale = cfg.get("aligner_loss_scale", 1.0)
            self.aligner = instantiate(self._cfg.alignment_module)
            self.forward_sum_loss_fn = ForwardSumLoss(loss_scale=aligner_loss_scale)
            self.bin_loss_fn = BinLoss(loss_scale=aligner_loss_scale)

        self.preprocessor = instantiate(self._cfg.preprocessor)
        input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs)
        output_fft = instantiate(self._cfg.output_fft)
        duration_predictor = instantiate(self._cfg.duration_predictor)
        pitch_predictor = instantiate(self._cfg.pitch_predictor)
        speaker_encoder = instantiate(self._cfg.get("speaker_encoder", None))
        energy_embedding_kernel_size = cfg.get("energy_embedding_kernel_size", 0)
        energy_predictor = instantiate(self._cfg.get("energy_predictor", None))

        # [TODO] may remove if we change the pre-trained config
        # cfg: condition_types = [ "add" ]
        n_speakers = cfg.get("n_speakers", 0)
        speaker_emb_condition_prosody = cfg.get("speaker_emb_condition_prosody", False)
        speaker_emb_condition_decoder = cfg.get("speaker_emb_condition_decoder", False)
        speaker_emb_condition_aligner = cfg.get("speaker_emb_condition_aligner", False)
        min_token_duration = cfg.get("min_token_duration", 0)
        use_log_energy = cfg.get("use_log_energy", True)
        if n_speakers > 1 and "add" not in input_fft.cond_input.condition_types:
            input_fft.cond_input.condition_types.append("add")
        if speaker_emb_condition_prosody:
            duration_predictor.cond_input.condition_types.append("add")
            pitch_predictor.cond_input.condition_types.append("add")
        if speaker_emb_condition_decoder:
            output_fft.cond_input.condition_types.append("add")
        if speaker_emb_condition_aligner and self.aligner is not None:
            self.aligner.cond_input.condition_types.append("add")

        self.fastpitch = FastPitchModule(
            input_fft,
            output_fft,
            duration_predictor,
            pitch_predictor,
            energy_predictor,
            self.aligner,
            speaker_encoder,
            n_speakers,
            cfg.symbols_embedding_dim,
            cfg.pitch_embedding_kernel_size,
            energy_embedding_kernel_size,
            cfg.n_mel_channels,
            min_token_duration,
            cfg.max_token_duration,
            use_log_energy,
        )
        self._input_types = self._output_types = None
        self.export_config = {
            "emb_range": (0, self.fastpitch.encoder.word_emb.num_embeddings),
            "enable_volume": False,
            "enable_ragged_batches": False,
        }
        if self.fastpitch.speaker_emb is not None:
            self.export_config["num_speakers"] = cfg.n_speakers

        self.log_config = cfg.get("log_config", None)

        # Adapter modules setup (from FastPitchAdapterModelMixin)
        self.setup_adapters()

    def _get_default_text_tokenizer_conf(self):
        text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
        return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer))

    def _setup_tokenizer(self, cfg):
        text_tokenizer_kwargs = {}

        if "g2p" in cfg.text_tokenizer:
            # for backward compatibility
            if (
                self._is_model_being_restored()
                and (cfg.text_tokenizer.g2p.get('_target_', None) is not None)
                and cfg.text_tokenizer.g2p["_target_"].startswith("nemo_text_processing.g2p")
            ):
                cfg.text_tokenizer.g2p["_target_"] = g2p_backward_compatible_support(
                    cfg.text_tokenizer.g2p["_target_"]
                )

            g2p_kwargs = {}

            if "phoneme_dict" in cfg.text_tokenizer.g2p:
                g2p_kwargs["phoneme_dict"] = self.register_artifact(
                    'text_tokenizer.g2p.phoneme_dict',
                    cfg.text_tokenizer.g2p.phoneme_dict,
                )

            if "heteronyms" in cfg.text_tokenizer.g2p:
                g2p_kwargs["heteronyms"] = self.register_artifact(
                    'text_tokenizer.g2p.heteronyms',
                    cfg.text_tokenizer.g2p.heteronyms,
                )

            # for backward compatability
            text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p, **g2p_kwargs)

        # TODO @xueyang: rename the instance of tokenizer because vocab is misleading.
        self.vocab = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs)

    @property
    def tb_logger(self):
        if self._tb_logger is None:
            if self.logger is None and self.logger.experiment is None:
                return None
            tb_logger = self.logger.experiment
            for logger in self.trainer.loggers:
                if isinstance(logger, TensorBoardLogger):
                    tb_logger = logger.experiment
                    break
            self._tb_logger = tb_logger
        return self._tb_logger

    @property
    def parser(self):
        if self._parser is not None:
            return self._parser

        if self.learn_alignment:
            self._parser = self.vocab.encode
        else:
            self._parser = parsers.make_parser(
                labels=self._cfg.labels,
                name='en',
                unk_id=-1,
                blank_id=-1,
                do_normalize=True,
                abbreviation_version="fastpitch",
                make_table=False,
            )
        return self._parser

    def parse(self, str_input: str, normalize=True) -> torch.tensor:
        if self.training:
            logging.warning("parse() is meant to be called in eval mode.")
        if isinstance(str_input, Hypothesis):
            str_input = str_input.text
        if normalize and self.text_normalizer_call is not None:
            str_input = self.text_normalizer_call(str_input, **self.text_normalizer_call_kwargs)

        if self.learn_alignment:
            eval_phon_mode = contextlib.nullcontext()
            if hasattr(self.vocab, "set_phone_prob"):
                eval_phon_mode = self.vocab.set_phone_prob(prob=1.0)

            # Disable mixed g2p representation if necessary
            with eval_phon_mode:
                tokens = self.parser(str_input)
        else:
            tokens = self.parser(str_input)

        x = torch.tensor(tokens).unsqueeze_(0).long().to(self.device)
        return x

    @typecheck(
        input_types={
            "text": NeuralType(('B', 'T_text'), TokenIndex()),
            "durs": NeuralType(('B', 'T_text'), TokenDurationType()),
            "pitch": NeuralType(('B', 'T_audio'), RegressionValuesType()),
            "energy": NeuralType(('B', 'T_audio'), RegressionValuesType(), optional=True),
            "speaker": NeuralType(('B'), Index(), optional=True),
            "pace": NeuralType(optional=True),
            "spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
            "attn_prior": NeuralType(('B', 'T_spec', 'T_text'), ProbsType(), optional=True),
            "mel_lens": NeuralType(('B'), LengthsType(), optional=True),
            "input_lens": NeuralType(('B'), LengthsType(), optional=True),
            # reference_* data is used for multi-speaker FastPitch training
            "reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True),
            "reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True),
        }
    )
    def forward(
        self,
        *,
        text,
        durs=None,
        pitch=None,
        energy=None,
        speaker=None,
        pace=1.0,
        spec=None,
        attn_prior=None,
        mel_lens=None,
        input_lens=None,
        reference_spec=None,
        reference_spec_lens=None,
    ):
        return self.fastpitch(
            text=text,
            durs=durs,
            pitch=pitch,
            energy=energy,
            speaker=speaker,
            pace=pace,
            spec=spec,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=input_lens,
            reference_spec=reference_spec,
            reference_spec_lens=reference_spec_lens,
        )

    @typecheck(output_types={"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType())})
    def generate_spectrogram(
        self,
        tokens: 'torch.tensor',
        speaker: Optional[int] = None,
        pace: float = 1.0,
        reference_spec: Optional['torch.tensor'] = None,
        reference_spec_lens: Optional['torch.tensor'] = None,
    ) -> torch.tensor:
        if self.training:
            logging.warning("generate_spectrogram() is meant to be called in eval mode.")
        if isinstance(speaker, int):
            speaker = torch.tensor([speaker]).to(self.device)
        spect, *_ = self(
            text=tokens,
            durs=None,
            pitch=None,
            speaker=speaker,
            pace=pace,
            reference_spec=reference_spec,
            reference_spec_lens=reference_spec_lens,
        )
        return spect

    def training_step(self, batch, batch_idx):
        attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = (
            None,
            None,
            None,
            None,
            None,
            None,
        )
        if self.learn_alignment:
            if self.ds_class == "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
                batch_dict = batch
            else:
                batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
            audio = batch_dict.get("audio")
            audio_lens = batch_dict.get("audio_lens")
            text = batch_dict.get("text")
            text_lens = batch_dict.get("text_lens")
            attn_prior = batch_dict.get("align_prior_matrix", None)
            pitch = batch_dict.get("pitch", None)
            energy = batch_dict.get("energy", None)
            speaker = batch_dict.get("speaker_id", None)
            reference_audio = batch_dict.get("reference_audio", None)
            reference_audio_len = batch_dict.get("reference_audio_lens", None)
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

        mels, spec_len = self.preprocessor(input_signal=audio, length=audio_lens)
        reference_spec, reference_spec_len = None, None
        if reference_audio is not None:
            reference_spec, reference_spec_len = self.preprocessor(
                input_signal=reference_audio, length=reference_audio_len
            )

        (
            mels_pred,
            _,
            _,
            log_durs_pred,
            pitch_pred,
            attn_soft,
            attn_logprob,
            attn_hard,
            attn_hard_dur,
            pitch,
            energy_pred,
            energy_tgt,
        ) = self(
            text=text,
            durs=durs,
            pitch=pitch,
            energy=energy,
            speaker=speaker,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            reference_spec=reference_spec,
            reference_spec_lens=reference_spec_len,
            attn_prior=attn_prior,
            mel_lens=spec_len,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss_fn(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss_fn(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
        loss = mel_loss + dur_loss
        if self.learn_alignment:
            ctc_loss = self.forward_sum_loss_fn(attn_logprob=attn_logprob, in_lens=text_lens, out_lens=spec_len)
            bin_loss_weight = min(self.current_epoch / self.bin_loss_warmup_epochs, 1.0) * 1.0
            bin_loss = self.bin_loss_fn(hard_attention=attn_hard, soft_attention=attn_soft) * bin_loss_weight
            loss += ctc_loss + bin_loss

        pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
        energy_loss = self.energy_loss_fn(energy_predicted=energy_pred, energy_tgt=energy_tgt, length=text_lens)
        loss += pitch_loss + energy_loss

        self.log("t_loss", loss)
        self.log("t_mel_loss", mel_loss)
        self.log("t_dur_loss", dur_loss)
        self.log("t_pitch_loss", pitch_loss)
        if energy_tgt is not None:
            self.log("t_energy_loss", energy_loss)
        if self.learn_alignment:
            self.log("t_ctc_loss", ctc_loss)
            self.log("t_bin_loss", bin_loss)

        # Log images to tensorboard
        if self.log_images and self.log_train_images and isinstance(self.logger, TensorBoardLogger):
            self.log_train_images = False

            self.tb_logger.add_image(
                "train_mel_target",
                plot_spectrogram_to_numpy(mels[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = mels_pred[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "train_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            if self.learn_alignment:
                attn = attn_hard[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_attn",
                    plot_alignment_to_numpy(attn.T),
                    self.global_step,
                    dataformats="HWC",
                )
                soft_attn = attn_soft[0].data.cpu().float().numpy().squeeze()
                self.tb_logger.add_image(
                    "train_soft_attn",
                    plot_alignment_to_numpy(soft_attn.T),
                    self.global_step,
                    dataformats="HWC",
                )

        return loss

    def validation_step(self, batch, batch_idx):
        attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = (
            None,
            None,
            None,
            None,
            None,
            None,
        )
        if self.learn_alignment:
            if self.ds_class == "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
                batch_dict = batch
            else:
                batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set)
            audio = batch_dict.get("audio")
            audio_lens = batch_dict.get("audio_lens")
            text = batch_dict.get("text")
            text_lens = batch_dict.get("text_lens")
            attn_prior = batch_dict.get("align_prior_matrix", None)
            pitch = batch_dict.get("pitch", None)
            energy = batch_dict.get("energy", None)
            speaker = batch_dict.get("speaker_id", None)
            reference_audio = batch_dict.get("reference_audio", None)
            reference_audio_len = batch_dict.get("reference_audio_lens", None)
        else:
            audio, audio_lens, text, text_lens, durs, pitch, speaker = batch

        mels, mel_lens = self.preprocessor(input_signal=audio, length=audio_lens)
        reference_spec, reference_spec_len = None, None
        if reference_audio is not None:
            reference_spec, reference_spec_len = self.preprocessor(
                input_signal=reference_audio, length=reference_audio_len
            )

        # Calculate val loss on ground truth durations to better align L2 loss in time
        (
            mels_pred,
            _,
            _,
            log_durs_pred,
            pitch_pred,
            _,
            _,
            _,
            attn_hard_dur,
            pitch,
            energy_pred,
            energy_tgt,
        ) = self(
            text=text,
            durs=durs,
            pitch=pitch,
            energy=energy,
            speaker=speaker,
            pace=1.0,
            spec=mels if self.learn_alignment else None,
            reference_spec=reference_spec,
            reference_spec_lens=reference_spec_len,
            attn_prior=attn_prior,
            mel_lens=mel_lens,
            input_lens=text_lens,
        )
        if durs is None:
            durs = attn_hard_dur

        mel_loss = self.mel_loss_fn(spect_predicted=mels_pred, spect_tgt=mels)
        dur_loss = self.duration_loss_fn(log_durs_predicted=log_durs_pred, durs_tgt=durs, len=text_lens)
        pitch_loss = self.pitch_loss_fn(pitch_predicted=pitch_pred, pitch_tgt=pitch, len=text_lens)
        energy_loss = self.energy_loss_fn(energy_predicted=energy_pred, energy_tgt=energy_tgt, length=text_lens)
        loss = mel_loss + dur_loss + pitch_loss + energy_loss

        val_outputs = {
            "val_loss": loss,
            "mel_loss": mel_loss,
            "dur_loss": dur_loss,
            "pitch_loss": pitch_loss,
            "energy_loss": energy_loss if energy_tgt is not None else None,
            "mel_target": mels if batch_idx == 0 else None,
            "mel_pred": mels_pred if batch_idx == 0 else None,
        }
        self.validation_step_outputs.append(val_outputs)
        return val_outputs

    def on_validation_epoch_end(self):
        collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean()
        val_loss = collect("val_loss")
        mel_loss = collect("mel_loss")
        dur_loss = collect("dur_loss")
        pitch_loss = collect("pitch_loss")
        self.log("val_loss", val_loss, sync_dist=True)
        self.log("val_mel_loss", mel_loss, sync_dist=True)
        self.log("val_dur_loss", dur_loss, sync_dist=True)
        self.log("val_pitch_loss", pitch_loss, sync_dist=True)
        if self.validation_step_outputs[0]["energy_loss"] is not None:
            energy_loss = collect("energy_loss")
            self.log("val_energy_loss", energy_loss, sync_dist=True)

        _, _, _, _, _, spec_target, spec_predict = self.validation_step_outputs[0].values()

        if self.log_images and isinstance(self.logger, TensorBoardLogger):
            self.tb_logger.add_image(
                "val_mel_target",
                plot_spectrogram_to_numpy(spec_target[0].data.cpu().float().numpy()),
                self.global_step,
                dataformats="HWC",
            )
            spec_predict = spec_predict[0].data.cpu().float().numpy()
            self.tb_logger.add_image(
                "val_mel_predicted",
                plot_spectrogram_to_numpy(spec_predict),
                self.global_step,
                dataformats="HWC",
            )
            self.log_train_images = True
        self.validation_step_outputs.clear()  # free memory)

    def _setup_train_dataloader(self, cfg):
        phon_mode = contextlib.nullcontext()
        if hasattr(self.vocab, "set_phone_prob"):
            phon_mode = self.vocab.set_phone_prob(self.vocab.phoneme_probability)

        with phon_mode:
            dataset = instantiate(
                cfg.dataset,
                text_tokenizer=self.vocab,
            )

        sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
        return torch.utils.data.DataLoader(
            dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params
        )

    def _setup_test_dataloader(self, cfg):
        phon_mode = contextlib.nullcontext()
        if hasattr(self.vocab, "set_phone_prob"):
            phon_mode = self.vocab.set_phone_prob(0.0)

        with phon_mode:
            dataset = instantiate(
                cfg.dataset,
                text_tokenizer=self.vocab,
            )

        return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)

    def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloader_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True"
                )
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
        elif cfg.dataloader_params.shuffle:
            logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")

        if self.ds_class == "nemo.collections.tts.data.dataset.TTSDataset":
            phon_mode = contextlib.nullcontext()
            if hasattr(self.vocab, "set_phone_prob"):
                phon_mode = self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability)

            with phon_mode:
                dataset = instantiate(
                    cfg.dataset,
                    text_normalizer=self.normalizer,
                    text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
                    text_tokenizer=self.vocab,
                )
        else:
            dataset = instantiate(cfg.dataset)

        return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)

    def setup_training_data(self, cfg):
        if self.ds_class == "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
            self._train_dl = self._setup_train_dataloader(cfg)
        else:
            self._train_dl = self.__setup_dataloader_from_config(cfg)

    def setup_validation_data(self, cfg):
        if self.ds_class == "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
            self._validation_dl = self._setup_test_dataloader(cfg)
        else:
            self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="val")

    def setup_test_data(self, cfg):
        """Omitted."""
        pass

    def configure_callbacks(self):
        if not self.log_config:
            return []

        sample_ds_class = self.log_config.dataset._target_
        if sample_ds_class != "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset":
            raise ValueError(f"Logging callback only supported for TextToSpeechDataset, got {sample_ds_class}")

        data_loader = self._setup_test_dataloader(self.log_config)

        generators = instantiate(self.log_config.generators)
        log_dir = Path(self.log_config.log_dir) if self.log_config.log_dir else None
        log_callback = LoggingCallback(
            generators=generators,
            data_loader=data_loader,
            log_epochs=self.log_config.log_epochs,
            epoch_frequency=self.log_config.epoch_frequency,
            output_dir=log_dir,
            loggers=self.trainer.loggers,
            log_tensorboard=self.log_config.log_tensorboard,
            log_wandb=self.log_config.log_wandb,
        )

        return [log_callback]

    @classmethod
    def list_available_models(cls) -> 'List[PretrainedModelInfo]':
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        list_of_models = []

        # en-US, single speaker, 22050Hz, LJSpeech (ARPABET).
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_fastpitch",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.8.1/files/tts_en_fastpitch_align.nemo",
            description="This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent. It is ARPABET-based.",
            class_=cls,
        )
        list_of_models.append(model)

        # en-US, single speaker, 22050Hz, LJSpeech (IPA).
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_fastpitch_ipa",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/IPA_1.13.0/files/tts_en_fastpitch_align_ipa.nemo",
            description="This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent. It is IPA-based.",
            class_=cls,
        )
        list_of_models.append(model)

        # en-US, multi-speaker, 44100Hz, HiFiTTS.
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_fastpitch_multispeaker",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_multispeaker_fastpitchhifigan/versions/1.10.0/files/tts_en_fastpitch_multispeaker.nemo",
            description="This model is trained on HiFITTS sampled at 44100Hz with and can be used to generate male and female English voices with an American accent.",
            class_=cls,
        )
        list_of_models.append(model)

        # de-DE, single male speaker, grapheme-based tokenizer, 22050 Hz, Thorsten Müller’s German Neutral-TTS Dataset, 21.02
        model = PretrainedModelInfo(
            pretrained_model_name="tts_de_fastpitch_singleSpeaker_thorstenNeutral_2102",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_de_fastpitchhifigan/versions/1.15.0/files/tts_de_fastpitch_thorstens2102.nemo",
            description="This model is trained on a single male speaker data in Thorsten Müller’s German Neutral 21.02 Dataset sampled at 22050Hz and can be used to generate male German voices.",
            class_=cls,
        )
        list_of_models.append(model)

        # de-DE, single male speaker, grapheme-based tokenizer, 22050 Hz, Thorsten Müller’s German Neutral-TTS Dataset, 22.10
        model = PretrainedModelInfo(
            pretrained_model_name="tts_de_fastpitch_singleSpeaker_thorstenNeutral_2210",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_de_fastpitchhifigan/versions/1.15.0/files/tts_de_fastpitch_thorstens2210.nemo",
            description="This model is trained on a single male speaker data in Thorsten Müller’s German Neutral 22.10 Dataset sampled at 22050Hz and can be used to generate male German voices.",
            class_=cls,
        )
        list_of_models.append(model)

        # de-DE, multi-speaker, 5 speakers, 44100 Hz, HUI-Audio-Corpus-German Clean.
        model = PretrainedModelInfo(
            pretrained_model_name="tts_de_fastpitch_multispeaker_5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_de_fastpitch_multispeaker_5/versions/1.11.0/files/tts_de_fastpitch_multispeaker_5.nemo",
            description="This model is trained on 5 speakers in HUI-Audio-Corpus-German clean subset sampled at 44100Hz with and can be used to generate male and female German voices.",
            class_=cls,
        )
        list_of_models.append(model)

        # es, 174 speakers, 44100Hz, OpenSLR (IPA)
        model = PretrainedModelInfo(
            pretrained_model_name="tts_es_fastpitch_multispeaker",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_es_multispeaker_fastpitchhifigan/versions/1.15.0/files/tts_es_fastpitch_multispeaker.nemo",
            description="This model is trained on 174 speakers in 6 crowdsourced Latin American Spanish OpenSLR datasets sampled at 44100Hz and can be used to generate male and female Spanish voices with Latin American accents.",
            class_=cls,
        )
        list_of_models.append(model)

        # zh, single female speaker, 22050Hz, SFSpeech Bilingual Chinese/English dataset, improved model using richer
        # dict and jieba word segmenter for polyphone disambiguation.
        model = PretrainedModelInfo(
            pretrained_model_name="tts_zh_fastpitch_sfspeech",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_zh_fastpitch_hifigan_sfspeech/versions/1.15.0/files/tts_zh_fastpitch_sfspeech.nemo",
            description="This model is trained on a single female speaker in SFSpeech Bilingual Chinese/English dataset"
            " sampled at 22050Hz and can be used to generate female Mandarin Chinese voices. It is improved"
            " using richer dict and jieba word segmenter for polyphone disambiguation.",
            class_=cls,
        )
        list_of_models.append(model)

        # en, multi speaker, LibriTTS, 16000 Hz
        # stft 25ms 10ms matching ASR params
        # for use during Enhlish ASR training/adaptation
        model = PretrainedModelInfo(
            pretrained_model_name="tts_en_fastpitch_for_asr_finetuning",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning/versions/1.20.0/files/tts_en_fastpitch_for_asr_finetuning.nemo",
            description="This model is trained on LibriSpeech, train-960 subset."
            " STFT parameters follow those commonly used in ASR: 25 ms window, 10 ms hop."
            " This model is supposed to be used with its companion SpetrogramEnhancer for "
            " ASR fine-tuning. Usage for regular TTS tasks is not advised.",
            class_=cls,
        )
        list_of_models.append(model)

        return list_of_models

    # Methods for model exportability
    def _prepare_for_export(self, **kwargs):
        super()._prepare_for_export(**kwargs)

        tensor_shape = ('T') if self.export_config["enable_ragged_batches"] else ('B', 'T')

        # Define input_types and output_types as required by export()
        self._input_types = {
            "text": NeuralType(tensor_shape, TokenIndex()),
            "pitch": NeuralType(tensor_shape, RegressionValuesType()),
            "pace": NeuralType(tensor_shape),
            "volume": NeuralType(tensor_shape, optional=True),
            "batch_lengths": NeuralType(('B'), optional=True),
            "speaker": NeuralType(('B'), Index(), optional=True),
        }
        self._output_types = {
            "spect": NeuralType(('B', 'D', 'T'), MelSpectrogramType()),
            "num_frames": NeuralType(('B'), TokenDurationType()),
            "durs_predicted": NeuralType(('B', 'T'), TokenDurationType()),
            "log_durs_predicted": NeuralType(('B', 'T'), TokenLogDurationType()),
            "pitch_predicted": NeuralType(('B', 'T'), RegressionValuesType()),
        }
        if self.export_config["enable_volume"]:
            self._output_types["volume_aligned"] = NeuralType(('B', 'T'), RegressionValuesType())

    def _export_teardown(self):
        self._input_types = self._output_types = None

    @property
    def disabled_deployment_input_names(self):
        """Implement this method to return a set of input names disabled for export"""
        disabled_inputs = set()
        if self.fastpitch.speaker_emb is None:
            disabled_inputs.add("speaker")
        if not self.export_config["enable_ragged_batches"]:
            disabled_inputs.add("batch_lengths")
        if not self.export_config["enable_volume"]:
            disabled_inputs.add("volume")
        return disabled_inputs

    @property
    def input_types(self):
        return self._input_types

    @property
    def output_types(self):
        return self._output_types

    def input_example(self, max_batch=1, max_dim=44):
        """
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        """
        par = next(self.fastpitch.parameters())
        inputs = sample_tts_input(self.export_config, par.device, max_batch=max_batch, max_dim=max_dim)
        if 'enable_ragged_batches' not in self.export_config:
            inputs.pop('batch_lengths', None)
        return (inputs,)

    def forward_for_export(self, text, pitch, pace, volume=None, batch_lengths=None, speaker=None):
        if self.export_config["enable_ragged_batches"]:
            text, pitch, pace, volume_tensor, lens = batch_from_ragged(
                text, pitch, pace, batch_lengths, padding_idx=self.fastpitch.encoder.padding_idx, volume=volume
            )
            if volume is not None:
                volume = volume_tensor
        return self.fastpitch.infer(text=text, pitch=pitch, pace=pace, volume=volume, speaker=speaker)

    def interpolate_speaker(
        self, original_speaker_1, original_speaker_2, weight_speaker_1, weight_speaker_2, new_speaker_id
    ):
        """
        This method performs speaker interpolation between two original speakers the model is trained on.

        Inputs:
            original_speaker_1: Integer speaker ID of first existing speaker in the model
            original_speaker_2: Integer speaker ID of second existing speaker in the model
            weight_speaker_1: Floating point weight associated in to first speaker during weight combination
            weight_speaker_2: Floating point weight associated in to second speaker during weight combination
            new_speaker_id: Integer speaker ID of new interpolated speaker in the model
        """
        if self.fastpitch.speaker_emb is None:
            raise Exception(
                "Current FastPitch model is not a multi-speaker FastPitch model. Speaker interpolation can only \
                be performed with a multi-speaker model"
            )
        n_speakers = self.fastpitch.speaker_emb.weight.data.size()[0]
        if original_speaker_1 >= n_speakers or original_speaker_2 >= n_speakers or new_speaker_id >= n_speakers:
            raise Exception(
                f"Parameters original_speaker_1, original_speaker_2, new_speaker_id should be less than the total \
                total number of speakers FastPitch was trained on (n_speakers = {n_speakers})."
            )
        speaker_emb_1 = (
            self.fastpitch.speaker_emb(torch.tensor(original_speaker_1, dtype=torch.int32).cuda()).clone().detach()
        )
        speaker_emb_2 = (
            self.fastpitch.speaker_emb(torch.tensor(original_speaker_2, dtype=torch.int32).cuda()).clone().detach()
        )
        new_speaker_emb = weight_speaker_1 * speaker_emb_1 + weight_speaker_2 * speaker_emb_2
        self.fastpitch.speaker_emb.weight.data[new_speaker_id] = new_speaker_emb
