#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""CBHG related modules."""

import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask


class CBHGLoss(torch.nn.Module):
    """Loss function module for CBHG."""

    def __init__(self, use_masking=True):
        """Initialize CBHG loss module.

        Args:
            use_masking (bool): Whether to mask padded part in loss calculation.

        """
        super(CBHGLoss, self).__init__()
        self.use_masking = use_masking

    def forward(self, cbhg_outs, spcs, olens):
        """Calculate forward propagation.

        Args:
            cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
            spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
            olens (LongTensor): Batch of the lengths of each sequence (B,).

        Returns:
            Tensor: L1 loss value
            Tensor: Mean square error loss value.

        """
        # perform masking for padded values
        if self.use_masking:
            mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device)
            spcs = spcs.masked_select(mask)
            cbhg_outs = cbhg_outs.masked_select(mask)

        # calculate loss
        cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
        cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)

        return cbhg_l1_loss, cbhg_mse_loss


class CBHG(torch.nn.Module):
    """CBHG module to convert log Mel-filterbanks to linear spectrogram.

    This is a module of CBHG introduced
    in `Tacotron: Towards End-to-End Speech Synthesis`_.
    The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram.

    .. _`Tacotron: Towards End-to-End Speech Synthesis`:
         https://arxiv.org/abs/1703.10135

    """

    def __init__(
        self,
        idim,
        odim,
        conv_bank_layers=8,
        conv_bank_chans=128,
        conv_proj_filts=3,
        conv_proj_chans=256,
        highway_layers=4,
        highway_units=128,
        gru_units=256,
    ):
        """Initialize CBHG module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            conv_bank_layers (int, optional): The number of convolution bank layers.
            conv_bank_chans (int, optional): The number of channels in convolution bank.
            conv_proj_filts (int, optional):
                Kernel size of convolutional projection layer.
            conv_proj_chans (int, optional):
                The number of channels in convolutional projection layer.
            highway_layers (int, optional): The number of highway network layers.
            highway_units (int, optional): The number of highway network units.
            gru_units (int, optional): The number of GRU units (for both directions).

        """
        super(CBHG, self).__init__()
        self.idim = idim
        self.odim = odim
        self.conv_bank_layers = conv_bank_layers
        self.conv_bank_chans = conv_bank_chans
        self.conv_proj_filts = conv_proj_filts
        self.conv_proj_chans = conv_proj_chans
        self.highway_layers = highway_layers
        self.highway_units = highway_units
        self.gru_units = gru_units

        # define 1d convolution bank
        self.conv_bank = torch.nn.ModuleList()
        for k in range(1, self.conv_bank_layers + 1):
            if k % 2 != 0:
                padding = (k - 1) // 2
            else:
                padding = ((k - 1) // 2, (k - 1) // 2 + 1)
            self.conv_bank += [
                torch.nn.Sequential(
                    torch.nn.ConstantPad1d(padding, 0.0),
                    torch.nn.Conv1d(
                        idim, self.conv_bank_chans, k, stride=1, padding=0, bias=True
                    ),
                    torch.nn.BatchNorm1d(self.conv_bank_chans),
                    torch.nn.ReLU(),
                )
            ]

        # define max pooling (need padding for one-side to keep same length)
        self.max_pool = torch.nn.Sequential(
            torch.nn.ConstantPad1d((0, 1), 0.0), torch.nn.MaxPool1d(2, stride=1)
        )

        # define 1d convolution projection
        self.projections = torch.nn.Sequential(
            torch.nn.Conv1d(
                self.conv_bank_chans * self.conv_bank_layers,
                self.conv_proj_chans,
                self.conv_proj_filts,
                stride=1,
                padding=(self.conv_proj_filts - 1) // 2,
                bias=True,
            ),
            torch.nn.BatchNorm1d(self.conv_proj_chans),
            torch.nn.ReLU(),
            torch.nn.Conv1d(
                self.conv_proj_chans,
                self.idim,
                self.conv_proj_filts,
                stride=1,
                padding=(self.conv_proj_filts - 1) // 2,
                bias=True,
            ),
            torch.nn.BatchNorm1d(self.idim),
        )

        # define highway network
        self.highways = torch.nn.ModuleList()
        self.highways += [torch.nn.Linear(idim, self.highway_units)]
        for _ in range(self.highway_layers):
            self.highways += [HighwayNet(self.highway_units)]

        # define bidirectional GRU
        self.gru = torch.nn.GRU(
            self.highway_units,
            gru_units // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        # define final projection
        self.output = torch.nn.Linear(gru_units, odim, bias=True)

    def forward(self, xs, ilens):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim).
            ilens (LongTensor): Batch of lengths of each input sequence (B,).

        Return:
            Tensor: Batch of the padded sequence of outputs (B, Tmax, odim).
            LongTensor: Batch of lengths of each output sequence (B,).

        """
        xs = xs.transpose(1, 2)  # (B, idim, Tmax)
        convs = []
        for k in range(self.conv_bank_layers):
            convs += [self.conv_bank[k](xs)]
        convs = torch.cat(convs, dim=1)  # (B, #CH * #BANK, Tmax)
        convs = self.max_pool(convs)
        convs = self.projections(convs).transpose(1, 2)  # (B, Tmax, idim)
        xs = xs.transpose(1, 2) + convs
        # + 1 for dimension adjustment layer
        for i in range(self.highway_layers + 1):
            xs = self.highways[i](xs)

        # sort by length
        xs, ilens, sort_idx = self._sort_by_length(xs, ilens)

        # total_length needs for DataParallel
        # (see https://github.com/pytorch/pytorch/pull/6327)
        total_length = xs.size(1)
        if not isinstance(ilens, torch.Tensor):
            ilens = torch.tensor(ilens)
        xs = pack_padded_sequence(xs, ilens.cpu(), batch_first=True)
        self.gru.flatten_parameters()
        xs, _ = self.gru(xs)
        xs, ilens = pad_packed_sequence(xs, batch_first=True, total_length=total_length)

        # revert sorting by length
        xs, ilens = self._revert_sort_by_length(xs, ilens, sort_idx)

        xs = self.output(xs)  # (B, Tmax, odim)

        return xs, ilens

    def inference(self, x):
        """Inference.

        Args:
            x (Tensor): The sequences of inputs (T, idim).

        Return:
            Tensor: The sequence of outputs (T, odim).

        """
        assert len(x.size()) == 2
        xs = x.unsqueeze(0)
        ilens = x.new([x.size(0)]).long()

        return self.forward(xs, ilens)[0][0]

    def _sort_by_length(self, xs, ilens):
        sort_ilens, sort_idx = ilens.sort(0, descending=True)
        return xs[sort_idx], ilens[sort_idx], sort_idx

    def _revert_sort_by_length(self, xs, ilens, sort_idx):
        _, revert_idx = sort_idx.sort(0)
        return xs[revert_idx], ilens[revert_idx]


class HighwayNet(torch.nn.Module):
    """Highway Network module.

    This is a module of Highway Network introduced in `Highway Networks`_.

    .. _`Highway Networks`: https://arxiv.org/abs/1505.00387

    """

    def __init__(self, idim):
        """Initialize Highway Network module.

        Args:
            idim (int): Dimension of the inputs.

        """
        super(HighwayNet, self).__init__()
        self.idim = idim
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(idim, idim), torch.nn.ReLU()
        )
        self.gate = torch.nn.Sequential(torch.nn.Linear(idim, idim), torch.nn.Sigmoid())

    def forward(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Batch of inputs (B, ..., idim).

        Returns:
            Tensor: Batch of outputs, which are the same shape as inputs (B, ..., idim).

        """
        proj = self.projection(x)
        gate = self.gate(x)
        return proj * gate + x * (1.0 - gate)
