"""Library implementing quaternion-valued recurrent neural networks.

Authors
 * Titouan Parcollet 2020
"""

from typing import Optional

import torch

from speechbrain.nnet.quaternion_networks.q_linear import QLinear
from speechbrain.nnet.quaternion_networks.q_normalization import QBatchNorm
from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


class QLSTM(torch.nn.Module):
    """This function implements a quaternion-valued LSTM as first introduced
    in : "Quaternion Recurrent Neural Networks", Parcollet T. et al.

    Input format is (batch, time, fea) or (batch, time, fea, channel).
    In the latter shape, the two last dimensions will be merged:
    (batch, time, fea * channel)

    Arguments
    ---------
    hidden_size : int
        Number of output neurons (i.e, the dimensionality of the output).
        Specified value is in terms of quaternion-valued neurons. Thus, the output
        is 4*hidden_size.
    input_shape : tuple
        The expected shape of the input tensor.
    num_layers : int, optional
        Number of layers to employ in the RNN architecture (default 1).
    bias : bool, optional
        If True, the additive bias b is adopted (default True).
    dropout : float, optional
        It is the dropout factor (must be between 0 and 1) (default 0.0).
    bidirectional : bool, optional
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used (default False).
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the quaternion-valued weights (default "glorot").
    weight_init : str, optional
        (quaternion, unitary).
        This parameter defines the initialization procedure of the
        quaternion-valued weights. "quaternion" will generate random quaternion
        weights following the init_criterion and the quaternion polar form.
        "unitary" will normalize the weights to lie on the unit circle (default "quaternion").
        More details in: "Quaternion Recurrent Neural Networks",
        Parcollet T. et al.
    autograd : bool, optional
        When True, the default PyTorch autograd will be used. When False, a
        custom backpropagation will be used, reducing by a factor 3 to 4 the
        memory consumption. It is also 2x slower (default True).


    Example
    -------
    >>> inp_tensor = torch.rand([10, 16, 40])
    >>> rnn = QLSTM(hidden_size=16, input_shape=inp_tensor.shape)
    >>> out_tensor = rnn(inp_tensor)
    >>>
    torch.Size([10, 16, 64])
    """

    def __init__(
        self,
        hidden_size,
        input_shape,
        num_layers=1,
        bias=True,
        dropout=0.0,
        bidirectional=False,
        init_criterion="glorot",
        weight_init="quaternion",
        autograd=True,
    ):
        super().__init__()
        self.hidden_size = hidden_size * 4
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.reshape = False
        self.init_criterion = init_criterion
        self.weight_init = weight_init
        self.autograd = autograd

        if len(input_shape) > 3:
            self.reshape = True

        # Computing the feature dimensionality
        self.fea_dim = torch.prod(torch.tensor(input_shape[2:]))
        self.batch_size = input_shape[0]

        self.rnn = self._init_layers()

    def _init_layers(self):
        """Initializes the layers of the quaternionLSTM.

        Returns
        -------
        rnn : ModuleList
            The initialized QLSTM_Layers
        """
        rnn = torch.nn.ModuleList([])
        current_dim = self.fea_dim
        for i in range(self.num_layers):
            rnn_lay = QLSTM_Layer(
                current_dim,
                self.hidden_size,
                self.num_layers,
                self.batch_size,
                dropout=self.dropout,
                bidirectional=self.bidirectional,
                init_criterion=self.init_criterion,
                weight_init=self.weight_init,
                autograd=self.autograd,
            )

            rnn.append(rnn_lay)

            if self.bidirectional:
                current_dim = self.hidden_size * 2
            else:
                current_dim = self.hidden_size

        return rnn

    def forward(self, x, hx: Optional[torch.Tensor] = None):
        """Returns the output of the vanilla QuaternionRNN.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        output : torch.Tensor
            Output of Quaternion RNN
        hh : torch.Tensor
            Hidden states
        """

        # Reshaping input tensors for 4d inputs
        if self.reshape:
            if x.ndim == 4:
                x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])

        output, hh = self._forward_rnn(x, hx=hx)

        return output, hh

    def _forward_rnn(self, x, hx: Optional[torch.Tensor]):
        """Returns the output of the vanilla QuaternionRNN.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        x : torch.Tensor
            The output of the Quaternion RNN layer.
        h : torch.Tensor
            The hiddens states.
        """

        h = []
        if hx is not None:
            if self.bidirectional:
                hx = hx.reshape(
                    self.num_layers, self.batch_size * 2, self.hidden_size
                )

        # Processing the different layers
        for i, rnn_lay in enumerate(self.rnn):
            if hx is not None:
                x = rnn_lay(x, hx=hx[i])
            else:
                x = rnn_lay(x, hx=None)
            h.append(x[:, -1, :])
        h = torch.stack(h, dim=1)

        if self.bidirectional:
            h = h.reshape(h.shape[1] * 2, h.shape[0], self.hidden_size)
        else:
            h = h.transpose(0, 1)

        return x, h


