# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
import tempfile

import numpy as np
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import (
    Dictionary,
    IdDataset,
    ListDataset,
    NestedDictionaryDataset,
    NumelDataset,
    NumSamplesDataset,
    PadDataset,
    SortDataset,
    data_utils,
    encoders,
)
from fairseq.tasks import LegacyFairseqTask, register_task

from . import wsc_utils


@register_task("wsc")
class WSCTask(LegacyFairseqTask):
    """Task to finetune RoBERTa for Winograd Schemas."""

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        parser.add_argument(
            "data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
        )
        parser.add_argument(
            "--init-token",
            type=int,
            default=None,
            help="add token at the beginning of each batch item",
        )

    def __init__(self, args, vocab):
        super().__init__(args)
        self.vocab = vocab
        self.mask = vocab.add_symbol("<mask>")

        self.bpe = encoders.build_bpe(args)
        self.tokenizer = encoders.build_tokenizer(args)

        # hack to handle GPT-2 BPE, which includes leading spaces
        if args.bpe == "gpt2":
            self.leading_space = True
            self.trailing_space = False
        else:
            self.leading_space = False
            self.trailing_space = True

    @classmethod
    def load_dictionary(cls, filename):
        """Load the dictionary from the filename

        Args:
            filename (str): the filename
        """
        dictionary = Dictionary.load(filename)
        dictionary.add_symbol("<mask>")
        return dictionary

    @classmethod
    def setup_task(cls, args, **kwargs):
        assert args.criterion == "wsc", "Must set --criterion=wsc"

        # load data and label dictionaries
        vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
        print("| dictionary: {} types".format(len(vocab)))

        return cls(args, vocab)

    def binarize(self, s: str, append_eos: bool = False):
        if self.tokenizer is not None:
            s = self.tokenizer.encode(s)
        if self.bpe is not None:
            s = self.bpe.encode(s)
        tokens = self.vocab.encode_line(
            s,
            append_eos=append_eos,
            add_if_not_exist=False,
        ).long()
        if self.args.init_token is not None:
            tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
        return tokens

    def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space):
        toks = self.binarize(
            prefix + leading_space + txt + trailing_space + suffix,
            append_eos=True,
        )
        mask = torch.zeros_like(toks, dtype=torch.bool)
        mask_start = len(self.binarize(prefix))
        mask_size = len(self.binarize(leading_space + txt))
        mask[mask_start : mask_start + mask_size] = 1
        return toks, mask

    def load_dataset(
        self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
    ):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if data_path is None:
            data_path = os.path.join(self.args.data, split + ".jsonl")
        if not os.path.exists(data_path):
            raise FileNotFoundError("Cannot find data: {}".format(data_path))

        query_tokens = []
        query_masks = []
        query_lengths = []
        candidate_tokens = []
        candidate_masks = []
        candidate_lengths = []
        labels = []

        for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
            prefix = sentence[: pronoun_span.start].text
            suffix = sentence[pronoun_span.end :].text_with_ws

            # spaCy spans include trailing spaces, but we need to know about
            # leading spaces for the GPT-2 BPE
            leading_space = (
                " " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
            )
            trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""

            # get noun phrases, excluding pronouns and anything overlapping with the query
            cand_spans = wsc_utils.filter_noun_chunks(
                wsc_utils.extended_noun_chunks(sentence),
                exclude_pronouns=True,
                exclude_query=query,
                exact_match=False,
            )

            if query is not None:
                query_toks, query_mask = self.binarize_with_mask(
                    query, prefix, suffix, leading_space, trailing_space
                )
                query_len = len(query_toks)
            else:
                query_toks, query_mask, query_len = None, None, 0

            query_tokens.append(query_toks)
            query_masks.append(query_mask)
            query_lengths.append(query_len)

            cand_toks, cand_masks = [], []
            for cand_span in cand_spans:
                toks, mask = self.binarize_with_mask(
                    cand_span.text,
                    prefix,
                    suffix,
                    leading_space,
                    trailing_space,
                )
                cand_toks.append(toks)
                cand_masks.append(mask)

            # collate candidates
            cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
            cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
            assert cand_toks.size() == cand_masks.size()

            candidate_tokens.append(cand_toks)
            candidate_masks.append(cand_masks)
            candidate_lengths.append(cand_toks.size(1))

            labels.append(label)

        query_lengths = np.array(query_lengths)
        query_tokens = ListDataset(query_tokens, query_lengths)
        query_masks = ListDataset(query_masks, query_lengths)

        candidate_lengths = np.array(candidate_lengths)
        candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
        candidate_masks = ListDataset(candidate_masks, candidate_lengths)

        labels = ListDataset(labels, [1] * len(labels))

        dataset = {
            "id": IdDataset(),
            "query_tokens": query_tokens,
            "query_masks": query_masks,
            "candidate_tokens": candidate_tokens,
            "candidate_masks": candidate_masks,
            "labels": labels,
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(query_tokens, reduce=True),
        }

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[query_lengths],
        )

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(query_tokens))
        dataset = SortDataset(
            nested_dataset,
            # shuffle
            sort_order=[shuffle],
        )

        if return_only:
            return dataset

        self.datasets[split] = dataset
        return self.datasets[split]

    def build_dataset_for_inference(self, sample_json):
        with tempfile.NamedTemporaryFile(buffering=0) as h:
            h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
            dataset = self.load_dataset(
                "disambiguate_pronoun",
                data_path=h.name,
                return_only=True,
            )
        return dataset

    def disambiguate_pronoun(self, model, sentence, use_cuda=False):
        sample_json = wsc_utils.convert_sentence_to_json(sentence)
        dataset = self.build_dataset_for_inference(sample_json)
        sample = dataset.collater([dataset[0]])
        if use_cuda:
            sample = utils.move_to_cuda(sample)

        def get_masked_input(tokens, mask):
            masked_tokens = tokens.clone()
            masked_tokens[mask.bool()] = self.mask
            return masked_tokens

        def get_lprobs(tokens, mask):
            logits, _ = model(src_tokens=get_masked_input(tokens, mask))
            lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
            scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
            mask = mask.type_as(scores)
            scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
            return scores

        cand_lprobs = get_lprobs(
            sample["candidate_tokens"][0],
            sample["candidate_masks"][0],
        )
        if sample["query_tokens"][0] is not None:
            query_lprobs = get_lprobs(
                sample["query_tokens"][0].unsqueeze(0),
                sample["query_masks"][0].unsqueeze(0),
            )
            return (query_lprobs >= cand_lprobs).all().item() == 1
        else:
            best_idx = cand_lprobs.argmax().item()
            full_cand = sample["candidate_tokens"][0][best_idx]
            mask = sample["candidate_masks"][0][best_idx]
            toks = full_cand[mask.bool()]
            return self.bpe.decode(self.source_dictionary.string(toks)).strip()

    @property
    def source_dictionary(self):
        return self.vocab

    @property
    def target_dictionary(self):
        return self.vocab


