# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np
import torch

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import (
    get_indexed_dataset_,
    get_samples_mapping,
)
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import TextMemMapDataset
from nemo.core.classes import Dataset
from nemo.utils import logging

__all__ = ['SequenceToSequenceDataset', 'TextMemmapSequenceToSequenceDataset']


class SequenceToSequenceDataset(Dataset):
    """Sequence to Sequence Dataset in memory."""

    def __init__(
        self,
        src_file_name: str,
        tgt_file_name: str,
        src_tokenizer: TokenizerSpec,
        tgt_tokenizer: TokenizerSpec,
        max_src_seq_length: int,
        max_tgt_seq_length: int,
        add_bos_to_input: bool = True,
        add_eos_to_input: bool = True,
        replace_bos_with_pad: bool = False,
    ):
        super().__init__()
        self.src_file_name = src_file_name
        self.tgt_file_name = tgt_file_name
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_src_seq_length = max_src_seq_length
        self.max_tgt_seq_length = max_tgt_seq_length
        self.add_bos_to_input = add_bos_to_input
        self.add_eos_to_input = add_eos_to_input
        self.replace_bos_with_pad = replace_bos_with_pad
        assert self.max_src_seq_length > 0
        assert self.max_tgt_seq_length > 0
        self._check_files_exist()
        self._get_examples()

    def _check_files_exist(self):
        if not os.path.exists(self.src_file_name):
            raise FileNotFoundError(f"Source file {self.src_file_name} not found")
        if not os.path.exists(self.tgt_file_name):
            raise FileNotFoundError(f"Source file {self.src_file_name} not found")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        text_enc = example['src']
        text_dec = example['tgt'][:-1]
        labels = example['tgt'][1:]
        return {'text_enc': text_enc, 'text_dec': text_dec, 'labels': labels}

    def _get_examples(self):
        self.examples = []
        with open(self.src_file_name, encoding='utf8') as f_src, open(self.tgt_file_name, encoding='utf8') as f_tgt:
            for i, (src, tgt) in enumerate(zip(f_src, f_tgt)):
                if i % 10000 == 0 and i != 0:
                    logging.info(f"Read {i} lines from {self.src_file_name} & {self.tgt_file_name}")
                src = self.src_tokenizer.text_to_ids(src.strip())
                if self.add_bos_to_input:
                    src = [self.src_tokenizer.pad_id if self.replace_bos_with_pad else self.src_tokenizer.bos_id] + src
                if self.add_eos_to_input:
                    src = src + [self.src_tokenizer.eos_id]

                tgt = (
                    [self.tgt_tokenizer.pad_id if self.replace_bos_with_pad else self.tgt_tokenizer.bos_id]
                    + self.tgt_tokenizer.text_to_ids(tgt.strip())
                    + [self.tgt_tokenizer.eos_id]
                )
                # Truncate to max sequence length.
                if len(src) > self.max_src_seq_length:
                    src = src[-self.max_src_seq_length + 1 :]
                if len(tgt) > self.max_tgt_seq_length:
                    tgt = tgt[-self.max_tgt_seq_length + 1 :]
                self.examples.append({'src': src, 'tgt': tgt})

        logging.info(f'Dataset Length : {len(self.examples)}')

    def collate_fn(self, batch):
        text_enc = [item['text_enc'] for item in batch]
        text_dec = [item['text_dec'] for item in batch]
        labels = [item['labels'] for item in batch]

        if isinstance(text_enc[0], np.ndarray):
            text_enc = [x.tolist() for x in text_enc]

        if isinstance(text_dec[0], np.ndarray):
            text_dec = [x.tolist() for x in text_dec]

        if isinstance(labels[0], np.ndarray):
            labels = [x.tolist() for x in labels]

        max_dec_input_length = max([len(item) for item in text_dec]) if text_dec else 0
        max_enc_input_length = max([len(item) for item in text_enc]) if text_enc else 0
        max_label_length = max([len(item) for item in labels]) if labels else 0

        loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels]
        text_enc = [item + [self.src_tokenizer.pad_id] * (max_enc_input_length - len(item)) for item in text_enc]
        text_dec = [item + [self.tgt_tokenizer.pad_id] * (max_dec_input_length - len(item)) for item in text_dec]
        labels = [item + [self.tgt_tokenizer.pad_id] * (max_label_length - len(item)) for item in labels]

        text_enc = torch.LongTensor(text_enc)
        text_dec = torch.LongTensor(text_dec)
        labels = torch.LongTensor(labels)
        loss_mask = torch.LongTensor(loss_mask)

        enc_mask = (text_enc != self.src_tokenizer.pad_id).long()
        dec_mask = (text_dec != self.tgt_tokenizer.pad_id).long()

        return {
            'text_enc': text_enc,
            'text_dec': text_dec,
            'labels': labels,
            'loss_mask': loss_mask,
            'enc_mask': enc_mask,
            'dec_mask': dec_mask,
        }