class QLSTM_Layer(torch.nn.Module):
    """This function implements quaternion-valued LSTM layer.

    Arguments
    ---------
    input_size : int
        Feature dimensionality of the input tensors (in term of real values).
    hidden_size : int
        Number of output values (in term of real values).
    num_layers : int, optional
        Number of layers to employ in the RNN architecture (default 1).
    batch_size : int
        Batch size of the input tensors.
    dropout : float, optional
        It is the dropout factor (must be between 0 and 1) (default 0.0).
    bidirectional : bool, optional
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used (default False).
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the quaternion-valued weights (default "glorot").
    weight_init : str, optional
        (quaternion, unitary).
        This parameter defines the initialization procedure of the
        quaternion-valued weights. "quaternion" will generate random quaternion
        weights following the init_criterion and the quaternion polar form.
        "unitary" will normalize the weights to lie on the unit circle (default "quaternion").
        More details in: "Quaternion Recurrent Neural Networks",
        Parcollet T. et al.
    autograd : bool, optional
        When True, the default PyTorch autograd will be used. When False, a
        custom backpropagation will be used, reducing by a factor 3 to 4 the
        memory consumption. It is also 2x slower (default True).
    """

    def __init__(
        self,
        input_size,
        hidden_size,
        num_layers,
        batch_size,
        dropout=0.0,
        bidirectional=False,
        init_criterion="glorot",
        weight_init="quaternion",
        autograd="true",
    ):
        super().__init__()

        self.hidden_size = int(hidden_size) // 4  # Express in term of quat
        self.input_size = int(input_size)
        self.batch_size = batch_size
        self.bidirectional = bidirectional
        self.dropout = dropout
        self.init_criterion = init_criterion
        self.weight_init = weight_init
        self.autograd = autograd

        self.w = QLinear(
            input_shape=self.input_size,
            n_neurons=self.hidden_size * 4,  # Forget, Input, Output, Cell
            bias=True,
            weight_init=self.weight_init,
            init_criterion=self.init_criterion,
            autograd=self.autograd,
        )

        self.u = QLinear(
            input_shape=self.hidden_size * 4,  # The input size is in real
            n_neurons=self.hidden_size * 4,
            bias=True,
            weight_init=self.weight_init,
            init_criterion=self.init_criterion,
            autograd=self.autograd,
        )

        if self.bidirectional:
            self.batch_size = self.batch_size * 2

        # Initial state
        self.register_buffer("h_init", torch.zeros(1, self.hidden_size * 4))

        # Preloading dropout masks (gives some speed improvement)
        self._init_drop(self.batch_size)

        # Initializing dropout
        self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)

        self.drop_mask_te = torch.tensor([1.0]).float()

    def forward(self, x, hx: Optional[torch.Tensor] = None):
        # type: (torch.Tensor, Optional[torch.Tensor]) -> torch.Tensor # noqa F821
        """Returns the output of the QuaternionRNN_layer.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        h : torch.Tensor
            The output of the Quaternion RNN layer.
        """
        if self.bidirectional:
            x_flip = x.flip(1)
            x = torch.cat([x, x_flip], dim=0)

        # Change batch size if needed
        self._change_batch_size(x)

        # Feed-forward affine transformations (all steps in parallel)
        w = self.w(x)

        # Processing time steps
        if hx is not None:
            h = self._quaternionlstm_cell(w, hx)
        else:
            h = self._quaternionlstm_cell(w, self.h_init)

        if self.bidirectional:
            h_f, h_b = h.chunk(2, dim=0)
            h_b = h_b.flip(1)
            h = torch.cat([h_f, h_b], dim=2)

        return h

    def _quaternionlstm_cell(self, w, ht):
        """Returns the hidden states for each time step.

        Arguments
        ---------
        w : torch.Tensor
            Linearly transformed input.
        ht : torch.Tensor
            Hidden layer.

        Returns
        -------
        h : torch.Tensor
            The hidden states for all steps.
        """

        hiddens = []

        # Initialise the cell state
        ct = self.h_init

        # Sampling dropout mask
        drop_mask = self._sample_drop_mask(w)

        # Loop over time axis
        for k in range(w.shape[1]):
            gates = w[:, k] + self.u(ht)
            (
                itr,
                iti,
                itj,
                itk,
                ftr,
                fti,
                ftj,
                ftk,
                otr,
                oti,
                otj,
                otk,
                ctr,
                cti,
                ctj,
                ctk,
            ) = gates.chunk(16, 1)
            it = torch.sigmoid(torch.cat([itr, iti, itj, itk], dim=-1))
            ft = torch.sigmoid(torch.cat([ftr, fti, ftj, ftk], dim=-1))
            ot = torch.sigmoid(torch.cat([otr, oti, otj, otk], dim=-1))

            ct = (
                it
                * torch.tanh(torch.cat([ctr, cti, ctj, ctk], dim=-1))
                * drop_mask
                + ft * ct
            )
            ht = ot * torch.tanh(ct)
            hiddens.append(ht)

        # Stacking hidden states
        h = torch.stack(hiddens, dim=1)
        return h

    def _init_drop(self, batch_size):
        """Initializes the recurrent dropout operation. To speed it up,
        the dropout masks are sampled in advance.
        """
        self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)
        self.drop_mask_te = torch.tensor([1.0]).float()

        self.N_drop_masks = 16000
        self.drop_mask_cnt = 0

        self.drop_masks = self.drop(
            torch.ones(self.N_drop_masks, self.hidden_size * 4)
        ).data

    def _sample_drop_mask(self, w):
        """Selects one of the pre-defined dropout masks."""
        if self.training:
            # Sample new masks when needed
            if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                self.drop_mask_cnt = 0
                self.drop_masks = self.drop(
                    torch.ones(
                        self.N_drop_masks, self.hidden_size * 4, device=w.device
                    )
                ).data

            # Sampling the mask
            drop_mask = self.drop_masks[
                self.drop_mask_cnt : self.drop_mask_cnt + self.batch_size
            ]
            self.drop_mask_cnt = self.drop_mask_cnt + self.batch_size

        else:
            drop_mask = self.drop_mask_te

        return drop_mask

    def _change_batch_size(self, x):
        """This function changes the batch size when it is different from
        the one detected in the initialization method. This might happen in
        the case of multi-gpu or when we have different batch sizes in train
        and test. We also update the h_int and drop masks.
        """

        if self.batch_size != x.shape[0]:
            self.batch_size = x.shape[0]

            if self.training:
                self.drop_masks = self.drop(
                    torch.ones(
                        self.N_drop_masks, self.hidden_size * 4, device=x.device
                    )
                ).data