@register_task("winogrande")
class WinograndeTask(WSCTask):
    """
    Task for WinoGrande dataset. Efficient implementation for Winograd schema
    tasks with exactly two candidates, one of which is correct.
    """

    @classmethod
    def setup_task(cls, args, **kwargs):
        assert args.criterion == "winogrande", "Must set --criterion=winogrande"

        # load data and label dictionaries
        vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
        print("| dictionary: {} types".format(len(vocab)))

        return cls(args, vocab)

    def load_dataset(
        self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
    ):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if data_path is None:
            data_path = os.path.join(self.args.data, split + ".jsonl")
        if not os.path.exists(data_path):
            raise FileNotFoundError("Cannot find data: {}".format(data_path))

        query_tokens = []
        query_masks = []
        query_lengths = []
        candidate_tokens = []
        candidate_masks = []
        candidate_lengths = []

        itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))

        for sample in itr:
            sentence, pronoun_span, query, cand_text = sample
            prefix = sentence[: pronoun_span[0]].rstrip()
            suffix = sentence[pronoun_span[1] :]

            leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
            trailing_space = ""

            if query is not None:
                query_toks, query_mask = self.binarize_with_mask(
                    query,
                    prefix,
                    suffix,
                    leading_space,
                    trailing_space,
                )
                query_len = len(query_toks)
            else:
                query_toks, query_mask, query_len = None, None, 0

            query_tokens.append(query_toks)
            query_masks.append(query_mask)
            query_lengths.append(query_len)

            cand_toks, cand_mask = self.binarize_with_mask(
                cand_text,
                prefix,
                suffix,
                leading_space,
                trailing_space,
            )

            candidate_tokens.append(cand_toks)
            candidate_masks.append(cand_mask)
            candidate_lengths.append(cand_toks.size(0))

        query_lengths = np.array(query_lengths)

        def get_pad_dataset_fn(tokens, length, pad_idx):
            return PadDataset(
                ListDataset(tokens, length),
                pad_idx=pad_idx,
                left_pad=False,
            )

        query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad())
        query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)

        candidate_lengths = np.array(candidate_lengths)
        candidate_tokens = get_pad_dataset_fn(
            candidate_tokens, candidate_lengths, self.vocab.pad()
        )
        candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)

        dataset = {
            "id": IdDataset(),
            "query_tokens": query_tokens,
            "query_masks": query_masks,
            "candidate_tokens": candidate_tokens,
            "candidate_masks": candidate_masks,
            "nsentences": NumSamplesDataset(),
            "ntokens": NumelDataset(query_tokens, reduce=True),
        }

        nested_dataset = NestedDictionaryDataset(
            dataset,
            sizes=[query_lengths],
        )

        with data_utils.numpy_seed(self.args.seed):
            shuffle = np.random.permutation(len(query_tokens))
        dataset = SortDataset(
            nested_dataset,
            # shuffle
            sort_order=[shuffle],
        )

        if return_only:
            return dataset

        self.datasets[split] = dataset
        return self.datasets[split]