class IndexedSequenceToSequenceDataset(SequenceToSequenceDataset):
    """Abstract class for TextMemmapSequenceToSequenceDataset and BinarizedMemmapSequenceToSequenceDataset.
    This class is not meant to be used standalone and just as an abstract class for the two subclasses.
    """

    def __init__(
        self,
        src_file_name: str,
        tgt_file_name: str,
        src_tokenizer: TokenizerSpec,
        tgt_tokenizer: TokenizerSpec,
        max_src_seq_length: int,
        max_tgt_seq_length: int,
        seed: int = 1234,
        add_bos_to_enc: bool = True,
        add_eos_to_enc: bool = True,
        max_num_samples: int = None,
        prepend_id: int = None,
    ):
        """
        src_file_name: Path to a single source file on disk. This is either the path to a raw text file or the prefix to the processed src_file_name.bin/idx files.
        src_file_name: Path to a single target file on disk. This is either the path to a raw text file or the prefix to the processed tgt_file_name.bin/idx files.
        src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: SentencePiece).
        tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: SentencePiece).
        max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
        max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
        seed: Random seed for data shuffling.
        max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
        prepend_id: If not None, prepend this id to the encoder input.
        """
        super().__init__(
            src_file_name=src_file_name,
            tgt_file_name=tgt_file_name,
            src_tokenizer=src_tokenizer,
            tgt_tokenizer=tgt_tokenizer,
            max_src_seq_length=max_src_seq_length,
            max_tgt_seq_length=max_tgt_seq_length,
        )
        self.seed = seed
        self.max_num_samples = max_num_samples
        self.add_bos_to_enc = add_bos_to_enc
        self.add_eos_to_enc = add_eos_to_enc
        self.prepend_id = prepend_id

        logging.info(f'Desired number of samples : {self.max_num_samples}')
        logging.info(f'Source Dataset Length : {len(self.src_indexed_dataset)}')
        logging.info(f'Target Dataset Length : {len(self.tgt_indexed_dataset)}')

    def __len__(self):
        if self.max_num_samples is None:
            return len(self.src_indexed_dataset)
        else:
            return self.max_num_samples

    def _get_sample(self, idx):
        if isinstance(idx, np.int64):
            idx = idx.item()

        if self.samples_mapping is not None:
            assert idx < len(self.samples_mapping)
            idx, _, _ = self.samples_mapping[idx]
            if isinstance(idx, np.uint32):
                idx = idx.item()

        assert idx < len(self.src_indexed_dataset)
        src = self.src_indexed_dataset[idx]
        tgt = self.tgt_indexed_dataset[idx]

        return src, tgt

    def __getitem__(self, idx):
        src, tgt = self._get_sample(idx)
        offset = 0
        if self.add_bos_to_enc:
            offset += 1
        if self.add_eos_to_enc:
            offset += 1
        if self.prepend_id is not None:
            offset += 1

        if len(src) > self.max_src_seq_length - offset:
            src = src[: self.max_src_seq_length - offset]

        if self.add_bos_to_enc:
            src = np.concatenate([[self.src_tokenizer.bos_id], src])

        if self.prepend_id is not None:
            src = np.concatenate([[self.prepend_id], src])

        if self.add_eos_to_enc:
            src = np.concatenate([src, [self.src_tokenizer.eos_id]])

        if len(tgt) > self.max_tgt_seq_length - 2:
            tgt = tgt[: self.max_tgt_seq_length - 2]

        text_dec = np.concatenate([[self.tgt_tokenizer.bos_id], tgt])
        labels = np.concatenate([tgt, [self.tgt_tokenizer.eos_id]])

        return {'text_enc': src, 'text_dec': text_dec, 'labels': labels}

    def _build_samples_mapping(self):
        if self.max_num_samples is not None:
            # This means max src and max tgt sequence length need to be the same
            if self.max_src_seq_length != self.max_tgt_seq_length:
                raise ValueError(
                    f"max_src_seq_length ({self.max_src_seq_length}) != max_tgt_seq_length ({self.max_tgt_seq_length}). This is needed for max_samples based training for now."
                )

            self.samples_mapping = get_samples_mapping(
                indexed_dataset=self.src_indexed_dataset,
                data_prefix=self.src_file_name,
                num_epochs=None,
                max_num_samples=self.max_num_samples,
                max_seq_length=self.max_src_seq_length - 2,
                short_seq_prob=0,
                seed=self.seed,
                name=self.src_file_name.split('/')[-1],
                binary_head=False,
            )
        else:
            self.samples_mapping = None