class QRNN(torch.nn.Module):
    """This function implements a vanilla quaternion-valued RNN.

    Input format is (batch, time, fea) or (batch, time, fea, channel).
    In the latter shape, the two last dimensions will be merged:
    (batch, time, fea * channel)

    Arguments
    ---------
    hidden_size : int
        Number of output neurons (i.e, the dimensionality of the output).
        Specified value is in term of quaternion-valued neurons. Thus, the output
        is 4*hidden_size.
    input_shape : tuple
        Expected shape of the input tensor.
    nonlinearity : str, optional
        Type of nonlinearity (tanh, relu) (default "tanh").
    num_layers : int, optional
        Number of layers to employ in the RNN architecture (default 1).
    bias : bool, optional
        If True, the additive bias b is adopted (default True).
    dropout : float, optional
        It is the dropout factor (must be between 0 and 1) (default 0.0).
    bidirectional : bool, optional
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used (default False).
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the quaternion-valued weights (default "glorot").
    weight_init : str, optional
        (quaternion, unitary).
        This parameter defines the initialization procedure of the
        quaternion-valued weights. "quaternion" will generate random quaternion
        weights following the init_criterion and the quaternion polar form.
        "unitary" will normalize the weights to lie on the unit circle (default "quaternion").
        More details in: "Quaternion Recurrent Neural Networks",
        Parcollet T. et al.
    autograd : bool, optional
        When True, the default PyTorch autograd will be used. When False, a
        custom backpropagation will be used, reducing by a factor 3 to 4 the
        memory consumption. It is also 2x slower (default True).


    Example
    -------
    >>> inp_tensor = torch.rand([10, 16, 40])
    >>> rnn = QRNN(hidden_size=16, input_shape=inp_tensor.shape)
    >>> out_tensor = rnn(inp_tensor)
    >>>
    torch.Size([10, 16, 64])
    """

    def __init__(
        self,
        hidden_size,
        input_shape,
        nonlinearity="tanh",
        num_layers=1,
        bias=True,
        dropout=0.0,
        bidirectional=False,
        init_criterion="glorot",
        weight_init="quaternion",
        autograd=True,
    ):
        super().__init__()
        self.hidden_size = hidden_size * 4  # z = x + iy
        self.nonlinearity = nonlinearity
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.reshape = False
        self.init_criterion = init_criterion
        self.weight_init = weight_init
        self.autograd = autograd

        if len(input_shape) > 3:
            self.reshape = True

        # Computing the feature dimensionality
        self.fea_dim = torch.prod(torch.tensor(input_shape[2:]))
        self.batch_size = input_shape[0]

        self.rnn = self._init_layers()

    def _init_layers(self):
        """
        Initializes the layers of the quaternionRNN.

        Returns
        -------
        rnn : ModuleList
            The initialized QRNN_Layers.
        """

        rnn = torch.nn.ModuleList([])
        current_dim = self.fea_dim
        for i in range(self.num_layers):
            rnn_lay = QRNN_Layer(
                current_dim,
                self.hidden_size,
                self.num_layers,
                self.batch_size,
                dropout=self.dropout,
                nonlinearity=self.nonlinearity,
                bidirectional=self.bidirectional,
                init_criterion=self.init_criterion,
                weight_init=self.weight_init,
                autograd=self.autograd,
            )

            rnn.append(rnn_lay)

            if self.bidirectional:
                current_dim = self.hidden_size * 2
            else:
                current_dim = self.hidden_size

        return rnn

    def forward(self, x, hx: Optional[torch.Tensor] = None):
        """Returns the output of the vanilla QuaternionRNN.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        output : torch.Tensor
        hh : torch.Tensor
        """
        # Reshaping input tensors for 4d inputs
        if self.reshape:
            if x.ndim == 4:
                x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])

        output, hh = self._forward_rnn(x, hx=hx)

        return output, hh

    def _forward_rnn(self, x, hx: Optional[torch.Tensor]):
        """Returns the output of the vanilla QuaternionRNN.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        x : torch.Tensor
            Outputs
        h : torch.Tensor
            Hidden states.
        """
        h = []
        if hx is not None:
            if self.bidirectional:
                hx = hx.reshape(
                    self.num_layers, self.batch_size * 2, self.hidden_size
                )

        # Processing the different layers
        for i, rnn_lay in enumerate(self.rnn):
            if hx is not None:
                x = rnn_lay(x, hx=hx[i])
            else:
                x = rnn_lay(x, hx=None)
            h.append(x[:, -1, :])
        h = torch.stack(h, dim=1)

        if self.bidirectional:
            h = h.reshape(h.shape[1] * 2, h.shape[0], self.hidden_size)
        else:
            h = h.transpose(0, 1)

        return x, h


