#!/usr/bin/env python3

from ast import literal_eval
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models import (
    FairseqEncoder,
    FairseqEncoderDecoderModel,
    FairseqIncrementalDecoder,
    register_model,
    register_model_architecture,
)


@register_model("s2t_berard")
class BerardModel(FairseqEncoderDecoderModel):
    """Implementation of a model similar to https://arxiv.org/abs/1802.04200

    Paper title: End-to-End Automatic Speech Translation of Audiobooks
    An implementation is available in tensorflow at
    https://github.com/eske/seq2seq
    Relevant files in this implementation are the config
    (https://github.com/eske/seq2seq/blob/master/config/LibriSpeech/AST.yaml)
    and the model code
    (https://github.com/eske/seq2seq/blob/master/translate/models.py).
    The encoder and decoder try to be close to the original implementation.
    The attention is an MLP as in Bahdanau et al.
    (https://arxiv.org/abs/1409.0473).
    There is no state initialization by averaging the encoder outputs.
    """

    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)

    @staticmethod
    def add_args(parser):
        parser.add_argument(
            "--input-layers",
            type=str,
            metavar="EXPR",
            help="List of linear layer dimensions. These "
            "layers are applied to the input features and "
            "are followed by tanh and possibly dropout.",
        )
        parser.add_argument(
            "--dropout",
            type=float,
            metavar="D",
            help="Dropout probability to use in the encoder/decoder. "
            "Note that this parameters control dropout in various places, "
            "there is no fine-grained control for dropout for embeddings "
            "vs LSTM layers for example.",
        )
        parser.add_argument(
            "--in-channels",
            type=int,
            metavar="N",
            help="Number of encoder input channels. " "Typically value is 1.",
        )
        parser.add_argument(
            "--conv-layers",
            type=str,
            metavar="EXPR",
            help="List of conv layers " "(format: (channels, kernel, stride)).",
        )
        parser.add_argument(
            "--num-blstm-layers",
            type=int,
            metavar="N",
            help="Number of encoder bi-LSTM layers.",
        )
        parser.add_argument(
            "--lstm-size", type=int, metavar="N", help="LSTM hidden size."
        )
        parser.add_argument(
            "--decoder-embed-dim",
            type=int,
            metavar="N",
            help="Embedding dimension of the decoder target tokens.",
        )
        parser.add_argument(
            "--decoder-hidden-dim",
            type=int,
            metavar="N",
            help="Decoder LSTM hidden dimension.",
        )
        parser.add_argument(
            "--decoder-num-layers",
            type=int,
            metavar="N",
            help="Number of decoder LSTM layers.",
        )
        parser.add_argument(
            "--attention-dim",
            type=int,
            metavar="N",
            help="Hidden layer dimension in MLP attention.",
        )
        parser.add_argument(
            "--output-layer-dim",
            type=int,
            metavar="N",
            help="Hidden layer dim for linear layer prior to output projection.",
        )
        parser.add_argument(
            "--load-pretrained-encoder-from",
            type=str,
            metavar="STR",
            help="model to take encoder weights from (for initialization)",
        )
        parser.add_argument(
            "--load-pretrained-decoder-from",
            type=str,
            metavar="STR",
            help="model to take decoder weights from (for initialization)",
        )

    @classmethod
    def build_encoder(cls, args, task):
        encoder = BerardEncoder(
            input_layers=literal_eval(args.input_layers),
            conv_layers=literal_eval(args.conv_layers),
            in_channels=args.input_channels,
            input_feat_per_channel=args.input_feat_per_channel,
            num_blstm_layers=args.num_blstm_layers,
            lstm_size=args.lstm_size,
            dropout=args.dropout,
        )
        if getattr(args, "load_pretrained_encoder_from", None):
            encoder = checkpoint_utils.load_pretrained_component_from_model(
                component=encoder, checkpoint=args.load_pretrained_encoder_from
            )
        return encoder

    @classmethod
    def build_decoder(cls, args, task):
        decoder = LSTMDecoder(
            dictionary=task.target_dictionary,
            embed_dim=args.decoder_embed_dim,
            num_layers=args.decoder_num_layers,
            hidden_size=args.decoder_hidden_dim,
            dropout=args.dropout,
            encoder_output_dim=2 * args.lstm_size,  # bidirectional
            attention_dim=args.attention_dim,
            output_layer_dim=args.output_layer_dim,
        )
        if getattr(args, "load_pretrained_decoder_from", None):
            decoder = checkpoint_utils.load_pretrained_component_from_model(
                component=decoder, checkpoint=args.load_pretrained_decoder_from
            )
        return decoder

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        encoder = cls.build_encoder(args, task)
        decoder = cls.build_decoder(args, task)

        return cls(encoder, decoder)

    def get_normalized_probs(self, net_output, log_probs, sample=None):
        # net_output['encoder_out'] is a (B, T, D) tensor
        lprobs = super().get_normalized_probs(net_output, log_probs, sample)
        # lprobs is a (B, T, D) tensor
        lprobs.batch_first = True
        return lprobs


