# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""Class Declaration of Transformer's Training Subprocess."""
import collections
import logging
import math

import numpy as np
import six
from chainer import cuda
from chainer import functions as F
from chainer import training
from chainer.training import extension
from chainer.training.updaters.multiprocess_parallel_updater import (
    gather_grads,
    gather_params,
    scatter_grads,
)


# copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py
def sum_sqnorm(arr):
    """Calculate the norm of the array.

    Args:
        arr (numpy.ndarray)

    Returns:
        Float: Sum of the norm calculated from the given array.

    """
    sq_sum = collections.defaultdict(float)
    for x in arr:
        with cuda.get_device_from_array(x) as dev:
            if x is not None:
                x = x.ravel()
                s = x.dot(x)
                sq_sum[int(dev)] += s
    return sum([float(i) for i in six.itervalues(sq_sum)])


class CustomUpdater(training.StandardUpdater):
    """Custom updater for chainer.

    Args:
        train_iter (iterator | dict[str, iterator]): Dataset iterator for the
            training dataset. It can also be a dictionary that maps strings to
            iterators. If this is just an iterator, then the iterator is
            registered by the name ``'main'``.
        optimizer (optimizer | dict[str, optimizer]): Optimizer to update
            parameters. It can also be a dictionary that maps strings to
            optimizers. If this is just an optimizer, then the optimizer is
            registered by the name ``'main'``.
        converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
            function to build input arrays. Each batch extracted by the main
            iterator and the ``device`` option are passed to this function.
            :func:`chainer.dataset.concat_examples` is used by default.
        device (int or dict): The destination device info to send variables. In the
            case of cpu or single gpu, `device=-1 or 0`, respectively.
            In the case of multi-gpu, `device={"main":0, "sub_1": 1, ...}`.
        accum_grad (int):The number of gradient accumulation. if set to 2, the network
            parameters will be updated once in twice,
            i.e. actual batchsize will be doubled.

    """

    def __init__(self, train_iter, optimizer, converter, device, accum_grad=1):
        """Initialize Custom Updater."""
        super(CustomUpdater, self).__init__(
            train_iter, optimizer, converter=converter, device=device
        )
        self.accum_grad = accum_grad
        self.forward_count = 0
        self.start = True
        self.device = device
        logging.debug("using custom converter for transformer")

    # The core part of the update routine can be customized by overriding.
    def update_core(self):
        """Process main update routine for Custom Updater."""
        train_iter = self.get_iterator("main")
        optimizer = self.get_optimizer("main")

        # Get batch and convert into variables
        batch = train_iter.next()
        x = self.converter(batch, self.device)
        if self.start:
            optimizer.target.cleargrads()
            self.start = False

        # Compute the loss at this time step and accumulate it
        loss = optimizer.target(*x) / self.accum_grad
        loss.backward()  # Backprop

        self.forward_count += 1
        if self.forward_count != self.accum_grad:
            return
        self.forward_count = 0
        # compute the gradient norm to check if it is normal or not
        grad_norm = np.sqrt(
            sum_sqnorm([p.grad for p in optimizer.target.params(False)])
        )
        logging.info("grad norm={}".format(grad_norm))
        if math.isnan(grad_norm):
            logging.warning("grad norm is nan. Do not update model.")
        else:
            optimizer.update()
        optimizer.target.cleargrads()  # Clear the parameter gradients

    def update(self):
        """Update step for Custom Updater."""
        self.update_core()
        if self.forward_count == 0:
            self.iteration += 1


