#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from examples.speech_recognition.data.replabels import pack_replabels
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion


@register_criterion("asg_loss")
class ASGCriterion(FairseqCriterion):
    @staticmethod
    def add_args(parser):
        group = parser.add_argument_group("ASG Loss")
        group.add_argument(
            "--asg-transitions-init",
            help="initial diagonal value of transition matrix",
            type=float,
            default=0.0,
        )
        group.add_argument(
            "--max-replabel", help="maximum # of replabels", type=int, default=2
        )
        group.add_argument(
            "--linseg-updates",
            help="# of training updates to use LinSeg initialization",
            type=int,
            default=0,
        )
        group.add_argument(
            "--hide-linseg-messages",
            help="hide messages about LinSeg initialization",
            action="store_true",
        )

    def __init__(
        self,
        task,
        silence_token,
        asg_transitions_init,
        max_replabel,
        linseg_updates,
        hide_linseg_messages,
    ):
        from flashlight.lib.sequence.criterion import ASGLoss, CriterionScaleMode

        super().__init__(task)
        self.tgt_dict = task.target_dictionary
        self.eos = self.tgt_dict.eos()
        self.silence = (
            self.tgt_dict.index(silence_token)
            if silence_token in self.tgt_dict
            else None
        )
        self.max_replabel = max_replabel

        num_labels = len(self.tgt_dict)
        self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
        self.asg.trans = torch.nn.Parameter(
            asg_transitions_init * torch.eye(num_labels), requires_grad=True
        )

        self.linseg_progress = torch.nn.Parameter(
            torch.tensor([0], dtype=torch.int), requires_grad=False
        )
        self.linseg_maximum = linseg_updates
        self.linseg_message_state = "none" if hide_linseg_messages else "start"

    @classmethod
    def build_criterion(cls, args, task):
        return cls(
            task,
            args.silence_token,
            args.asg_transitions_init,
            args.max_replabel,
            args.linseg_updates,
            args.hide_linseg_messages,
        )

    def linseg_step(self):
        if not self.training:
            return False
        if self.linseg_progress.item() < self.linseg_maximum:
            if self.linseg_message_state == "start":
                print("| using LinSeg to initialize ASG")
                self.linseg_message_state = "finish"
            self.linseg_progress.add_(1)
            return True
        elif self.linseg_message_state == "finish":
            print("| finished LinSeg initialization")
            self.linseg_message_state = "none"
        return False

    def replace_eos_with_silence(self, tgt):
        if tgt[-1] != self.eos:
            return tgt
        elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
            return tgt[:-1]
        else:
            return tgt[:-1] + [self.silence]

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        net_output = model(**sample["net_input"])
        emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
        B = emissions.size(0)
        T = emissions.size(1)
        device = emissions.device

        target = torch.IntTensor(B, T)
        target_size = torch.IntTensor(B)
        using_linseg = self.linseg_step()

        for b in range(B):
            initial_target_size = sample["target_lengths"][b].item()
            if initial_target_size == 0:
                raise ValueError("target size cannot be zero")

            tgt = sample["target"][b, :initial_target_size].tolist()
            tgt = self.replace_eos_with_silence(tgt)
            tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
            tgt = tgt[:T]

            if using_linseg:
                tgt = [tgt[t * len(tgt) // T] for t in range(T)]

            target[b][: len(tgt)] = torch.IntTensor(tgt)
            target_size[b] = len(tgt)

        loss = self.asg.forward(emissions, target.to(device), target_size.to(device))

        if reduce:
            loss = torch.sum(loss)

        sample_size = (
            sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        agg_output = {
            "loss": loss_sum / nsentences,
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }
        return agg_output