class BerardEncoder(FairseqEncoder):
    def __init__(
        self,
        input_layers: List[int],
        conv_layers: List[Tuple[int]],
        in_channels: int,
        input_feat_per_channel: int,
        num_blstm_layers: int,
        lstm_size: int,
        dropout: float,
    ):
        """
        Args:
            input_layers: list of linear layer dimensions. These layers are
                applied to the input features and are followed by tanh and
                possibly dropout.
            conv_layers: list of conv2d layer configurations. A configuration is
                a tuple (out_channels, conv_kernel_size, stride).
            in_channels: number of input channels.
            input_feat_per_channel: number of input features per channel. These
                are speech features, typically 40 or 80.
            num_blstm_layers: number of bidirectional LSTM layers.
            lstm_size: size of the LSTM hidden (and cell) size.
            dropout: dropout probability. Dropout can be applied after the
                linear layers and LSTM layers but not to the convolutional
                layers.
        """
        super().__init__(None)

        self.input_layers = nn.ModuleList()
        in_features = input_feat_per_channel
        for out_features in input_layers:
            if dropout > 0:
                self.input_layers.append(
                    nn.Sequential(
                        nn.Linear(in_features, out_features), nn.Dropout(p=dropout)
                    )
                )
            else:
                self.input_layers.append(nn.Linear(in_features, out_features))
            in_features = out_features

        self.in_channels = in_channels
        self.input_dim = input_feat_per_channel
        self.conv_kernel_sizes_and_strides = []
        self.conv_layers = nn.ModuleList()
        lstm_input_dim = input_layers[-1]
        for conv_layer in conv_layers:
            out_channels, conv_kernel_size, conv_stride = conv_layer
            self.conv_layers.append(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    conv_kernel_size,
                    stride=conv_stride,
                    padding=conv_kernel_size // 2,
                )
            )
            self.conv_kernel_sizes_and_strides.append((conv_kernel_size, conv_stride))
            in_channels = out_channels
            lstm_input_dim //= conv_stride

        lstm_input_dim *= conv_layers[-1][0]
        self.lstm_size = lstm_size
        self.num_blstm_layers = num_blstm_layers
        self.lstm = nn.LSTM(
            input_size=lstm_input_dim,
            hidden_size=lstm_size,
            num_layers=num_blstm_layers,
            dropout=dropout,
            bidirectional=True,
        )
        self.output_dim = 2 * lstm_size  # bidirectional
        if dropout > 0:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = None

    def forward(self, src_tokens, src_lengths=None, **kwargs):
        """
        Args
            src_tokens: padded tensor (B, T, C * feat)
            src_lengths: tensor of original lengths of input utterances (B,)
        """
        bsz, max_seq_len, _ = src_tokens.size()
        # (B, C, T, feat)
        x = (
            src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
            .transpose(1, 2)
            .contiguous()
        )

        for input_layer in self.input_layers:
            x = input_layer(x)
            x = torch.tanh(x)

        for conv_layer in self.conv_layers:
            x = conv_layer(x)

        bsz, _, output_seq_len, _ = x.size()

        # (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) ->
        # (T, B, C * feat)
        x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1)

        input_lengths = src_lengths.clone()
        for k, s in self.conv_kernel_sizes_and_strides:
            p = k // 2
            input_lengths = (input_lengths.float() + 2 * p - k) / s + 1
            input_lengths = input_lengths.floor().long()

        packed_x = nn.utils.rnn.pack_padded_sequence(x, input_lengths)

        h0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_()
        c0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_()
        packed_outs, _ = self.lstm(packed_x, (h0, c0))

        # unpack outputs and apply dropout
        x, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_outs)
        if self.dropout is not None:
            x = self.dropout(x)

        encoder_padding_mask = (
            lengths_to_padding_mask(output_lengths).to(src_tokens.device).t()
        )

        return {
            "encoder_out": x,  # (T, B, C)
            "encoder_padding_mask": encoder_padding_mask,  # (T, B)
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
            1, new_order
        )
        encoder_out["encoder_padding_mask"] = encoder_out[
            "encoder_padding_mask"
        ].index_select(1, new_order)
        return encoder_out