class QRNN_Layer(torch.nn.Module):
    """This function implements quaternion-valued recurrent layer.

    Arguments
    ---------
    input_size : int
        Feature dimensionality of the input tensors (in term of real values).
    hidden_size : int
        Number of output values (in term of real values).
    num_layers : int, optional
        Number of layers to employ in the RNN architecture (default 1).
    batch_size : int
        Batch size of the input tensors.
    dropout : float, optional
        It is the dropout factor (must be between 0 and 1) (default 0.0).
    nonlinearity : str, optional
        Type of nonlinearity (tanh, relu) (default "tanh").
    bidirectional : bool, optional
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used (default False).
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the quaternion-valued weights (default "glorot").
    weight_init : str, optional
        (quaternion, unitary).
        This parameter defines the initialization procedure of the
        quaternion-valued weights. "quaternion" will generate random quaternion
        weights following the init_criterion and the quaternion polar form.
        "unitary" will normalize the weights to lie on the unit circle (default "quaternion").
        More details in: "Quaternion Recurrent Neural Networks",
        Parcollet T. et al.
    autograd : bool, optional
        When True, the default PyTorch autograd will be used. When False, a
        custom backpropagation will be used, reducing by a factor 3 to 4 the
        memory consumption. It is also 2x slower (default True).
    """

    def __init__(
        self,
        input_size,
        hidden_size,
        num_layers,
        batch_size,
        dropout=0.0,
        nonlinearity="tanh",
        bidirectional=False,
        init_criterion="glorot",
        weight_init="quaternion",
        autograd="true",
    ):
        super().__init__()

        self.hidden_size = int(hidden_size) // 4  # Express in term of quat
        self.input_size = int(input_size)
        self.batch_size = batch_size
        self.bidirectional = bidirectional
        self.dropout = dropout
        self.init_criterion = init_criterion
        self.weight_init = weight_init
        self.autograd = autograd

        self.w = QLinear(
            input_shape=self.input_size,
            n_neurons=self.hidden_size,
            bias=True,
            weight_init=self.weight_init,
            init_criterion=self.init_criterion,
            autograd=self.autograd,
        )

        self.u = QLinear(
            input_shape=self.hidden_size * 4,  # The input size is in real
            n_neurons=self.hidden_size,
            bias=True,
            weight_init=self.weight_init,
            init_criterion=self.init_criterion,
            autograd=self.autograd,
        )

        if self.bidirectional:
            self.batch_size = self.batch_size * 2

        # Initial state
        self.register_buffer("h_init", torch.zeros(1, self.hidden_size * 4))

        # Preloading dropout masks (gives some speed improvement)
        self._init_drop(self.batch_size)

        # Initializing dropout
        self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)

        self.drop_mask_te = torch.tensor([1.0]).float()

        # Setting the activation function
        if nonlinearity == "tanh":
            self.act = torch.nn.Tanh()
        else:
            self.act = torch.nn.ReLU()

    def forward(self, x, hx: Optional[torch.Tensor] = None):
        # type: (torch.Tensor, Optional[torch.Tensor]) -> torch.Tensor # noqa F821
        """Returns the output of the QuaternionRNN_layer.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        h : torch.Tensor
            Output of the Quaternion RNN
        """
        if self.bidirectional:
            x_flip = x.flip(1)
            x = torch.cat([x, x_flip], dim=0)

        # Change batch size if needed
        self._change_batch_size(x)

        # Feed-forward affine transformations (all steps in parallel)
        w = self.w(x)

        # Processing time steps
        if hx is not None:
            h = self._quaternionrnn_cell(w, hx)
        else:
            h = self._quaternionrnn_cell(w, self.h_init)

        if self.bidirectional:
            h_f, h_b = h.chunk(2, dim=0)
            h_b = h_b.flip(1)
            h = torch.cat([h_f, h_b], dim=2)

        return h

    def _quaternionrnn_cell(self, w, ht):
        """Returns the hidden states for each time step.

        Arguments
        ---------
        w : torch.Tensor
            Linearly transformed input.
        ht : torch.Tensor
            The hidden layer.

        Returns
        -------
        h : torch.Tensor
            Hidden states for each step.
        """
        hiddens = []

        # Sampling dropout mask
        drop_mask = self._sample_drop_mask(w)

        # Loop over time axis
        for k in range(w.shape[1]):
            at = w[:, k] + self.u(ht)
            ht = self.act(at) * drop_mask
            hiddens.append(ht)

        # Stacking hidden states
        h = torch.stack(hiddens, dim=1)
        return h

    def _init_drop(self, batch_size):
        """Initializes the recurrent dropout operation. To speed it up,
        the dropout masks are sampled in advance.
        """

        self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)
        self.drop_mask_te = torch.tensor([1.0]).float()

        self.N_drop_masks = 16000
        self.drop_mask_cnt = 0

        self.drop_masks = self.drop(
            torch.ones(self.N_drop_masks, self.hidden_size * 4)
        ).data

    def _sample_drop_mask(self, w):
        """Selects one of the pre-defined dropout masks."""

        if self.training:
            # Sample new masks when needed
            if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                self.drop_mask_cnt = 0
                self.drop_masks = self.drop(
                    torch.ones(
                        self.N_drop_masks, self.hidden_size * 4, device=w.device
                    )
                ).data

            # Sampling the mask
            drop_mask = self.drop_masks[
                self.drop_mask_cnt : self.drop_mask_cnt + self.batch_size
            ]
            self.drop_mask_cnt = self.drop_mask_cnt + self.batch_size

        else:
            drop_mask = self.drop_mask_te

        return drop_mask

    def _change_batch_size(self, x):
        """This function changes the batch size when it is different from
        the one detected in the initialization method. This might happen in
        the case of multi-gpu or when we have different batch sizes in train
        and test. We also update the h_int and drop masks.
        """

        if self.batch_size != x.shape[0]:
            self.batch_size = x.shape[0]

            if self.training:
                self.drop_masks = self.drop(
                    torch.ones(
                        self.N_drop_masks, self.hidden_size * 2, device=x.device
                    )
                ).data


