# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional

import torch

from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.speech_to_text_dataset import (
    S2TDataConfig,
    SpeechToTextDataset,
    SpeechToTextDatasetCreator,
)

logger = logging.getLogger(__name__)


class S2TJointDataConfig(S2TDataConfig):
    """Wrapper class for data config YAML"""

    @property
    def src_vocab_filename(self):
        """fairseq vocabulary file under data root"""
        return self.config.get("src_vocab_filename", "src_dict.txt")

    @property
    def src_pre_tokenizer(self) -> Dict:
        """Pre-tokenizer to apply before subword tokenization. Returning
        a dictionary with `tokenizer` providing the tokenizer name and
        the other items providing the tokenizer-specific arguments.
        Tokenizers are defined in `fairseq.data.encoders.*`"""
        return self.config.get("src_pre_tokenizer", {"tokenizer": None})

    @property
    def src_bpe_tokenizer(self) -> Dict:
        """Subword tokenizer to apply on source text after pre-tokenization.
        Returning a dictionary with `bpe` providing the tokenizer name and
        the other items providing the tokenizer-specific arguments.
        Tokenizers are defined in `fairseq.data.encoders.*`"""
        return self.config.get("src_bpe_tokenizer", {"bpe": None})

    @property
    def prepend_tgt_lang_tag_no_change(self) -> bool:
        """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
        to-many multilingual setting). No change needed during inference.
        This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
        """
        value = self.config.get("prepend_tgt_lang_tag_no_change", None)
        if value is None:
            return self.config.get("prepend_tgt_lang_tag_as_bos", False)
        return value

    @property
    def sampling_text_alpha(self):
        """Hyper-parameter alpha = 1/T for temperature-based resampling. (text
        input only) (alpha = 1 for no resampling)"""
        return self.config.get("sampling_text_alpha", 1.0)


class SpeechToTextJointDatasetItem(NamedTuple):
    index: int
    source: torch.Tensor
    target: Optional[torch.Tensor] = None
    src_txt_tokens: Optional[torch.Tensor] = None
    tgt_lang_tag: Optional[int] = None
    src_lang_tag: Optional[int] = None
    tgt_alignment: Optional[torch.Tensor] = None