class MLPAttention(nn.Module):
    """The original attention from Badhanau et al. (2014)

    https://arxiv.org/abs/1409.0473, based on a Multi-Layer Perceptron.
    The attention score between position i in the encoder and position j in the
    decoder is: alpha_ij = V_a * tanh(W_ae * enc_i + W_ad * dec_j + b_a)
    """

    def __init__(self, decoder_hidden_state_dim, context_dim, attention_dim):
        super().__init__()

        self.context_dim = context_dim
        self.attention_dim = attention_dim
        # W_ae and b_a
        self.encoder_proj = nn.Linear(context_dim, self.attention_dim, bias=True)
        # W_ad
        self.decoder_proj = nn.Linear(
            decoder_hidden_state_dim, self.attention_dim, bias=False
        )
        # V_a
        self.to_scores = nn.Linear(self.attention_dim, 1, bias=False)

    def forward(self, decoder_state, source_hids, encoder_padding_mask):
        """The expected input dimensions are:
        decoder_state: bsz x decoder_hidden_state_dim
        source_hids: src_len x bsz x context_dim
        encoder_padding_mask: src_len x bsz
        """
        src_len, bsz, _ = source_hids.size()
        # (src_len*bsz) x context_dim (to feed through linear)
        flat_source_hids = source_hids.view(-1, self.context_dim)
        # (src_len*bsz) x attention_dim
        encoder_component = self.encoder_proj(flat_source_hids)
        # src_len x bsz x attention_dim
        encoder_component = encoder_component.view(src_len, bsz, self.attention_dim)
        # 1 x bsz x attention_dim
        decoder_component = self.decoder_proj(decoder_state).unsqueeze(0)
        # Sum with broadcasting and apply the non linearity
        # src_len x bsz x attention_dim
        hidden_att = torch.tanh(
            (decoder_component + encoder_component).view(-1, self.attention_dim)
        )
        # Project onto the reals to get attentions scores (src_len x bsz)
        attn_scores = self.to_scores(hidden_att).view(src_len, bsz)

        # Mask + softmax (src_len x bsz)
        if encoder_padding_mask is not None:
            attn_scores = (
                attn_scores.float()
                .masked_fill_(encoder_padding_mask, float("-inf"))
                .type_as(attn_scores)
            )  # FP16 support: cast to float and back
        # srclen x bsz
        normalized_masked_attn_scores = F.softmax(attn_scores, dim=0)

        # Sum weighted sources (bsz x context_dim)
        attn_weighted_context = (
            source_hids * normalized_masked_attn_scores.unsqueeze(2)
        ).sum(dim=0)

        return attn_weighted_context, normalized_masked_attn_scores


class LSTMDecoder(FairseqIncrementalDecoder):
    def __init__(
        self,
        dictionary,
        embed_dim,
        num_layers,
        hidden_size,
        dropout,
        encoder_output_dim,
        attention_dim,
        output_layer_dim,
    ):
        """
        Args:
            dictionary: target text dictionary.
            embed_dim: embedding dimension for target tokens.
            num_layers: number of LSTM layers.
            hidden_size: hidden size for LSTM layers.
            dropout: dropout probability. Dropout can be applied to the
                embeddings, the LSTM layers, and the context vector.
            encoder_output_dim: encoder output dimension (hidden size of
                encoder LSTM).
            attention_dim: attention dimension for MLP attention.
            output_layer_dim: size of the linear layer prior to output
                projection.
        """
        super().__init__(dictionary)
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()
        self.embed_tokens = nn.Embedding(num_embeddings, embed_dim, padding_idx)
        if dropout > 0:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = None

        self.layers = nn.ModuleList()
        for layer_id in range(num_layers):
            input_size = embed_dim if layer_id == 0 else encoder_output_dim
            self.layers.append(
                nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
            )

        self.context_dim = encoder_output_dim
        self.attention = MLPAttention(
            decoder_hidden_state_dim=hidden_size,
            context_dim=encoder_output_dim,
            attention_dim=attention_dim,
        )

        self.deep_output_layer = nn.Linear(
            hidden_size + encoder_output_dim + embed_dim, output_layer_dim
        )
        self.output_projection = nn.Linear(output_layer_dim, num_embeddings)

    def forward(
        self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
    ):
        encoder_padding_mask = encoder_out["encoder_padding_mask"]
        encoder_outs = encoder_out["encoder_out"]

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
        bsz, seqlen = prev_output_tokens.size()

        srclen = encoder_outs.size(0)

        # embed tokens
        embeddings = self.embed_tokens(prev_output_tokens)
        x = embeddings
        if self.dropout is not None:
            x = self.dropout(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # initialize previous states (or get from cache during incremental
        # generation)
        cached_state = utils.get_incremental_state(
            self, incremental_state, "cached_state"
        )
        if cached_state is not None:
            prev_hiddens, prev_cells = cached_state
        else:
            prev_hiddens = [encoder_out["encoder_out"].mean(dim=0)] * self.num_layers
            prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers

        attn_scores = x.new_zeros(bsz, srclen)
        attention_outs = []
        outs = []
        for j in range(seqlen):
            input = x[j, :, :]
            attention_out = None
            for i, layer in enumerate(self.layers):
                # the previous state is one layer below except for the bottom
                # layer where the previous state is the state emitted by the
                # top layer
                hidden, cell = layer(
                    input,
                    (
                        prev_hiddens[(i - 1) % self.num_layers],
                        prev_cells[(i - 1) % self.num_layers],
                    ),
                )
                if self.dropout is not None:
                    hidden = self.dropout(hidden)
                prev_hiddens[i] = hidden
                prev_cells[i] = cell
                if attention_out is None:
                    attention_out, attn_scores = self.attention(
                        hidden, encoder_outs, encoder_padding_mask
                    )
                    if self.dropout is not None:
                        attention_out = self.dropout(attention_out)
                    attention_outs.append(attention_out)
                input = attention_out

            # collect the output of the top layer
            outs.append(hidden)

        # cache previous states (no-op except during incremental generation)
        utils.set_incremental_state(
            self, incremental_state, "cached_state", (prev_hiddens, prev_cells)
        )

        # collect outputs across time steps
        x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
        attention_outs_concat = torch.cat(attention_outs, dim=0).view(
            seqlen, bsz, self.context_dim
        )

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)
        attention_outs_concat = attention_outs_concat.transpose(0, 1)

        # concat LSTM output, attention output and embedding
        # before output projection
        x = torch.cat((x, attention_outs_concat, embeddings), dim=2)
        x = self.deep_output_layer(x)
        x = torch.tanh(x)
        if self.dropout is not None:
            x = self.dropout(x)
        # project back to size of vocabulary
        x = self.output_projection(x)

        # to return the full attn_scores tensor, we need to fix the decoder
        # to account for subsampling input frames
        # return x, attn_scores
        return x, None

    def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)
        cached_state = utils.get_incremental_state(
            self, incremental_state, "cached_state"
        )
        if cached_state is None:
            return

        def reorder_state(state):
            if isinstance(state, list):
                return [reorder_state(state_i) for state_i in state]
            return state.index_select(0, new_order)

        new_state = tuple(map(reorder_state, cached_state))
        utils.set_incremental_state(self, incremental_state, "cached_state", new_state)


