# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)


import collections
import logging
import math

import numpy as np
import six

# chainer related
from chainer import Variable, cuda, training
from chainer.training.updaters.multiprocess_parallel_updater import (
    gather_grads,
    gather_params,
    scatter_grads,
)


# copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py
def sum_sqnorm(arr):
    """Calculate the norm of the array.

    Args:
        arr (numpy.ndarray)

    Returns:
        Float: Sum of the norm calculated from the given array.

    """
    sq_sum = collections.defaultdict(float)
    for x in arr:
        with cuda.get_device_from_array(x) as dev:
            if x is not None:
                x = x.ravel()
                s = x.dot(x)
                sq_sum[int(dev)] += s
    return sum([float(i) for i in six.itervalues(sq_sum)])


class CustomUpdater(training.StandardUpdater):
    """Custom updater for chainer.

    Args:
        train_iter (iterator | dict[str, iterator]): Dataset iterator for the
            training dataset. It can also be a dictionary that maps strings to
            iterators. If this is just an iterator, then the iterator is
            registered by the name ``'main'``.
        optimizer (optimizer | dict[str, optimizer]): Optimizer to update
            parameters. It can also be a dictionary that maps strings to
            optimizers. If this is just an optimizer, then the optimizer is
            registered by the name ``'main'``.
        converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
            function to build input arrays. Each batch extracted by the main
            iterator and the ``device`` option are passed to this function.
            :func:`chainer.dataset.concat_examples` is used by default.
        device (int or dict): The destination device info to send variables. In the
            case of cpu or single gpu, `device=-1 or 0`, respectively.
            In the case of multi-gpu, `device={"main":0, "sub_1": 1, ...}`.
        accum_grad (int):The number of gradient accumulation. if set to 2, the network
            parameters will be updated once in twice,
            i.e. actual batchsize will be doubled.

    """

    def __init__(self, train_iter, optimizer, converter, device, accum_grad=1):
        super(CustomUpdater, self).__init__(
            train_iter, optimizer, converter=converter, device=device
        )
        self.forward_count = 0
        self.accum_grad = accum_grad
        self.start = True
        # To solve #1091, it is required to set the variable inside this class.
        self.device = device

    # The core part of the update routine can be customized by overriding.
    def update_core(self):
        """Main update routine for Custom Updater."""
        train_iter = self.get_iterator("main")
        optimizer = self.get_optimizer("main")

        # Get batch and convert into variables
        batch = train_iter.next()
        x = self.converter(batch, self.device)
        if self.start:
            optimizer.target.cleargrads()
            self.start = False

        # Compute the loss at this time step and accumulate it
        loss = optimizer.target(*x) / self.accum_grad
        loss.backward()  # Backprop
        loss.unchain_backward()  # Truncate the graph

        # update parameters
        self.forward_count += 1
        if self.forward_count != self.accum_grad:
            return
        self.forward_count = 0
        # compute the gradient norm to check if it is normal or not
        grad_norm = np.sqrt(
            sum_sqnorm([p.grad for p in optimizer.target.params(False)])
        )
        logging.info("grad norm={}".format(grad_norm))
        if math.isnan(grad_norm):
            logging.warning("grad norm is nan. Do not update model.")
        else:
            optimizer.update()
        optimizer.target.cleargrads()  # Clear the parameter gradients

    def update(self):
        self.update_core()
        if self.forward_count == 0:
            self.iteration += 1