# use_src_lang_id:
#   0: don't use src_lang_id
#   1: attach src_lang_id to the src_txt_tokens as eos
class SpeechToTextJointDataset(SpeechToTextDataset):
    def __init__(
        self,
        split: str,
        is_train_split: bool,
        cfg: S2TJointDataConfig,
        audio_paths: List[str],
        n_frames: List[int],
        src_texts: Optional[List[str]] = None,
        tgt_texts: Optional[List[str]] = None,
        speakers: Optional[List[str]] = None,
        src_langs: Optional[List[str]] = None,
        tgt_langs: Optional[List[str]] = None,
        ids: Optional[List[str]] = None,
        tgt_dict: Optional[Dictionary] = None,
        src_dict: Optional[Dictionary] = None,
        pre_tokenizer=None,
        bpe_tokenizer=None,
        src_pre_tokenizer=None,
        src_bpe_tokenizer=None,
        append_eos: Optional[bool] = True,
        alignment: Optional[List[str]] = None,
        use_src_lang_id: Optional[int] = 0,
    ):
        super().__init__(
            split,
            is_train_split,
            cfg,
            audio_paths,
            n_frames,
            src_texts=src_texts,
            tgt_texts=tgt_texts,
            speakers=speakers,
            src_langs=src_langs,
            tgt_langs=tgt_langs,
            ids=ids,
            tgt_dict=tgt_dict,
            pre_tokenizer=pre_tokenizer,
            bpe_tokenizer=bpe_tokenizer,
            append_eos=append_eos,
        )

        self.src_dict = src_dict
        self.src_pre_tokenizer = src_pre_tokenizer
        self.src_bpe_tokenizer = src_bpe_tokenizer
        self.alignment = None
        self.use_src_lang_id = use_src_lang_id
        if alignment is not None:
            self.alignment = [
                [float(s) for s in sample.split()] for sample in alignment
            ]

    def get_tokenized_src_text(self, index: int):
        text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
        text = self.tokenize(self.src_bpe_tokenizer, text)
        return text

    def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
        s2t_dataset_item = super().__getitem__(index)
        src_tokens = None
        src_lang_tag = None
        if self.src_texts is not None and self.src_dict is not None:
            src_tokens = self.get_tokenized_src_text(index)
            src_tokens = self.src_dict.encode_line(
                src_tokens, add_if_not_exist=False, append_eos=True
            ).long()
            if self.use_src_lang_id > 0:
                src_lang_tag = self.get_lang_tag_idx(
                    self.src_langs[index], self.src_dict
                )
        tgt_lang_tag = None
        if self.cfg.prepend_tgt_lang_tag_no_change:
            # prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
            tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
        ali = None
        if self.alignment is not None:
            ali = torch.Tensor(self.alignment[index]).float()

        return SpeechToTextJointDatasetItem(
            index=index,
            source=s2t_dataset_item.source,
            target=s2t_dataset_item.target,
            src_txt_tokens=src_tokens,
            tgt_lang_tag=tgt_lang_tag,
            src_lang_tag=src_lang_tag,
            tgt_alignment=ali,
        )

    def __len__(self):
        return self.n_samples

    def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
        s2t_out = super().collater(samples, return_order=True)
        if s2t_out == {}:
            return s2t_out
        net_input, order = s2t_out["net_input"], s2t_out["order"]

        if self.src_texts is not None and self.src_dict is not None:
            src_txt_tokens = fairseq_data_utils.collate_tokens(
                [x.src_txt_tokens for x in samples],
                self.src_dict.pad(),
                self.src_dict.eos(),
                left_pad=False,
                move_eos_to_beginning=False,
            )
            src_txt_lengths = torch.tensor(
                [x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
            )
            if self.use_src_lang_id > 0:
                src_lang_idxs = torch.tensor(
                    [s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
                )
                if self.use_src_lang_id == 1:  # replace eos with lang_id
                    eos_idx = src_txt_lengths - 1
                    src_txt_tokens.scatter_(
                        1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
                    )
                else:
                    raise NotImplementedError("Implementation is required")

            src_txt_tokens = src_txt_tokens.index_select(0, order)
            src_txt_lengths = src_txt_lengths.index_select(0, order)
            net_input["src_txt_tokens"] = src_txt_tokens
            net_input["src_txt_lengths"] = src_txt_lengths

        net_input["alignment"] = None
        if self.alignment is not None:
            max_len = max([s.tgt_alignment.size(0) for s in samples])
            alignment = torch.ones(len(samples), max_len).float()
            for i, s in enumerate(samples):
                cur_len = s.tgt_alignment.size(0)
                alignment[i][:cur_len].copy_(s.tgt_alignment)
            net_input["alignment"] = alignment.index_select(0, order)

        if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
            for i in range(len(samples)):
                net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag

        out = {
            "id": s2t_out["id"],
            "net_input": net_input,
            "target": s2t_out["target"],
            "target_lengths": s2t_out["target_lengths"],
            "ntokens": s2t_out["ntokens"],
            "nsentences": len(samples),
        }
        return out


class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
    KEY_ALIGN = "align"

    @classmethod
    def _from_list(
        cls,
        split_name: str,
        is_train_split,
        samples: List[Dict],
        cfg: S2TJointDataConfig,
        tgt_dict,
        src_dict,
        pre_tokenizer,
        bpe_tokenizer,
        src_pre_tokenizer,
        src_bpe_tokenizer,
        append_eos,
        use_src_lang_id,
    ) -> SpeechToTextJointDataset:
        audio_root = Path(cfg.audio_root)
        ids = [s[cls.KEY_ID] for s in samples]
        audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
        n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
        tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
        src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
        speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
        src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
        tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
        tgt_alignment = None
        if cls.KEY_ALIGN in samples[0].keys():
            tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
        return SpeechToTextJointDataset(
            split_name,
            is_train_split,
            cfg,
            audio_paths,
            n_frames,
            src_texts=src_texts,
            tgt_texts=tgt_texts,
            speakers=speakers,
            src_langs=src_langs,
            tgt_langs=tgt_langs,
            ids=ids,
            tgt_dict=tgt_dict,
            src_dict=src_dict,
            pre_tokenizer=pre_tokenizer,
            bpe_tokenizer=bpe_tokenizer,
            src_pre_tokenizer=src_pre_tokenizer,
            src_bpe_tokenizer=src_bpe_tokenizer,
            append_eos=append_eos,
            alignment=tgt_alignment,
            use_src_lang_id=use_src_lang_id,
        )

    @classmethod
    def _from_tsv(
        cls,
        root: str,
        cfg: S2TJointDataConfig,
        split: str,
        tgt_dict,
        src_dict,
        is_train_split: bool,
        pre_tokenizer,
        bpe_tokenizer,
        src_pre_tokenizer,
        src_bpe_tokenizer,
        append_eos: bool,
        use_src_lang_id: int,
    ) -> SpeechToTextJointDataset:
        samples = cls._load_samples_from_tsv(root, split)
        return cls._from_list(
            split,
            is_train_split,
            samples,
            cfg,
            tgt_dict,
            src_dict,
            pre_tokenizer,
            bpe_tokenizer,
            src_pre_tokenizer,
            src_bpe_tokenizer,
            append_eos,
            use_src_lang_id,
        )

    @classmethod
    def from_tsv(
        cls,
        root: str,
        cfg: S2TJointDataConfig,
        splits: str,
        tgt_dict,
        src_dict,
        pre_tokenizer,
        bpe_tokenizer,
        src_pre_tokenizer,
        src_bpe_tokenizer,
        is_train_split: bool,
        epoch: int,
        seed: int,
        append_eos: Optional[bool] = True,
        use_src_lang_id: Optional[int] = 0,
    ) -> SpeechToTextJointDataset:
        datasets = [
            cls._from_tsv(
                root,
                cfg,
                split,
                tgt_dict,
                src_dict,
                is_train_split,
                pre_tokenizer,
                bpe_tokenizer,
                src_pre_tokenizer,
                src_bpe_tokenizer,
                append_eos=append_eos,
                use_src_lang_id=use_src_lang_id,
            )
            for split in splits.split(",")
        ]

        if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
            # temperature-based sampling
            size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
            datasets = [
                ResamplingDataset(
                    d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
                )
                for r, d in zip(size_ratios, datasets)
            ]

        return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