class QLiGRU(torch.nn.Module):
    """This function implements a quaternion-valued Light GRU (liGRU).

    Ligru is single-gate GRU model based on batch-norm + relu
    activations + recurrent dropout. For more info see:

    "M. Ravanelli, P. Brakel, M. Omologo, Y. Bengio,
    Light Gated Recurrent Units for Speech Recognition,
    in IEEE Transactions on Emerging Topics in Computational Intelligence,
    2018" (https://arxiv.org/abs/1803.10225)

    To speed it up, it is compiled with the torch just-in-time compiler (jit)
    right before using it.

    It accepts in input tensors formatted as (batch, time, fea).
    In the case of 4d inputs like (batch, time, fea, channel) the tensor is
    flattened as (batch, time, fea*channel).

    Arguments
    ---------
    hidden_size : int
        Number of output neurons (i.e, the dimensionality of the output).
        Specified value is in term of quaternion-valued neurons. Thus, the output
        is 2*hidden_size.
    input_shape : tuple
        Expected shape of the input.
    nonlinearity : str
        Type of nonlinearity (tanh, relu).
    num_layers : int
        Number of layers to employ in the RNN architecture.
    bias : bool
        If True, the additive bias b is adopted.
    dropout: float
        It is the dropout factor (must be between 0 and 1).
    bidirectional : bool
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used.
    init_criterion : str, optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the quaternion-valued weights (default "glorot").
    weight_init : str, optional
        (quaternion, unitary).
        This parameter defines the initialization procedure of the
        quaternion-valued weights. "quaternion" will generate random quaternion-valued
        weights following the init_criterion and the quaternion polar form.
        "unitary" will normalize the weights to lie on the unit circle (default "quaternion").
        More details in: "Deep quaternion Networks", Trabelsi C. et al.
    autograd : bool, optional
        When True, the default PyTorch autograd will be used. When False, a
        custom backpropagation will be used, reducing by a factor 3 to 4 the
        memory consumption. It is also 2x slower (default True).

    Example
    -------
    >>> inp_tensor = torch.rand([10, 16, 40])
    >>> rnn = QLiGRU(input_shape=inp_tensor.shape, hidden_size=16)
    >>> out_tensor = rnn(inp_tensor)
    >>>
    torch.Size([4, 10, 5])
    """

    def __init__(
        self,
        hidden_size,
        input_shape,
        nonlinearity="leaky_relu",
        num_layers=1,
        bias=True,
        dropout=0.0,
        bidirectional=False,
        init_criterion="glorot",
        weight_init="quaternion",
        autograd=True,
    ):
        super().__init__()
        self.hidden_size = hidden_size * 4  # q = x + iy + jz + kw
        self.nonlinearity = nonlinearity
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.reshape = False
        self.init_criterion = init_criterion
        self.weight_init = weight_init
        self.autograd = autograd

        if len(input_shape) > 3:
            self.reshape = True

        self.fea_dim = torch.prod(torch.tensor(input_shape[2:]))
        self.batch_size = input_shape[0]
        self.rnn = self._init_layers()

    def _init_layers(self):
        """
        Initializes the layers of the liGRU.

        Returns
        -------
        rnn : ModuleList
            The initialized QLiGRU_Layers.
        """
        rnn = torch.nn.ModuleList([])
        current_dim = self.fea_dim

        for i in range(self.num_layers):
            rnn_lay = QLiGRU_Layer(
                current_dim,
                self.hidden_size,
                self.num_layers,
                self.batch_size,
                dropout=self.dropout,
                nonlinearity=self.nonlinearity,
                bidirectional=self.bidirectional,
                init_criterion=self.init_criterion,
                weight_init=self.weight_init,
                autograd=self.autograd,
            )
            rnn.append(rnn_lay)

            if self.bidirectional:
                current_dim = self.hidden_size * 2
            else:
                current_dim = self.hidden_size
        return rnn

    def forward(self, x, hx: Optional[torch.Tensor] = None):
        """Returns the output of the QuaternionliGRU.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        output : torch.Tensor
        hh : torch.Tensor
        """
        # Reshaping input tensors for 4d inputs
        if self.reshape:
            if x.ndim == 4:
                x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])

        # run ligru
        output, hh = self._forward_ligru(x, hx=hx)

        return output, hh

    def _forward_ligru(self, x, hx: Optional[torch.Tensor]):
        """Returns the output of the quaternionliGRU.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor
            Hidden layer.

        Returns
        -------
        x : torch.Tensor
            Output
        h : torch.Tensor
            Hidden states
        """
        h = []
        if hx is not None:
            if self.bidirectional:
                hx = hx.reshape(
                    self.num_layers, self.batch_size * 2, self.hidden_size
                )
        # Processing the different layers
        for i, ligru_lay in enumerate(self.rnn):
            if hx is not None:
                x = ligru_lay(x, hx=hx[i])
            else:
                x = ligru_lay(x, hx=None)
            h.append(x[:, -1, :])
        h = torch.stack(h, dim=1)

        if self.bidirectional:
            h = h.reshape(h.shape[1] * 2, h.shape[0], self.hidden_size)
        else:
            h = h.transpose(0, 1)

        return x, h


