# -*- coding: utf-8 -*-

# Copyright 2019 Tomoki Hayashi (Nagoya University)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""This code is based on https://github.com/kan-bayashi/PytorchWaveNetVocoder."""

import logging
import sys
import time

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


def encode_mu_law(x, mu=256):
    """Perform mu-law encoding.

    Args:
        x (ndarray): Audio signal with the range from -1 to 1.
        mu (int): Quantized level.

    Returns:
        ndarray: Quantized audio signal with the range from 0 to mu - 1.

    """
    mu = mu - 1
    fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
    return np.floor((fx + 1) / 2 * mu + 0.5).astype(np.int64)


def decode_mu_law(y, mu=256):
    """Perform mu-law decoding.

    Args:
        x (ndarray): Quantized audio signal with the range from 0 to mu - 1.
        mu (int): Quantized level.

    Returns:
        ndarray: Audio signal with the range from -1 to 1.

    """
    mu = mu - 1
    fx = (y - 0.5) / mu * 2 - 1
    x = np.sign(fx) / mu * ((1 + mu) ** np.abs(fx) - 1)
    return x


def initialize(m):
    """Initilize conv layers with xavier.

    Args:
        m (torch.nn.Module): Torch module.

    """
    if isinstance(m, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0.0)

    if isinstance(m, nn.ConvTranspose2d):
        nn.init.constant_(m.weight, 1.0)
        nn.init.constant_(m.bias, 0.0)


class OneHot(nn.Module):
    """Convert to one-hot vector.

    Args:
        depth (int): Dimension of one-hot vector.

    """

    def __init__(self, depth):
        super(OneHot, self).__init__()
        self.depth = depth

    def forward(self, x):
        """Calculate forward propagation.

        Args:
            x (LongTensor): long tensor variable with the shape  (B, T)

        Returns:
            Tensor: float tensor variable with the shape (B, depth, T)

        """
        x = x % self.depth
        x = torch.unsqueeze(x, 2)
        x_onehot = x.new_zeros(x.size(0), x.size(1), self.depth).float()

        return x_onehot.scatter_(2, x, 1)


