# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from lightning.pytorch import Trainer
from omegaconf import DictConfig

from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
from nemo.collections.asr.parts.k2.classes import ASRK2Mixin
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.utils import logging


class EncDecK2SeqModel(EncDecCTCModel, ASRK2Mixin):
    """Encoder decoder models with various lattice losses."""

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        loss_type = cfg.graph_module_cfg.get("loss_type", "ctc")
        if loss_type != "ctc" and loss_type != "mmi":
            raise ValueError(f"Class {self.__class__.__name__} does not support `loss_type`={loss_type}")
        super().__init__(cfg=cfg, trainer=trainer)
        self._init_k2()

    @classmethod
    def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        pass

    def change_vocabulary(self, new_vocabulary: List[str]):
        """
        Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        If new_vocabulary == self.decoder.vocabulary then nothing will be changed.

        Args:
            new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
            this is target alphabet.

        Returns: None

        """
        super().change_vocabulary(new_vocabulary)

        if self.use_graph_lm:
            self.token_lm = None
            logging.warning(
                f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, 
                a new token_lm has to be set manually: call .update_k2_modules(new_cfg) 
                or update .graph_module_cfg.backend_cfg.token_lm before calling this method."""
            )

        self.update_k2_modules(self.graph_module_cfg)

    @typecheck()
    def forward(
        self,
        input_signal=None,
        input_signal_length=None,
        processed_signal=None,
        processed_signal_length=None,
    ):
        """
        Forward pass of the model.

        Args:
            input_signal: Tensor that represents a batch of raw audio signals,
                of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, that contains the individual lengths of the audio
                sequences.
            processed_signal: Tensor that represents a batch of processed audio signals,
                of shape (B, D, T) that has undergone processing via some DALI preprocessor.
            processed_signal_length: Vector of length B, that contains the individual lengths of the
                processed audio sequences.

        Returns:
            A tuple of 3 elements -
            1) The log probabilities tensor of shape [B, T, D].
            2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
            3) The greedy token predictions of the model of shape [B, T] (via argmax)
        """
        log_probs, encoded_len, greedy_predictions = super().forward(
            input_signal=input_signal,
            input_signal_length=input_signal_length,
            processed_signal=processed_signal,
            processed_signal_length=processed_signal_length,
        )
        return self._forward_k2_post_processing(
            log_probs=log_probs, encoded_length=encoded_len, greedy_predictions=greedy_predictions
        )


class EncDecK2SeqModelBPE(EncDecCTCModelBPE, ASRK2Mixin):
    """Encoder decoder models with Byte Pair Encoding and various lattice losses."""

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        loss_type = cfg.graph_module_cfg.get("loss_type", "ctc")
        if loss_type != "ctc" and loss_type != "mmi":
            raise ValueError(f"Class {self.__class__.__name__} does not support `loss_type`={loss_type}")
        super().__init__(cfg=cfg, trainer=trainer)
        self._init_k2()

    @classmethod
    def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        pass

    def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str):
        """
        Changes vocabulary of the tokenizer used during CTC decoding process.
        Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        Args:
            new_tokenizer_dir: Path to the new tokenizer directory.
            new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers,
                whereas `wpe` is used for `BertTokenizer`.

        Returns: None

        """
        super().change_vocabulary(new_tokenizer_dir, new_tokenizer_type)

        if self.use_graph_lm:
            self.token_lm = None
            logging.warning(
                f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, 
                a new token_lm has to be set manually: call .update_k2_modules(new_cfg) 
                or update .graph_module_cfg.backend_cfg.token_lm before calling this method."""
            )

        self.update_k2_modules(self.graph_module_cfg)

    @typecheck()
    def forward(
        self,
        input_signal=None,
        input_signal_length=None,
        processed_signal=None,
        processed_signal_length=None,
    ):
        """
        Forward pass of the model.

        Args:
            input_signal: Tensor that represents a batch of raw audio signals,
                of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, that contains the individual lengths of the audio
                sequences.
            processed_signal: Tensor that represents a batch of processed audio signals,
                of shape (B, D, T) that has undergone processing via some DALI preprocessor.
            processed_signal_length: Vector of length B, that contains the individual lengths of the
                processed audio sequences.

        Returns:
            A tuple of 3 elements -
            1) The log probabilities tensor of shape [B, T, D].
            2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
            3) The greedy token predictions of the model of shape [B, T] (via argmax)
        """
        log_probs, encoded_len, greedy_predictions = super().forward(
            input_signal=input_signal,
            input_signal_length=input_signal_length,
            processed_signal=processed_signal,
            processed_signal_length=processed_signal_length,
        )
        return self._forward_k2_post_processing(
            log_probs=log_probs, encoded_length=encoded_len, greedy_predictions=greedy_predictions
        )


class EncDecK2RnntSeqModel(EncDecRNNTModel, ASRK2Mixin):
    """Encoder decoder models with various lattice losses."""

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        loss_type = cfg.graph_module_cfg.get("loss_type", "rnnt")
        criterion_type = cfg.graph_module_cfg.get("criterion_type", "ml")
        if loss_type != "rnnt" or criterion_type != "ml":
            raise ValueError(
                f"""Class {self.__class__.__name__} does not support 
            `criterion_type`={criterion_type} with `loss_type`={loss_type}"""
            )
        super().__init__(cfg=cfg, trainer=trainer)
        self._init_k2()

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        pass

    def change_vocabulary(self, new_vocabulary: List[str]):
        """
        Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        If new_vocabulary == self.decoder.vocabulary then nothing will be changed.

        Args:
            new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
            this is target alphabet.

        Returns: None

        """
        super().change_vocabulary(new_vocabulary)

        if self.use_graph_lm:
            self.token_lm = None
            logging.warning(
                f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, 
                a new token_lm has to be set manually: call .update_k2_modules(new_cfg) 
                or update .graph_module_cfg.backend_cfg.token_lm before calling this method."""
            )

        self.update_k2_modules(self.graph_module_cfg)


class EncDecK2RnntSeqModelBPE(EncDecRNNTBPEModel, ASRK2Mixin):
    """Encoder decoder models with Byte Pair Encoding and various lattice losses."""

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        loss_type = cfg.graph_module_cfg.get("loss_type", "rnnt")
        criterion_type = cfg.graph_module_cfg.get("criterion_type", "ml")
        if loss_type != "rnnt" or criterion_type != "ml":
            raise ValueError(
                f"""Class {self.__class__.__name__} does not support 
            `criterion_type`={criterion_type} with `loss_type`={loss_type}"""
            )
        super().__init__(cfg=cfg, trainer=trainer)
        self._init_k2()

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        pass

    def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str):
        """
        Changes vocabulary of the tokenizer used during CTC decoding process.
        Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        Args:
            new_tokenizer_dir: Path to the new tokenizer directory.
            new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers,
                whereas `wpe` is used for `BertTokenizer`.

        Returns: None

        """
        super().change_vocabulary(new_tokenizer_dir, new_tokenizer_type)

        if self.use_graph_lm:
            self.token_lm = None
            logging.warning(
                f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, 
                a new token_lm has to be set manually: call .update_k2_modules(new_cfg) 
                or update .graph_module_cfg.backend_cfg.token_lm before calling this method."""
            )

        self.update_k2_modules(self.graph_module_cfg)
