import argparse
import unittest
from typing import Any, Dict

import torch
from examples.simultaneous_translation.models import (
    transformer_monotonic_attention
)


from tests.test_roberta import FakeTask


DEFAULT_CONFIG = {
    "attention_eps": 1e-6,
    "mass_preservation": True,
    "noise_type": "flat",
    "noise_mean": 0.0,
    "noise_var": 1.0,
    "energy_bias_init": -2,
    "energy_bias": True
}


PAD_INDEX = 1


def generate_config(overrides_kv):
    new_dict = {key: value for key, value in DEFAULT_CONFIG.items()}
    for key, value in overrides_kv.items():
        new_dict[key] = value
    return new_dict


def make_sample_with_padding(longer_src=False) -> Dict[str, Any]:
    tokens_1 = torch.LongTensor(
        [
            [2, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 2],
            [
                2, 11, 12, 14, 15, 10, 11, 12, 13, 14, 15, 2,
                PAD_INDEX, PAD_INDEX
            ],
        ]
    )
    tokens_2 = torch.LongTensor(
        [
            [2, 11, 12, 13, 14, 2, PAD_INDEX, PAD_INDEX],
            [2, 11, 22, 33, 2, PAD_INDEX, PAD_INDEX, PAD_INDEX]
        ]
    )
    if longer_src:
        src_tokens = tokens_1[:, 1:]
        prev_output_tokens = tokens_2
    else:
        src_tokens = tokens_2[:, 1:8]
        prev_output_tokens = tokens_1

    src_lengths = src_tokens.ne(PAD_INDEX).sum(dim=1).long()

    sample = {
        "net_input": {
            "src_tokens": src_tokens,
            "prev_output_tokens": prev_output_tokens,
            "src_lengths": src_lengths,
        },
        "target": prev_output_tokens[:, 1:],
    }
    return sample


def build_transformer_monotonic_attention(**extra_args: Any):
    overrides = {
        # Use characteristics dimensions
        "encoder_embed_dim": 12,
        "encoder_ffn_embed_dim": 14,
        "decoder_embed_dim": 12,
        "decoder_ffn_embed_dim": 14,
        # Disable dropout so we have comparable tests.
        "dropout": 0,
        "attention_dropout": 0,
        "activation_dropout": 0,
        "encoder_layerdrop": 0,
    }
    overrides.update(extra_args)
    # Overrides the defaults from the parser
    args = argparse.Namespace(**overrides)
    transformer_monotonic_attention.monotonic_tiny_architecture(args)

    torch.manual_seed(0)
    task = FakeTask(args)
    return (
        transformer_monotonic_attention
        .TransformerModelSimulTrans
        .build_model(args, task)
    )


def expected_alignment_formula(
    p_choose,
    mass_perservation=True,
    padding_mask=None
):
    # Online and Linear-Time Attention by Enforcing Monotonic Alignments
    # https://arxiv.org/pdf/1704.00784.pdf
    # Eq 18, 19
    bsz, tgt_len, src_len = p_choose.size()
    alpha = torch.zeros_like(p_choose)

    if padding_mask is not None:
        bsz_pad = padding_mask.size(0)
        num_heads = int(bsz / bsz_pad)
        padding_mask = (
            padding_mask
            .unsqueeze(1)
            .expand([bsz_pad, num_heads, src_len])
            .contiguous()
            .view(-1, src_len)
        )

    p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0)

    for bsz_i in range(bsz):
        for i in range(tgt_len):
            for j in range(src_len):
                if i == 0:
                    if j == 0:
                        # First source token
                        alpha[bsz_i, i, j] = p_choose[bsz_i, i, j]
                    else:
                        # First target token
                        alpha[bsz_i, i, j] = (
                            p_choose[bsz_i, i, j]
                            * torch.prod(
                                1 - p_choose[bsz_i, i, :j]
                            )
                        )
                else:
                    alpha[bsz_i, i, j] = alpha[bsz_i, i - 1, j]
                    for k in range(j):
                        alpha[bsz_i, i, j] += (
                            alpha[bsz_i, i - 1, k]
                            * torch.prod(
                                1 - p_choose[bsz_i, i, k:j]
                            )
                        )
                    alpha[bsz_i, i, j] *= p_choose[bsz_i, i, j]

    alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0)

    if mass_perservation:
        alpha = mass_perservation_formula(alpha, False, padding_mask)

    return alpha