class CausalConv1d(nn.Module):
    """1D dilated causal convolution."""

    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True):
        super(CausalConv1d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            padding=padding,
            dilation=dilation,
            bias=bias,
        )

    def forward(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor with the shape (B, in_channels, T).

        Returns:
            Tensor: Tensor with the shape (B, out_channels, T)

        """
        x = self.conv(x)
        if self.padding != 0:
            x = x[:, :, : -self.padding]
        return x


class UpSampling(nn.Module):
    """Upsampling layer with deconvolution.

    Args:
        upsampling_factor (int): Upsampling factor.

    """

    def __init__(self, upsampling_factor, bias=True):
        super(UpSampling, self).__init__()
        self.upsampling_factor = upsampling_factor
        self.bias = bias
        self.conv = nn.ConvTranspose2d(
            1,
            1,
            kernel_size=(1, self.upsampling_factor),
            stride=(1, self.upsampling_factor),
            bias=self.bias,
        )

    def forward(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Input tensor with the shape  (B, C, T)

        Returns:
            Tensor: Tensor with the shape (B, C, T') where T' = T * upsampling_factor.

        """
        x = x.unsqueeze(1)  # B x 1 x C x T
        x = self.conv(x)  # B x 1 x C x T'
        return x.squeeze(1)


class WaveNet(nn.Module):
    """Conditional wavenet.

    Args:
        n_quantize (int): Number of quantization.
        n_aux (int): Number of aux feature dimension.
        n_resch (int): Number of filter channels for residual block.
        n_skipch (int): Number of filter channels for skip connection.
        dilation_depth (int): Number of dilation depth
            (e.g. if set 10, max dilation = 2^(10-1)).
        dilation_repeat (int): Number of dilation repeat.
        kernel_size (int): Filter size of dilated causal convolution.
        upsampling_factor (int): Upsampling factor.

    """

    def __init__(
        self,
        n_quantize=256,
        n_aux=28,
        n_resch=512,
        n_skipch=256,
        dilation_depth=10,
        dilation_repeat=3,
        kernel_size=2,
        upsampling_factor=0,
    ):
        super(WaveNet, self).__init__()
        self.n_aux = n_aux
        self.n_quantize = n_quantize
        self.n_resch = n_resch
        self.n_skipch = n_skipch
        self.kernel_size = kernel_size
        self.dilation_depth = dilation_depth
        self.dilation_repeat = dilation_repeat
        self.upsampling_factor = upsampling_factor

        self.dilations = [
            2**i for i in range(self.dilation_depth)
        ] * self.dilation_repeat
        self.receptive_field = (self.kernel_size - 1) * sum(self.dilations) + 1

        # for preprocessing
        self.onehot = OneHot(self.n_quantize)
        self.causal = CausalConv1d(self.n_quantize, self.n_resch, self.kernel_size)
        if self.upsampling_factor > 0:
            self.upsampling = UpSampling(self.upsampling_factor)

        # for residual blocks
        self.dil_sigmoid = nn.ModuleList()
        self.dil_tanh = nn.ModuleList()
        self.aux_1x1_sigmoid = nn.ModuleList()
        self.aux_1x1_tanh = nn.ModuleList()
        self.skip_1x1 = nn.ModuleList()
        self.res_1x1 = nn.ModuleList()
        for d in self.dilations:
            self.dil_sigmoid += [
                CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)
            ]
            self.dil_tanh += [
                CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)
            ]
            self.aux_1x1_sigmoid += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
            self.aux_1x1_tanh += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
            self.skip_1x1 += [nn.Conv1d(self.n_resch, self.n_skipch, 1)]
            self.res_1x1 += [nn.Conv1d(self.n_resch, self.n_resch, 1)]

        # for postprocessing
        self.conv_post_1 = nn.Conv1d(self.n_skipch, self.n_skipch, 1)
        self.conv_post_2 = nn.Conv1d(self.n_skipch, self.n_quantize, 1)

    def forward(self, x, h):
        """Calculate forward propagation.

        Args:
            x (LongTensor): Quantized input waveform tensor with the shape  (B, T).
            h (Tensor): Auxiliary feature tensor with the shape  (B, n_aux, T).

        Returns:
            Tensor: Logits with the shape (B, T, n_quantize).

        """
        # preprocess
        output = self._preprocess(x)
        if self.upsampling_factor > 0:
            h = self.upsampling(h)

        # residual block
        skip_connections = []
        for i in range(len(self.dilations)):
            output, skip = self._residual_forward(
                output,
                h,
                self.dil_sigmoid[i],
                self.dil_tanh[i],
                self.aux_1x1_sigmoid[i],
                self.aux_1x1_tanh[i],
                self.skip_1x1[i],
                self.res_1x1[i],
            )
            skip_connections.append(skip)

        # skip-connection part
        output = sum(skip_connections)
        output = self._postprocess(output)

        return output

    def generate(self, x, h, n_samples, interval=None, mode="sampling"):
        """Generate a waveform with fast genration algorithm.

        This generation based on `Fast WaveNet Generation Algorithm`_.

        Args:
            x (LongTensor): Initial waveform tensor with the shape  (T,).
            h (Tensor): Auxiliary feature tensor with the shape  (n_samples + T, n_aux).
            n_samples (int): Number of samples to be generated.
            interval (int, optional): Log interval.
            mode (str, optional): "sampling" or "argmax".

        Return:
            ndarray: Generated quantized waveform (n_samples).

        .. _`Fast WaveNet Generation Algorithm`: https://arxiv.org/abs/1611.09482

        """
        # reshape inputs
        assert len(x.shape) == 1
        assert len(h.shape) == 2 and h.shape[1] == self.n_aux
        x = x.unsqueeze(0)
        h = h.transpose(0, 1).unsqueeze(0)

        # perform upsampling
        if self.upsampling_factor > 0:
            h = self.upsampling(h)

        # padding for shortage
        if n_samples > h.shape[2]:
            h = F.pad(h, (0, n_samples - h.shape[2]), "replicate")

        # padding if the length less than
        n_pad = self.receptive_field - x.size(1)
        if n_pad > 0:
            x = F.pad(x, (n_pad, 0), "constant", self.n_quantize // 2)
            h = F.pad(h, (n_pad, 0), "replicate")

        # prepare buffer
        output = self._preprocess(x)
        h_ = h[:, :, : x.size(1)]
        output_buffer = []
        buffer_size = []
        for i, d in enumerate(self.dilations):
            output, _ = self._residual_forward(
                output,
                h_,
                self.dil_sigmoid[i],
                self.dil_tanh[i],
                self.aux_1x1_sigmoid[i],
                self.aux_1x1_tanh[i],
                self.skip_1x1[i],
                self.res_1x1[i],
            )
            if d == 2 ** (self.dilation_depth - 1):
                buffer_size.append(self.kernel_size - 1)
            else:
                buffer_size.append(d * 2 * (self.kernel_size - 1))
            output_buffer.append(output[:, :, -buffer_size[i] - 1 : -1])

        # generate
        samples = x[0]
        start_time = time.time()
        for i in range(n_samples):
            output = samples[-self.kernel_size * 2 + 1 :].unsqueeze(0)
            output = self._preprocess(output)
            h_ = h[:, :, samples.size(0) - 1].contiguous().view(1, self.n_aux, 1)
            output_buffer_next = []
            skip_connections = []
            for j, d in enumerate(self.dilations):
                output, skip = self._generate_residual_forward(
                    output,
                    h_,
                    self.dil_sigmoid[j],
                    self.dil_tanh[j],
                    self.aux_1x1_sigmoid[j],
                    self.aux_1x1_tanh[j],
                    self.skip_1x1[j],
                    self.res_1x1[j],
                )
                output = torch.cat([output_buffer[j], output], dim=2)
                output_buffer_next.append(output[:, :, -buffer_size[j] :])
                skip_connections.append(skip)

            # update buffer
            output_buffer = output_buffer_next

            # get predicted sample
            output = sum(skip_connections)
            output = self._postprocess(output)[0]
            if mode == "sampling":
                posterior = F.softmax(output[-1], dim=0)
                dist = torch.distributions.Categorical(posterior)
                sample = dist.sample().unsqueeze(0)
            elif mode == "argmax":
                sample = output.argmax(-1)
            else:
                logging.error("mode should be sampling or argmax")
                sys.exit(1)
            samples = torch.cat([samples, sample], dim=0)

            # show progress
            if interval is not None and (i + 1) % interval == 0:
                elapsed_time_per_sample = (time.time() - start_time) / interval
                logging.info(
                    "%d/%d estimated time = %.3f sec (%.3f sec / sample)"
                    % (
                        i + 1,
                        n_samples,
                        (n_samples - i - 1) * elapsed_time_per_sample,
                        elapsed_time_per_sample,
                    )
                )
                start_time = time.time()

        return samples[-n_samples:].cpu().numpy()

    def _preprocess(self, x):
        x = self.onehot(x).transpose(1, 2)
        output = self.causal(x)
        return output

    def _postprocess(self, x):
        output = F.relu(x)
        output = self.conv_post_1(output)
        output = F.relu(output)  # B x C x T
        output = self.conv_post_2(output).transpose(1, 2)  # B x T x C
        return output

    def _residual_forward(
        self,
        x,
        h,
        dil_sigmoid,
        dil_tanh,
        aux_1x1_sigmoid,
        aux_1x1_tanh,
        skip_1x1,
        res_1x1,
    ):
        output_sigmoid = dil_sigmoid(x)
        output_tanh = dil_tanh(x)
        aux_output_sigmoid = aux_1x1_sigmoid(h)
        aux_output_tanh = aux_1x1_tanh(h)
        output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh(
            output_tanh + aux_output_tanh
        )
        skip = skip_1x1(output)
        output = res_1x1(output)
        output = output + x
        return output, skip

    def _generate_residual_forward(
        self,
        x,
        h,
        dil_sigmoid,
        dil_tanh,
        aux_1x1_sigmoid,
        aux_1x1_tanh,
        skip_1x1,
        res_1x1,
    ):
        output_sigmoid = dil_sigmoid(x)[:, :, -1:]
        output_tanh = dil_tanh(x)[:, :, -1:]
        aux_output_sigmoid = aux_1x1_sigmoid(h)
        aux_output_tanh = aux_1x1_tanh(h)
        output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh(
            output_tanh + aux_output_tanh
        )
        skip = skip_1x1(output)
        output = res_1x1(output)
        output = output + x[:, :, -1:]  # B x C x 1
        return output, skip