class CustomParallelUpdater(training.updaters.MultiprocessParallelUpdater):
    """Custom Parallel Updater for chainer.

    Defines the main update routine.

    Args:
        train_iter (iterator | dict[str, iterator]): Dataset iterator for the
            training dataset. It can also be a dictionary that maps strings to
            iterators. If this is just an iterator, then the iterator is
            registered by the name ``'main'``.
        optimizer (optimizer | dict[str, optimizer]): Optimizer to update
            parameters. It can also be a dictionary that maps strings to
            optimizers. If this is just an optimizer, then the optimizer is
            registered by the name ``'main'``.
        converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
            function to build input arrays. Each batch extracted by the main
            iterator and the ``device`` option are passed to this function.
            :func:`chainer.dataset.concat_examples` is used by default.
        device (torch.device): Device to which the training data is sent.
            Negative value
            indicates the host memory (CPU).
        accum_grad (int):The number of gradient accumulation. if set to 2,
            the network parameters will be updated once in twice,
            i.e. actual batchsize will be doubled.

    """

    def __init__(self, train_iters, optimizer, converter, devices, accum_grad=1):
        super(CustomParallelUpdater, self).__init__(
            train_iters, optimizer, converter=converter, devices=devices
        )
        from cupy.cuda import nccl

        self.accum_grad = accum_grad
        self.forward_count = 0
        self.nccl = nccl

    # The core part of the update routine can be customized by overriding.
    def update_core(self):
        """Main Update routine of the custom parallel updater."""
        self.setup_workers()

        self._send_message(("update", None))
        with cuda.Device(self._devices[0]):
            # For reducing memory

            optimizer = self.get_optimizer("main")
            batch = self.get_iterator("main").next()
            x = self.converter(batch, self._devices[0])

            loss = self._master(*x) / self.accum_grad
            loss.backward()
            loss.unchain_backward()

            # NCCL: reduce grads
            null_stream = cuda.Stream.null
            if self.comm is not None:
                gg = gather_grads(self._master)
                self.comm.reduce(
                    gg.data.ptr,
                    gg.data.ptr,
                    gg.size,
                    self.nccl.NCCL_FLOAT,
                    self.nccl.NCCL_SUM,
                    0,
                    null_stream.ptr,
                )
                scatter_grads(self._master, gg)
                del gg

            # update parameters
            self.forward_count += 1
            if self.forward_count != self.accum_grad:
                return
            self.forward_count = 0
            # check gradient value
            grad_norm = np.sqrt(
                sum_sqnorm([p.grad for p in optimizer.target.params(False)])
            )
            logging.info("grad norm={}".format(grad_norm))

            # update
            if math.isnan(grad_norm):
                logging.warning("grad norm is nan. Do not update model.")
            else:
                optimizer.update()
            self._master.cleargrads()

            if self.comm is not None:
                gp = gather_params(self._master)
                self.comm.bcast(
                    gp.data.ptr, gp.size, self.nccl.NCCL_FLOAT, 0, null_stream.ptr
                )

    def update(self):
        self.update_core()
        if self.forward_count == 0:
            self.iteration += 1


class CustomConverter(object):
    """Custom Converter.

    Args:
        subsampling_factor (int): The subsampling factor.

    """

    def __init__(self, subsampling_factor=1):
        self.subsampling_factor = subsampling_factor

    def __call__(self, batch, device):
        """Perform sabsampling.

        Args:
            batch (list): Batch that will be sabsampled.
            device (device): GPU device.

        Returns:
            chainer.Variable: xp.array that sabsampled from batch.
            xp.array: xp.array of the length of the mini-batches.
            chainer.Variable: xp.array that sabsampled from batch.

        """
        # set device
        xp = cuda.cupy if device != -1 else np

        # batch should be located in list
        assert len(batch) == 1
        xs, ys = batch[0]

        # perform subsampling
        if self.subsampling_factor > 1:
            xs = [x[:: self.subsampling_factor, :] for x in xs]

        # get batch made of lengths of input sequences
        ilens = [x.shape[0] for x in xs]

        # convert to Variable
        xs = [Variable(xp.array(x, dtype=xp.float32)) for x in xs]
        ilens = xp.array(ilens, dtype=xp.int32)
        ys = [Variable(xp.array(y, dtype=xp.int32)) for y in ys]

        return xs, ilens, ys