@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard")
def berard(args):
    """The original version: "End-to-End Automatic Speech Translation of
    Audiobooks" (https://arxiv.org/abs/1802.04200)
    """
    args.input_layers = getattr(args, "input_layers", "[256, 128]")
    args.conv_layers = getattr(args, "conv_layers", "[(16, 3, 2), (16, 3, 2)]")
    args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
    args.lstm_size = getattr(args, "lstm_size", 256)
    args.dropout = getattr(args, "dropout", 0.2)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
    args.decoder_num_layers = getattr(args, "decoder_num_layers", 2)
    args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 512)
    args.attention_dim = getattr(args, "attention_dim", 512)
    args.output_layer_dim = getattr(args, "output_layer_dim", 128)
    args.load_pretrained_encoder_from = getattr(
        args, "load_pretrained_encoder_from", None
    )
    args.load_pretrained_decoder_from = getattr(
        args, "load_pretrained_decoder_from", None
    )


@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_256_3_3")
def berard_256_3_3(args):
    """Used in
    * "Harnessing Indirect Training Data for End-to-End Automatic Speech
    Translation: Tricks of the Trade" (https://arxiv.org/abs/1909.06515)
    * "CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus"
    (https://arxiv.org/pdf/2002.01320.pdf)
    * "Self-Supervised Representations Improve End-to-End Speech Translation"
    (https://arxiv.org/abs/2006.12124)
    """
    args.decoder_num_layers = getattr(args, "decoder_num_layers", 3)
    berard(args)


@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_3_2")
def berard_512_3_2(args):
    args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
    args.lstm_size = getattr(args, "lstm_size", 512)
    args.dropout = getattr(args, "dropout", 0.3)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
    args.decoder_num_layers = getattr(args, "decoder_num_layers", 2)
    args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024)
    args.attention_dim = getattr(args, "attention_dim", 512)
    args.output_layer_dim = getattr(args, "output_layer_dim", 256)
    berard(args)


@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_5_3")
def berard_512_5_3(args):
    args.num_blstm_layers = getattr(args, "num_blstm_layers", 5)
    args.lstm_size = getattr(args, "lstm_size", 512)
    args.dropout = getattr(args, "dropout", 0.3)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
    args.decoder_num_layers = getattr(args, "decoder_num_layers", 3)
    args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024)
    args.attention_dim = getattr(args, "attention_dim", 512)
    args.output_layer_dim = getattr(args, "output_layer_dim", 256)
    berard(args)