class TextMemmapSequenceToSequenceDataset(IndexedSequenceToSequenceDataset):
    """Memory-mapped text sequence to sequence dataset. Operates on raw text files and tokenizes the text on-the-fly."""

    def __init__(
        self,
        src_file_name: str,
        tgt_file_name: str,
        src_tokenizer: TokenizerSpec,
        tgt_tokenizer: TokenizerSpec,
        max_src_seq_length: int,
        max_tgt_seq_length: int,
        seed: int = 1234,
        max_num_samples: int = None,
        add_bos_to_enc: bool = True,
        add_eos_to_enc: bool = True,
        prepend_id: int = None,
    ):
        """
        src_file_name: Path to a single source file on disk. The file should contain one sentence per line and be raw text.
        tgt_file_name: Path to a single target file on disk. The file should contain one sentence per line aligned with src_file_name and be raw text.
        src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: SentencePiece).
        tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: SentencePiece).
        max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
        max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
        seed: Random seed for data shuffling.
        max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
        add_bos_to_enc: Add BOS token to the encoder input.
        add_eos_to_enc: Add EOS token to the encoder input.
        prepend_id: If not None, prepend this id to the encoder input.
        """
        self.seed = seed
        self.max_num_samples = max_num_samples
        super().__init__(
            src_file_name=src_file_name,
            tgt_file_name=tgt_file_name,
            src_tokenizer=src_tokenizer,
            tgt_tokenizer=tgt_tokenizer,
            max_src_seq_length=max_src_seq_length,
            max_tgt_seq_length=max_tgt_seq_length,
            seed=seed,
            max_num_samples=max_num_samples,
            add_bos_to_enc=add_bos_to_enc,
            add_eos_to_enc=add_eos_to_enc,
            prepend_id=prepend_id,
        )

    def _get_examples(self):
        self.src_indexed_dataset = TextMemMapDataset(
            dataset_paths=[self.src_file_name], tokenizer=self.src_tokenizer, header_lines=0
        )
        self.tgt_indexed_dataset = TextMemMapDataset(
            dataset_paths=[self.tgt_file_name], tokenizer=self.tgt_tokenizer, header_lines=0
        )

        assert len(self.src_indexed_dataset) == len(
            self.tgt_indexed_dataset
        ), "src and tgt has different number of lines"
        self._build_samples_mapping()


class BinarizedMemmapSequenceToSequenceDataset(IndexedSequenceToSequenceDataset):
    """Memory-mapped text sequence to sequence dataset. Operates pre-tokenized binarized data files."""

    def __init__(
        self,
        src_dataset_prefix: str,
        tgt_dataset_prefix: str,
        src_tokenizer: TokenizerSpec,
        tgt_tokenizer: TokenizerSpec,
        max_src_seq_length: int,
        max_tgt_seq_length: int,
        seed: int = 1234,
        max_num_samples: int = None,
        add_bos_to_enc: bool = True,
        add_eos_to_enc: bool = True,
        prepend_id: int = None,
    ):
        """
        src_dataset_prefix: Path to the *prefix* of a single source bin/idx file on disk. This necessitates the existance src_file_prefix.bin and src_file_prefix.idx.
        tgt_dataset_prefix: Path to the *prefix* of a single target aligned with source bin/idx file on disk. This necessitates the existance tgt_file_prefix.bin and tgt_file_prefix.idx.
        src_tokenizer: Tokenizer for the source dataset. Instance of a class that inherits TokenizerSpec (ex: SentencePiece).
        tgt_tokenizer: Tokenizer for the target dataset. Instance of a class that inherits TokenizerSpec (ex: SentencePiece).
        max_src_seq_length: Maximum length of the source sequences. Lines above this length will be truncated.
        max_tgt_seq_length: Maximum length of the target sequences. Lines above this length will be truncated.
        seed: Random seed for data shuffling.
        max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded.
        add_bos_to_enc: Add BOS token to the encoder input.
        add_eos_to_enc: Add EOS token to the encoder input.
        prepend_id: If not None, prepend this id to the encoder input.
        """
        self.src_dataset_prefix = src_dataset_prefix
        self.tgt_dataset_prefix = tgt_dataset_prefix
        self.seed = seed
        self.max_num_samples = max_num_samples
        super().__init__(
            src_file_name=src_dataset_prefix,
            tgt_file_name=tgt_dataset_prefix,
            src_tokenizer=src_tokenizer,
            tgt_tokenizer=tgt_tokenizer,
            max_src_seq_length=max_src_seq_length,
            max_tgt_seq_length=max_tgt_seq_length,
            seed=seed,
            max_num_samples=max_num_samples,
            add_bos_to_enc=add_bos_to_enc,
            add_eos_to_enc=add_eos_to_enc,
            prepend_id=prepend_id,
        )

    def _check_files_exist(self):
        if not os.path.exists(self.src_dataset_prefix + ".bin") or not os.path.exists(
            self.src_dataset_prefix + ".idx"
        ):
            raise FileNotFoundError(f"{self.src_dataset_prefix}.bin or {self.src_dataset_prefix}.idx not found")
        if not os.path.exists(self.tgt_dataset_prefix + ".bin") or not os.path.exists(
            self.tgt_dataset_prefix + ".idx"
        ):
            raise FileNotFoundError(f"{self.tgt_dataset_prefix}.bin or {self.tgt_dataset_prefix}.idx not found")

    def _get_examples(self):
        self.src_indexed_dataset = self._get_indexed_dataset(
            self.src_dataset_prefix, data_impl='mmap', skip_warmup=True
        )
        self.tgt_indexed_dataset = self._get_indexed_dataset(
            self.tgt_dataset_prefix, data_impl='mmap', skip_warmup=True
        )
        assert len(self.src_indexed_dataset) == len(self.tgt_indexed_dataset)
        self._build_samples_mapping()

    def _get_indexed_dataset(self, data_prefix, data_impl, skip_warmup):
        indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
        return indexed_dataset