def mass_perservation_formula(alpha, left_padding=False, padding_mask=None):
    if padding_mask is None or alpha.size(-1) == 1:
        if alpha.size(-1) > 1:
            alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1)
        return alpha

    src_lens = (padding_mask.logical_not()).sum(dim=1).long()

    bsz, tgt_len, src_len = alpha.size()

    assert (
        not left_padding
        or (left_padding and (not padding_mask[:, 0].any()))
    )

    alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0)

    for bsz_i in range(bsz):
        if left_padding:
            alpha[bsz_i, :, -1] = (
                1 - alpha[bsz_i, :, :-1].sum(dim=-1)
            )
        else:
            alpha[bsz_i, :, src_lens[bsz_i] - 1] = (
                1 - alpha[bsz_i, :, :src_lens[bsz_i] - 1].sum(dim=-1)
            )

    return alpha


def expected_soft_attention_formula(
    alpha,
    soft_energy,
    padding_mask=None,
    chunksize=1e10,
):
    # Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
    # https://arxiv.org/pdf/1906.05218.pdf
    # Eq 14

    # Monotonic Chunkwise Attention
    # https://arxiv.org/abs/1712.05382
    # Eq 17
    bsz, tgt_len, src_len = alpha.size()
    beta = torch.zeros_like(alpha)

    if padding_mask is not None:
        bsz_pad = padding_mask.size(0)
        num_heads = int(bsz / bsz_pad)
        # Expanding for potential head dimension
        padding_mask = (
            padding_mask
            .unsqueeze(1)
            .expand([bsz_pad, num_heads, src_len])
            .contiguous()
            .view(-1, src_len)
        )
        soft_energy = soft_energy.masked_fill(padding_mask.unsqueeze(1), float('-inf'))

    for bsz_i in range(bsz):
        for i in range(tgt_len):
            for j in range(src_len):
                for k in range(j, min([src_len, j + chunksize])):
                    if not padding_mask[bsz_i, j]:
                        beta[bsz_i, i, j] += (
                            alpha[bsz_i, i, k] * torch.exp(soft_energy[bsz_i, i, j])
                            / torch.sum(torch.exp(soft_energy[bsz_i, i, max([0, k - chunksize + 1]):k + 1]))
                        )
    return beta


class MonotonicAttentionTestAbstractClass(object):
    def test_forward(self):
        sample = make_sample_with_padding()
        out, _ = self.model.forward(**sample["net_input"])
        loss = out.sum()
        loss.backward()

    def test_p_choose(self):
        sample = make_sample_with_padding()
        _, extra_out = self.model.forward(**sample["net_input"])
        for item in extra_out.attn_list:
            p_choose = item["p_choose"]
            self.assertTrue(p_choose.le(1.0).all())
            self.assertTrue(p_choose.ge(0.0).all())

    def test_expected_alignment(self):
        for longer_src in [True, False]:
            sample = make_sample_with_padding(longer_src)
            _, extra_out = self.model.forward(**sample["net_input"])
            for item in extra_out.attn_list:
                p_choose = item["p_choose"]
                alpha_system = item["alpha"]
                self.assertTrue(p_choose.size() == alpha_system.size())
                bsz, num_head, tgt_len, src_len = alpha_system.size()
                alpha_system = alpha_system.view(-1, tgt_len, src_len)
                p_choose = p_choose.view(-1, tgt_len, src_len)

                alpha_real = expected_alignment_formula(
                    p_choose,
                    self.model.decoder.layers[0].encoder_attn.mass_preservation,
                    sample["net_input"]["src_tokens"].eq(PAD_INDEX)
                )

                self.assertTrue(
                    torch.abs(alpha_system - alpha_real).le(5e-5).all(),
                )


class HardMonotonicAttentionTestCase(
    unittest.TestCase,
    MonotonicAttentionTestAbstractClass
):
    def setUp(self):
        self.model = build_transformer_monotonic_attention(
            **generate_config({"simul_type": "hard_aligned"})
        )