class CustomParallelUpdater(training.updaters.MultiprocessParallelUpdater):
    """Custom Parallel Updater for chainer.

    Defines the main update routine.

    Args:
        train_iter (iterator | dict[str, iterator]): Dataset iterator for the
            training dataset. It can also be a dictionary that maps strings to
            iterators. If this is just an iterator, then the iterator is
            registered by the name ``'main'``.
        optimizer (optimizer | dict[str, optimizer]): Optimizer to update
            parameters. It can also be a dictionary that maps strings to
            optimizers. If this is just an optimizer, then the optimizer is
            registered by the name ``'main'``.
        converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
            function to build input arrays. Each batch extracted by the main
            iterator and the ``device`` option are passed to this function.
            :func:`chainer.dataset.concat_examples` is used by default.
        device (torch.device): Device to which the training data is sent. Negative value
            indicates the host memory (CPU).
        accum_grad (int):The number of gradient accumulation. if set to 2, the network
            parameters will be updated once in twice,
            i.e. actual batchsize will be doubled.

    """

    def __init__(self, train_iters, optimizer, converter, devices, accum_grad=1):
        """Initialize custom parallel updater."""
        from cupy.cuda import nccl

        super(CustomParallelUpdater, self).__init__(
            train_iters, optimizer, converter=converter, devices=devices
        )
        self.accum_grad = accum_grad
        self.forward_count = 0
        self.nccl = nccl
        logging.debug("using custom parallel updater for transformer")

    # The core part of the update routine can be customized by overriding.
    def update_core(self):
        """Process main update routine for Custom Parallel Updater."""
        self.setup_workers()

        self._send_message(("update", None))
        with cuda.Device(self._devices[0]):
            # For reducing memory
            optimizer = self.get_optimizer("main")
            batch = self.get_iterator("main").next()
            x = self.converter(batch, self._devices[0])

            loss = self._master(*x) / self.accum_grad
            loss.backward()

            # NCCL: reduce grads
            null_stream = cuda.Stream.null
            if self.comm is not None:
                gg = gather_grads(self._master)
                self.comm.reduce(
                    gg.data.ptr,
                    gg.data.ptr,
                    gg.size,
                    self.nccl.NCCL_FLOAT,
                    self.nccl.NCCL_SUM,
                    0,
                    null_stream.ptr,
                )
                scatter_grads(self._master, gg)
                del gg

            # update parameters
            self.forward_count += 1
            if self.forward_count != self.accum_grad:
                return
            self.forward_count = 0
            # check gradient value
            grad_norm = np.sqrt(
                sum_sqnorm([p.grad for p in optimizer.target.params(False)])
            )
            logging.info("grad norm={}".format(grad_norm))

            # update
            if math.isnan(grad_norm):
                logging.warning("grad norm is nan. Do not update model.")
            else:
                optimizer.update()
            self._master.cleargrads()

            if self.comm is not None:
                gp = gather_params(self._master)
                self.comm.bcast(
                    gp.data.ptr, gp.size, self.nccl.NCCL_FLOAT, 0, null_stream.ptr
                )

    def update(self):
        """Update step for Custom Parallel Updater."""
        self.update_core()
        if self.forward_count == 0:
            self.iteration += 1


class VaswaniRule(extension.Extension):
    """Trainer extension to shift an optimizer attribute magically by Vaswani.

    Args:
        attr (str): Name of the attribute to shift.
        rate (float): Rate of the exponential shift. This value is multiplied
            to the attribute at each call.
        init (float): Initial value of the attribute. If it is ``None``, the
            extension extracts the attribute at the first call and uses it as
            the initial value.
        target (float): Target value of the attribute. If the attribute reaches
            this value, the shift stops.
        optimizer (~chainer.Optimizer): Target optimizer to adjust the
            attribute. If it is ``None``, the main optimizer of the updater is
            used.

    """

    def __init__(
        self,
        attr,
        d,
        warmup_steps=4000,
        init=None,
        target=None,
        optimizer=None,
        scale=1.0,
    ):
        """Initialize Vaswani rule extension."""
        self._attr = attr
        self._d_inv05 = d ** (-0.5) * scale
        self._warmup_steps_inv15 = warmup_steps ** (-1.5)
        self._init = init
        self._target = target
        self._optimizer = optimizer
        self._t = 0
        self._last_value = None

    def initialize(self, trainer):
        """Initialize Optimizer values."""
        optimizer = self._get_optimizer(trainer)
        # ensure that _init is set
        if self._init is None:
            self._init = self._d_inv05 * (1.0 * self._warmup_steps_inv15)
        if self._last_value is not None:  # resuming from a snapshot
            self._update_value(optimizer, self._last_value)
        else:
            self._update_value(optimizer, self._init)

    def __call__(self, trainer):
        """Forward extension."""
        self._t += 1
        optimizer = self._get_optimizer(trainer)
        value = self._d_inv05 * min(
            self._t ** (-0.5), self._t * self._warmup_steps_inv15
        )
        self._update_value(optimizer, value)

    def serialize(self, serializer):
        """Serialize extension."""
        self._t = serializer("_t", self._t)
        self._last_value = serializer("_last_value", self._last_value)

    def _get_optimizer(self, trainer):
        """Obtain optimizer from trainer."""
        return self._optimizer or trainer.updater.get_optimizer("main")

    def _update_value(self, optimizer, value):
        """Update requested variable values."""
        setattr(optimizer, self._attr, value)
        self._last_value = value


class CustomConverter(object):
    """Custom Converter.

    Args:
        subsampling_factor (int): The subsampling factor.

    """

    def __init__(self):
        """Initialize subsampling."""
        pass

    def __call__(self, batch, device):
        """Perform subsampling.

        Args:
            batch (list): Batch that will be sabsampled.
            device (chainer.backend.Device): CPU or GPU device.

        Returns:
            chainer.Variable: xp.array that are padded and subsampled from batch.
            xp.array: xp.array of the length of the mini-batches.
            chainer.Variable: xp.array that are padded and subsampled from batch.

        """
        # For transformer, data is processed in CPU.
        # batch should be located in list
        assert len(batch) == 1
        xs, ys = batch[0]
        xs = F.pad_sequence(xs, padding=-1).data
        # get batch of lengths of input sequences
        ilens = np.array([x.shape[0] for x in xs], dtype=np.int32)
        return xs, ilens, ys
