# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import numpy as np
import torch
from fairseq.data import FairseqDataset


class BlockPairDataset(FairseqDataset):
    """Break a Dataset of tokens into sentence pair blocks for next sentence
       prediction as well as masked language model.

       High-level logics are:
       1. break input tensor to tensor blocks
       2. pair the blocks with 50% next sentence and 50% random sentence
       3. return paired blocks as well as related segment labels

    Args:
        dataset (~torch.utils.data.Dataset): dataset to break into blocks
        sizes: array of sentence lengths
        dictionary: dictionary for the task
        block_size: maximum block size
        break_mode: mode for breaking copurs into block pairs. currently we support
            2 modes
            doc: respect document boundaries and each part of the pair should belong to on document
            none: don't respect any boundary and cut tokens evenly
        short_seq_prob: probability for generating shorter block pairs
        doc_break_size: Size for empty line separating documents. Typically 1 if
                        the sentences have eos, 0 otherwise.
    """

    def __init__(
        self,
        dataset,
        dictionary,
        sizes,
        block_size,
        break_mode="doc",
        short_seq_prob=0.1,
        doc_break_size=1,
    ):
        super().__init__()
        self.dataset = dataset
        self.pad = dictionary.pad()
        self.eos = dictionary.eos()
        self.cls = dictionary.cls()
        self.mask = dictionary.mask()
        self.sep = dictionary.sep()
        self.break_mode = break_mode
        self.dictionary = dictionary
        self.short_seq_prob = short_seq_prob
        self.block_indices = []

        assert len(dataset) == len(sizes)

        if break_mode == "doc":
            cur_doc = []
            for sent_id, sz in enumerate(sizes):
                assert doc_break_size == 0 or sz != 0, (
                    "when doc_break_size is non-zero, we expect documents to be"
                    "separated by a blank line with a single eos."
                )
                # empty line as document separator
                if sz == doc_break_size:
                    if len(cur_doc) == 0:
                        continue
                    self.block_indices.append(cur_doc)
                    cur_doc = []
                else:
                    cur_doc.append(sent_id)
            max_num_tokens = block_size - 3  # Account for [CLS], [SEP], [SEP]
            self.sent_pairs = []
            self.sizes = []
            for doc_id, doc in enumerate(self.block_indices):
                self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes)
        elif break_mode is None or break_mode == "none":
            # each block should have half of the block size since we are constructing block pair
            sent_length = (block_size - 3) // 2
            total_len = sum(dataset.sizes)
            length = math.ceil(total_len / sent_length)

            def block_at(i):
                start = i * sent_length
                end = min(start + sent_length, total_len)
                return (start, end)

            sent_indices = np.array([block_at(i) for i in range(length)])
            sent_sizes = np.array([e - s for s, e in sent_indices])
            dataset_index = self._sent_to_dataset_index(sent_sizes)

            # pair sentences
            self._pair_sentences(dataset_index)
        else:
            raise ValueError("Invalid break_mode: " + break_mode)

    def _pair_sentences(self, dataset_index):
        """
        Give a list of evenly cut blocks/sentences, pair these sentences with 50%
        consecutive sentences and 50% random sentences.
        This is used for none break mode
        """
        # pair sentences
        for sent_id, sent in enumerate(dataset_index):
            next_sent_label = (
                1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0
            )
            if next_sent_label:
                next_sent = dataset_index[sent_id + 1]
            else:
                next_sent = dataset_index[
                    self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1])
                ]
            self.sent_pairs.append((sent, next_sent, next_sent_label))

            # The current blocks don't include the special tokens but the
            # sizes already account for this
            self.sizes.append(3 + sent[3] + next_sent[3])

    def _sent_to_dataset_index(self, sent_sizes):
        """
        Build index mapping block indices to the underlying dataset indices
        """
        dataset_index = []
        ds_idx, ds_remaining = -1, 0
        for to_consume in sent_sizes:
            sent_size = to_consume
            if ds_remaining == 0:
                ds_idx += 1
                ds_remaining = sent_sizes[ds_idx]
            start_ds_idx = ds_idx
            start_offset = sent_sizes[ds_idx] - ds_remaining
            while to_consume > ds_remaining:
                to_consume -= ds_remaining
                ds_idx += 1
                ds_remaining = sent_sizes[ds_idx]
            ds_remaining -= to_consume
            dataset_index.append(
                (
                    start_ds_idx,  # starting index in dataset
                    start_offset,  # starting offset within starting index
                    ds_idx,  # ending index in dataset
                    sent_size,  # sentence length
                )
            )
        assert ds_remaining == 0
        assert ds_idx == len(self.dataset) - 1
        return dataset_index

    def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
        """
        Go through a single document and genrate sentence paris from it
        """
        current_chunk = []
        current_length = 0
        curr = 0
        # To provide more randomness, we decrease target seq length for parts of
        # samples (10% by default). Note that max_num_tokens is the hard threshold
        # for batching and will never be changed.
        target_seq_length = max_num_tokens
        if np.random.random() < self.short_seq_prob:
            target_seq_length = np.random.randint(2, max_num_tokens)
        # loop through all sentences in document
        while curr < len(doc):
            sent_id = doc[curr]
            current_chunk.append(sent_id)
            current_length = sum(sizes[current_chunk])
            # split chunk and generate pair when exceed target_seq_length or
            # finish the loop
            if curr == len(doc) - 1 or current_length >= target_seq_length:
                # split the chunk into 2 parts
                a_end = 1
                if len(current_chunk) > 2:
                    a_end = np.random.randint(1, len(current_chunk) - 1)
                sent_a = current_chunk[:a_end]
                len_a = sum(sizes[sent_a])
                # generate next sentence label, note that if there is only 1 sentence
                # in current chunk, label is always 0
                next_sent_label = (
                    1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
                )
                if not next_sent_label:
                    # if next sentence label is 0, sample sent_b from a random doc
                    target_b_length = target_seq_length - len_a
                    rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
                    random_doc = self.block_indices[rand_doc_id]
                    random_start = np.random.randint(0, len(random_doc))
                    sent_b = []
                    len_b = 0
                    for j in range(random_start, len(random_doc)):
                        sent_b.append(random_doc[j])
                        len_b = sum(sizes[sent_b])
                        if len_b >= target_b_length:
                            break
                    # return the second part of the chunk since it's not used
                    num_unused_segments = len(current_chunk) - a_end
                    curr -= num_unused_segments
                else:
                    # if next sentence label is 1, use the second part of chunk as sent_B
                    sent_b = current_chunk[a_end:]
                    len_b = sum(sizes[sent_b])
                # currently sent_a and sent_B may be longer than max_num_tokens,
                # truncate them and return block idx and offsets for them
                sent_a, sent_b = self._truncate_sentences(
                    sent_a, sent_b, max_num_tokens
                )
                self.sent_pairs.append((sent_a, sent_b, next_sent_label))
                self.sizes.append(3 + sent_a[3] + sent_b[3])
                current_chunk = []
            curr += 1

    def _skip_sampling(self, total, skip_ids):
        """
        Generate a random integer which is not in skip_ids. Sample range is [0, total)
        TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later
        """
        rand_id = np.random.randint(total - len(skip_ids))
        return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids)

    def _truncate_sentences(self, sent_a, sent_b, max_num_tokens):
        """
        Trancate a pair of sentence to limit total length under max_num_tokens
        Logics:
            1. Truncate longer sentence
            2. Tokens to be truncated could be at the beginning or the end of the sentnce
        Returns:
            Truncated sentences represented by dataset idx
        """
        len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b])
        front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0

        while True:
            total_length = (
                len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b
            )
            if total_length <= max_num_tokens:
                break

            if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b:
                if np.random.rand() < 0.5:
                    front_cut_a += 1
                else:
                    end_cut_a += 1
            else:
                if np.random.rand() < 0.5:
                    front_cut_b += 1
                else:
                    end_cut_b += 1

        # calculate ds indices as well as offsets and return
        truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a)
        truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b)
        return truncated_sent_a, truncated_sent_b

    def _cut_sentence(self, sent, front_cut, end_cut):
        """
        Cut a sentence based on the numbers of tokens to be cut from beginning and end
        Represent the sentence as dataset idx and return
        """
        start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0
        target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut
        while front_cut > 0:
            if self.dataset.sizes[start_ds_idx] > front_cut:
                offset += front_cut
                break
            else:
                front_cut -= self.dataset.sizes[start_ds_idx]
                start_ds_idx += 1
        while end_cut > 0:
            if self.dataset.sizes[end_ds_idx] > end_cut:
                break
            else:
                end_cut -= self.dataset.sizes[end_ds_idx]
                end_ds_idx -= 1
        return start_ds_idx, offset, end_ds_idx, target_len

    def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length):
        """
        Fetch a block of tokens based on its dataset idx
        """
        buffer = torch.cat(
            [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
        )
        s, e = offset, offset + length
        return buffer[s:e]

    def __getitem__(self, index):
        block1, block2, next_sent_label = self.sent_pairs[index]
        block1 = self._fetch_block(*block1)
        block2 = self._fetch_block(*block2)
        return block1, block2, next_sent_label

    def __len__(self):
        return len(self.sizes)

    @property
    def supports_prefetch(self):
        return getattr(self.dataset, "supports_prefetch", False)

    def prefetch(self, indices):
        prefetch_idx = set()
        for index in indices:
            for block1, block2, _ in [self.sent_pairs[index]]:
                for ds_idx in range(block1[0], block1[2] + 1):
                    prefetch_idx.add(ds_idx)
                for ds_idx in range(block2[0], block2[2] + 1):
                    prefetch_idx.add(ds_idx)
        self.dataset.prefetch(prefetch_idx)
