# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import datetime
import hashlib
import logging
import time
from bisect import bisect_right
from collections import OrderedDict, defaultdict
from enum import Enum
from typing import List

import numpy as np
import torch

from fairseq.data import FairseqDataset, data_utils
from fairseq.distributed import utils as distributed_utils


def get_time_gap(s, e):
    return (
        datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s)
    ).__str__()


logger = logging.getLogger(__name__)


def default_virtual_size_func(datasets, ratios, max_scale_up=1.5):
    sizes = [len(d) for d in datasets]
    if ratios is None:
        return sum(sizes)
    largest_idx = np.argmax(sizes)
    largest_r = ratios[largest_idx]
    largest_s = sizes[largest_idx]
    # set virtual sizes relative to the largest dataset
    virtual_sizes = [(r / largest_r) * largest_s for r in ratios]
    vsize = sum(virtual_sizes)
    max_size = sum(sizes) * max_scale_up
    return int(vsize if vsize < max_size else max_size)


class CollateFormat(Enum):
    single = 1
    ordered_dict = 2


class SampledMultiDataset(FairseqDataset):
    """Samples from multiple sub-datasets according to given sampling ratios.
    Args:
        datasets (
            List[~torch.utils.data.Dataset]
            or OrderedDict[str, ~torch.utils.data.Dataset]
        ): datasets
        sampling_ratios (List[float]): list of probability of each dataset to be sampled
            (default: None, which corresponds to concatenating all dataset together).
        seed (int): RNG seed to use (default: 2).
        epoch (int): starting epoch number (default: 1).
        eval_key (str, optional): a key used at evaluation time that causes
            this instance to pass-through batches from *datasets[eval_key]*.
        collate_format (CollateFormat):  collater output format, either CollateFormat.ordered_dict or
            CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures
            the collater to output batches of data mixed from all sub-datasets,
            and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys
            of sub-datasets.
            Note that not all sub-datasets will present in a single batch in both formats.
        virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func).
        split (str): the split of the data, e.g. 'train', 'valid' or 'test'.
        shared_collater (bool): whether or not to all sub-datasets have the same collater.
        shuffle (bool): whether or not to shuffle data (default: True).
    """

    def __init__(
        self,
        datasets,
        sampling_ratios=None,
        seed=2,
        epoch=1,
        eval_key=None,
        collate_format=CollateFormat.single,
        virtual_size=default_virtual_size_func,
        split="",
        shared_collater=False,
        shuffle=True,
    ):
        super().__init__()
        self.shared_collater = shared_collater
        self.shuffle = shuffle

        if isinstance(datasets, OrderedDict):
            self.keys = list(datasets.keys())
            datasets = list(datasets.values())
        elif isinstance(datasets, List):
            self.keys = list(range(len(datasets)))
        else:
            raise AssertionError()
        self.datasets = datasets
        self.split = split

        self.eval_key = eval_key
        if self.eval_key is not None:
            self.collate_format = CollateFormat.single
        else:
            self.collate_format = collate_format

        self.seed = seed
        self._cur_epoch = None

        self.cumulated_sizes = None
        # self.datasets[k][self._cur_indices[i]] is the data item i in this sampled dataset
        # namely, data item i is sampled from the kth sub-dataset self.datasets[k]
        # where self.cumulated_sizes[k-1] <= i < self.cumulated_sizes[k]
        self._cur_indices = None

        self._sizes = None
        self.virtual_size_per_dataset = None
        # caching properties
        self._reset_cached_properties()
        self.setup_sampling(sampling_ratios, virtual_size)
        self.set_epoch(epoch)

    def _clean_if_not_none(self, var_list):
        for v in var_list:
            if v is not None:
                del v

    def _reset_cached_properties(self):
        self._clean_if_not_none([self._sizes, self._cur_indices])
        self._sizes = None
        self._cur_indices = None

    def setup_sampling(self, sample_ratios, virtual_size):
        sizes = [len(d) for d in self.datasets]
        if sample_ratios is None:
            # default back to concating datasets
            self.sample_ratios = None
            self.virtual_size = sum(sizes)
        else:
            if not isinstance(sample_ratios, np.ndarray):
                sample_ratios = np.array(sample_ratios)
            self.sample_ratios = sample_ratios
            virtual_size = (
                default_virtual_size_func if virtual_size is None else virtual_size
            )
            self.virtual_size = (
                virtual_size(self.datasets, self.sample_ratios)
                if callable(virtual_size)
                else virtual_size
            )

    def adjust_sampling(self, epoch, sampling_ratios, virtual_size):
        if sampling_ratios is not None:
            sampling_ratios = self._sync_sample_ratios(sampling_ratios)
            self.setup_sampling(sampling_ratios, virtual_size)

    def _sync_sample_ratios(self, ratios):
        # in case the ratios are not precisely the same across processes
        # also to ensure every procresses update the ratios in the same pace
        ratios = torch.DoubleTensor(ratios)
        if torch.distributed.is_initialized():
            if torch.cuda.is_available():
                distributed_utils.all_reduce(
                    ratios.cuda(), group=distributed_utils.get_data_parallel_group()
                )
            else:
                distributed_utils.all_reduce(
                    ratios, group=distributed_utils.get_data_parallel_group()
                )
            ret = ratios.cpu()
            ret = ret.numpy()
        return ret

    def random_choice_in_dataset(self, rng, dataset, choice_size):
        if hasattr(dataset, "random_choice_in_dataset"):
            return dataset.random_choice_in_dataset(rng, choice_size)
        dataset_size = len(dataset)
        return rng.choice(
            dataset_size, choice_size, replace=(choice_size > dataset_size)
        )

    def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size):
        def get_counts(sample_ratios):
            counts = np.array([virtual_size * r for r in sample_ratios], dtype=np.int64)
            diff = virtual_size - counts.sum()
            assert diff >= 0
            # due to round-offs, the size might not match the desired sizes
            if diff > 0:
                dataset_indices = rng.choice(
                    len(sample_ratios), size=diff, p=sample_ratios
                )
                for i in dataset_indices:
                    counts[i] += 1
            return counts

        def get_in_dataset_indices(datasets, sizes, sample_ratios):
            counts = get_counts(sample_ratios)
            # uniformally sample desired counts for each dataset
            # if the desired counts are large, sample with replacement:
            indices = [
                self.random_choice_in_dataset(rng, d, c)
                for c, d in zip(counts, datasets)
            ]
            return indices

        sizes = [len(d) for d in datasets]
        if sample_ratios is None:
            # default back to concating datasets
            in_dataset_indices = [list(range(s)) for s in sizes]
            virtual_sizes_per_dataset = sizes
        else:
            ratios = sample_ratios / sample_ratios.sum()
            in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios)
            virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices]
        virtual_sizes_per_dataset = np.array(virtual_sizes_per_dataset, np.int64)
        cumulative_sizes = np.cumsum(virtual_sizes_per_dataset)
        assert sum(virtual_sizes_per_dataset) == virtual_size
        assert cumulative_sizes[-1] == virtual_size
        if virtual_size < sum(sizes):
            logger.warning(
                f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})."
                " If virtual size << real data size, there could be data coverage issue."
            )
        in_dataset_indices = np.hstack(in_dataset_indices)
        return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset

    def _get_dataset_and_index(self, index):
        i = bisect_right(self.cumulated_sizes, index)
        return i, self._cur_indices[index]

    def __getitem__(self, index):
        # self.__getitem__(index) returns self.datasets[k][self._cur_indices[index]]
        # where k satisfies self.cumulated_sizes[k - 1] <= k < self.cumulated_sizes[k]
        ds_idx, ds_sample_idx = self._get_dataset_and_index(index)
        ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx])
        return ret

    def num_tokens(self, index):
        return self.sizes[index].max()

    def num_tokens_vec(self, indices):
        sizes_vec = self.sizes[np.array(indices)]
        # max across all dimensions but first one
        return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape))))

    def size(self, index):
        return self.sizes[index]

    def __len__(self):
        return self.virtual_size

    def collater(self, samples, **extra_args):
        """Merge a list of samples to form a mini-batch."""
        if len(samples) == 0:
            return None
        if self.collate_format == "ordered_dict":
            collect_samples = [[] for _ in range(len(self.datasets))]
            for (i, sample) in samples:
                collect_samples[i].append(sample)
            batch = OrderedDict(
                [
                    (self.keys[i], dataset.collater(collect_samples[i]))
                    for i, (key, dataset) in enumerate(zip(self.keys, self.datasets))
                    if len(collect_samples[i]) > 0
                ]
            )
        elif self.shared_collater:
            batch = self.datasets[0].collater([s for _, s in samples])
        else:
            samples_dict = defaultdict(list)
            pad_to_length = (
                defaultdict(int)
                if "pad_to_length" not in extra_args
                else extra_args["pad_to_length"]
            )
            for ds_idx, s in samples:
                pad_to_length["source"] = max(
                    pad_to_length["source"], s["source"].size(0)
                )
                if s["target"] is not None:
                    pad_to_length["target"] = max(
                        pad_to_length["target"], s["target"].size(0)
                    )
                samples_dict[ds_idx].append(s)
            batches = [
                self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length)
                for i in range(len(self.datasets))
                if len(samples_dict[i]) > 0
            ]

            def straight_data(tensors):
                batch = torch.cat(tensors, dim=0)
                return batch

            src_lengths = straight_data(
                [b["net_input"]["src_lengths"] for b in batches]
            )
            src_lengths, sort_order = src_lengths.sort(descending=True)

            def straight_order(tensors):
                batch = straight_data(tensors)
                return batch.index_select(0, sort_order)

            batch = {
                "id": straight_order([b["id"] for b in batches]),
                "nsentences": sum(b["nsentences"] for b in batches),
                "ntokens": sum(b["ntokens"] for b in batches),
                "net_input": {
                    "src_tokens": straight_order(
                        [b["net_input"]["src_tokens"] for b in batches]
                    ),
                    "src_lengths": src_lengths,
                },
                "target": straight_order([b["target"] for b in batches])
                if batches[0]["target"] is not None
                else None,
            }
            if "prev_output_tokens" in batches[0]["net_input"]:
                batch["net_input"]["prev_output_tokens"] = straight_order(
                    [b["net_input"]["prev_output_tokens"] for b in batches]
                )
            if "src_lang_id" in batches[0]["net_input"]:
                batch["net_input"]["src_lang_id"] = straight_order(
                    [b["net_input"]["src_lang_id"] for b in batches]
                )
            if "tgt_lang_id" in batches[0]:
                batch["tgt_lang_id"] = straight_order(
                    [b["tgt_lang_id"] for b in batches]
                )
        return batch

    @property
    def sizes(self):
        if self._sizes is not None:
            return self._sizes
        start_time = time.time()
        in_sub_dataset_indices = [
            self._cur_indices[
                0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i]
            ]
            for i in range(len(self.datasets))
        ]
        sub_dataset_sizes = [
            d.sizes[indices]
            for d, indices in zip(self.datasets, in_sub_dataset_indices)
        ]
        self._sizes = np.vstack(sub_dataset_sizes)
        logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}")
        return self._sizes

    def ordered_indices(self):
        if self.shuffle:
            indices = np.random.permutation(len(self))
        else:
            indices = np.arange(len(self))

        sizes = self.sizes
        tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
        src_sizes = (
            sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
        )

        # sort by target length, then source length
        if tgt_sizes is not None:
            indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
        sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")]
        return sort_indices

    def prefetch(self, indices):
        prefetch_indices = [[] for _ in range(len(self.datasets))]
        for i in indices:
            ds_idx, ds_sample_idx = self._get_dataset_and_index(i)
            prefetch_indices[ds_idx].append(ds_sample_idx)
        for i in range(len(prefetch_indices)):
            self.datasets[i].prefetch(prefetch_indices[i])

    @property
    def can_reuse_epoch_itr_across_epochs(self):
        return False

    def set_epoch(self, epoch):
        super().set_epoch(epoch)
        if epoch == self._cur_epoch:
            # re-enter so return
            return
        for d in self.datasets:
            if hasattr(d, "set_epoch"):
                d.set_epoch(epoch)
        self._cur_epoch = epoch
        self._establish_virtual_datasets()

    def _establish_virtual_datasets(self):
        if self.sample_ratios is None and self._cur_indices is not None:
            # not a samping dataset, no need to resample if indices are already established
            return
        self._reset_cached_properties()

        start_time = time.time()
        # Generate a weighted sample of indices as a function of the
        # random seed and the current epoch.
        rng = np.random.RandomState(
            [
                int(
                    hashlib.sha1(
                        str(self.__class__.__name__).encode("utf-8")
                    ).hexdigest(),
                    16,
                )
                % (2**32),
                self.seed % (2**32),  # global seed
                self._cur_epoch,  # epoch index,
            ]
        )
        self._clean_if_not_none(
            [self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes]
        )
        self._sizes = None

        indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices(
            rng, self.datasets, self.sample_ratios, self.virtual_size
        )
        self._cur_indices = indices
        self.cumulated_sizes = cumulated_sizes
        self.virtual_size_per_dataset = virtual_size_per_dataset

        raw_sizes = [len(d) for d in self.datasets]
        sampled_sizes = self.virtual_size_per_dataset
        logger.info(
            f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; "
            f"raw total size: {sum(raw_sizes)}"
        )
        logger.info(
            f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; "
            f"resampled total size: {sum(sampled_sizes)}"
        )
        if self.sample_ratios is not None:
            logger.info(
                f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}"
            )
        else:
            logger.info(f"[{self.split}] A concat dataset")
        logger.info(
            f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}"
        )

    def filter_indices_by_size(self, indices, max_sizes):
        """Filter a list of sample indices. Remove those that are longer
            than specified in max_sizes.

        Args:
            indices (np.array): original array of sample indices
            max_sizes (int or list[int] or tuple[int]): max sample size,
                can be defined separately for src and tgt (then list or tuple)

        Returns:
            np.array: filtered sample array
            list: list of removed indices
        """
        sizes = self.sizes
        tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
        src_sizes = (
            sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
        )

        return data_utils.filter_paired_dataset_indices_by_size(
            src_sizes, tgt_sizes, indices, max_sizes
        )