class QLiGRU_Layer(torch.nn.Module):
    """This function implements quaternion-valued Light-Gated Recurrent Units
    (ligru) layer.

    Arguments
    ---------
    input_size: int
        Feature dimensionality of the input tensors.
    hidden_size: int
        Number of output values.
    num_layers: int
        Number of layers to employ in the RNN architecture.
    batch_size: int
        Batch size of the input tensors.
    dropout: float
        It is the dropout factor (must be between 0 and 1).
    nonlinearity: str
        Type of nonlinearity (tanh, relu).
    normalization: str
        The type of normalization to use (batchnorm or none)
    bidirectional: bool
        If True, a bidirectional model that scans the sequence both
        right-to-left and left-to-right is used.
    init_criterion: str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the quaternion-valued weights (default "glorot").
    weight_init: str, optional
        (quaternion, unitary).
        This parameter defines the initialization procedure of the
        quaternion-valued weights. "quaternion" will generate random quaternion
        weights following the init_criterion and the quaternion polar form.
        "unitary" will normalize the weights to lie on the unit circle (default "quaternion").
        More details in: "Deep quaternion Networks", Trabelsi C. et al.
    autograd: bool, optional
        When True, the default PyTorch autograd will be used. When False, a
        custom backpropagation will be used, reducing by a factor 3 to 4 the
        memory consumption. It is also 2x slower (default True).
    """

    def __init__(
        self,
        input_size,
        hidden_size,
        num_layers,
        batch_size,
        dropout=0.0,
        nonlinearity="leaky_relu",
        normalization="batchnorm",
        bidirectional=False,
        init_criterion="glorot",
        weight_init="quaternion",
        autograd=True,
    ):
        super().__init__()
        self.hidden_size = int(hidden_size) // 4
        self.input_size = int(input_size)
        self.batch_size = batch_size
        self.bidirectional = bidirectional
        self.dropout = dropout
        self.init_criterion = init_criterion
        self.weight_init = weight_init
        self.normalization = normalization
        self.nonlinearity = nonlinearity
        self.autograd = autograd

        self.w = QLinear(
            input_shape=self.input_size,
            n_neurons=self.hidden_size * 2,
            bias=False,
            weight_init=self.weight_init,
            init_criterion=self.init_criterion,
            autograd=self.autograd,
        )

        self.u = QLinear(
            input_shape=self.hidden_size * 4,  # The input size is in real
            n_neurons=self.hidden_size * 2,
            bias=False,
            weight_init=self.weight_init,
            init_criterion=self.init_criterion,
            autograd=self.autograd,
        )

        if self.bidirectional:
            self.batch_size = self.batch_size * 2

        # Initializing batch norm
        self.normalize = False

        if self.normalization == "batchnorm":
            self.norm = QBatchNorm(input_size=hidden_size * 2, dim=-1)
            self.normalize = True
        else:
            # Normalization is disabled here. self.norm is only  formally
            # initialized to avoid jit issues.
            self.norm = QBatchNorm(input_size=hidden_size * 2, dim=-1)
            self.normalize = False

        # Initial state
        self.register_buffer("h_init", torch.zeros(1, self.hidden_size * 4))

        # Preloading dropout masks (gives some speed improvement)
        self._init_drop(self.batch_size)

        # Initializing dropout
        self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)

        self.drop_mask_te = torch.tensor([1.0]).float()

        # Setting the activation function
        if self.nonlinearity == "tanh":
            self.act = torch.nn.Tanh()
        elif self.nonlinearity == "leaky_relu":
            self.act = torch.nn.LeakyReLU()
        else:
            self.act = torch.nn.ReLU()

    def forward(self, x, hx: Optional[torch.Tensor] = None):
        # type: (torch.Tensor, Optional[torch.Tensor]) -> torch.Tensor # noqa F821
        """Returns the output of the quaternion liGRU layer.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor.
        hx : torch.Tensor

        Returns
        -------
        Output of quaternion liGRU layer.
        """

        if self.bidirectional:
            x_flip = x.flip(1)
            x = torch.cat([x, x_flip], dim=0)

        # Change batch size if needed
        self._change_batch_size(x)

        # Feed-forward affine transformations (all steps in parallel)
        w = self.w(x)

        # Apply batch normalization
        if self.normalize:
            w_bn = self.norm(w.reshape(w.shape[0] * w.shape[1], w.shape[2]))
            w = w_bn.reshape(w.shape[0], w.shape[1], w.shape[2])

        # Processing time steps
        if hx is not None:
            h = self._quaternion_ligru_cell(w, hx)
        else:
            h = self._quaternion_ligru_cell(w, self.h_init)

        if self.bidirectional:
            h_f, h_b = h.chunk(2, dim=0)
            h_b = h_b.flip(1)
            h = torch.cat([h_f, h_b], dim=2)

        return h

    def _quaternion_ligru_cell(self, w, ht):
        """Returns the hidden states for each time step.

        Arguments
        ---------
        w : torch.Tensor
            Linearly transformed input.
        ht : torch.Tensor

        Returns
        -------
        h : torch.Tensor
            Hidden states for all steps.
        """

        hiddens = []

        # Sampling dropout mask
        drop_mask = self._sample_drop_mask(w)

        # Loop over time axis
        for k in range(w.shape[1]):
            gates = w[:, k] + self.u(ht)
            atr, ati, atj, atk, ztr, zti, ztj, ztk = gates.chunk(8, 1)
            at = torch.cat([atr, ati, atj, atk], dim=-1)
            zt = torch.cat([ztr, zti, ztj, ztk], dim=-1)
            zt = torch.sigmoid(zt)
            hcand = self.act(at) * drop_mask
            ht = zt * ht + (1 - zt) * hcand
            hiddens.append(ht)

        # Stacking hidden states
        h = torch.stack(hiddens, dim=1)
        return h

    def _init_drop(self, batch_size):
        """Initializes the recurrent dropout operation. To speed it up,
        the dropout masks are sampled in advance.
        """

        self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)
        self.drop_mask_te = torch.tensor([1.0]).float()

        self.N_drop_masks = 16000
        self.drop_mask_cnt = 0

        self.register_buffer(
            "drop_masks",
            self.drop(torch.ones(self.N_drop_masks, self.hidden_size * 4)).data,
        )

    def _sample_drop_mask(self, w):
        """Selects one of the pre-defined dropout masks"""

        if self.training:
            # Sample new masks when needed
            if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
                self.drop_mask_cnt = 0
                self.drop_masks = self.drop(
                    torch.ones(
                        self.N_drop_masks, self.hidden_size * 4, device=w.device
                    )
                ).data

            # Sampling the mask
            drop_mask = self.drop_masks[
                self.drop_mask_cnt : self.drop_mask_cnt + self.batch_size
            ]
            self.drop_mask_cnt = self.drop_mask_cnt + self.batch_size

        else:
            self.drop_mask_te = self.drop_mask_te.to(w.device)
            drop_mask = self.drop_mask_te

        return drop_mask

    def _change_batch_size(self, x):
        """This function changes the batch size when it is different from
        the one detected in the initialization method. This might happen in
        the case of multi-gpu or when we have different batch sizes in train
        and test. We also update the h_int and drop masks.
        """

        if self.batch_size != x.shape[0]:
            self.batch_size = x.shape[0]

            if self.training:
                self.drop_masks = self.drop(
                    torch.ones(
                        self.N_drop_masks, self.hidden_size * 4, device=x.device
                    )
                ).data