class InfiniteLookbackTestCase(
    unittest.TestCase,
    MonotonicAttentionTestAbstractClass
):
    def setUp(self):
        self.model = build_transformer_monotonic_attention(
            **generate_config(
                {
                    "simul_type": "infinite_lookback"
                }
            )
        )
        self.model.train()

    def test_fp16_for_long_input(self):
        sample = {
            "net_input": {
                "src_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0),
                "prev_output_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0),
                "src_lengths": torch.LongTensor([1000]).cuda(),
            },
            "target": torch.LongTensor([2] + [7] * 1000).unsqueeze(0).cuda()
        }
        self.model.cuda().half()
        _, extra_out = self.model.forward(**sample["net_input"])
        for item in extra_out.attn_list:
            for key in ["p_choose", "alpha", "beta", "soft_energy"]:
                self.assertFalse(torch.isnan(item[key]).any())

    def test_expected_attention(self):
        for longer_src in [True, False]:
            sample = make_sample_with_padding(longer_src)
            _, extra_out = self.model.forward(**sample["net_input"])
            for item in extra_out.attn_list:
                p_choose = item["p_choose"]
                alpha_system = item["alpha"]
                beta_system = item["beta"]
                soft_energy_system = item["soft_energy"]
                self.assertTrue(beta_system.size() == alpha_system.size())
                self.assertTrue(p_choose.size() == alpha_system.size())

                bsz, num_head, tgt_len, src_len = alpha_system.size()

                alpha_system = alpha_system.view(-1, tgt_len, src_len)
                beta_system = beta_system.view(-1, tgt_len, src_len)
                p_choose = p_choose.view(-1, tgt_len, src_len)
                soft_energy_system = soft_energy_system.view(-1, tgt_len, src_len)

                alpha_real = expected_alignment_formula(
                    p_choose,
                    self.model.decoder.layers[0].encoder_attn.mass_preservation,
                    sample["net_input"]["src_tokens"].eq(PAD_INDEX)
                )

                beta_real = expected_soft_attention_formula(
                    alpha_real,
                    soft_energy_system,
                    sample["net_input"]["src_tokens"].eq(PAD_INDEX),
                    chunksize=getattr(
                        self.model.decoder.layers[0].encoder_attn,
                        "chunk_size",
                        int(1e10)
                    )
                )

                self.assertTrue(
                    torch.abs(beta_system - beta_real).le(1e-5).all(),
                )


class ChunkwiswTestCase(
    InfiniteLookbackTestCase
):
    def setUp(self):
        self.model = build_transformer_monotonic_attention(
            **generate_config(
                {
                    "simul_type": "chunkwise",
                    "mocha_chunk_size": 3
                }
            )
        )


class WaitkTestCase(InfiniteLookbackTestCase):
    def setUp(self):
        self.model = build_transformer_monotonic_attention(
            **generate_config(
                {
                    "simul_type": "waitk",
                    "waitk_lagging": 3,
                }
            )
        )

    def check_waitk(self, p_choose, lagging, padding_mask):
        bsz, tgt_len, src_len = p_choose.size()
        for bsz_i in range(bsz):
            for i in range(tgt_len):
                for j in range(src_len):
                    if not padding_mask[bsz_i, j]:
                        if j - i == lagging - 1:
                            self.assertTrue(p_choose[bsz_i, i, j] == 1)
                        else:
                            self.assertTrue(p_choose[bsz_i, i, j] == 0)

    def test_waitk_p_choose(self):
        for longer_src in [True, False]:
            for k in [1, 3, 10, 20, 100]:
                sample = make_sample_with_padding(longer_src)
                model = build_transformer_monotonic_attention(
                    **generate_config(
                        {
                            "simul_type": "waitk",
                            "waitk_lagging": k,
                        }
                    )
                )
                model.train()
                _, extra_out = model.forward(**sample["net_input"])
                for item in extra_out.attn_list:
                    p_choose = item["p_choose"]
                    bsz, num_heads, tgt_len, src_len = p_choose.size()
                    padding_mask = sample["net_input"]["src_tokens"].eq(PAD_INDEX)
                    padding_mask = (
                        padding_mask
                        .unsqueeze(1)
                        .expand([bsz, num_heads, src_len])
                        .contiguous()
                        .view(-1, src_len)
                    )
                    p_choose = p_choose.view(bsz * num_heads, tgt_len, src_len)
                    self.check_waitk(p_choose, k, padding_mask)
