"""Library implementing quaternion-valued linear transformation.

Authors
 * Titouan Parcollet 2020
 * Drew Wagner 2024
"""

import torch

from speechbrain.nnet.quaternion_networks.q_ops import (
    QuaternionLinearCustomBackward,
    affect_init,
    check_quaternion_input,
    quaternion_init,
    quaternion_linear_op,
    quaternion_linear_rotation_op,
    renorm_quaternion_weights_inplace,
    unitary_init,
)
from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


class QLinear(torch.nn.Module):
    """This function implements a fully connected quaternion-valued
    linear layer: y = Wx + b. y, W, x and b are thus quaternion
    numbers. A quaternion number is written as: r + xi + yj + zk.
    A tensor of quaternion numbers x = [batch, 32] can be understood as
    [batch, 0:7] = R, [batch, 8:15] = Xi, [batch, 16:23] = Yi, and
    [batch, 24:31] = Xi. Thus the features dimension is cut in four
    (must be divisible by 4).

    Arguments
    ---------
    n_neurons : int
        It is the number of output neurons (i.e, the dimensionality of the
        output). Please note that these are quaternion-valued neurons. If 256
        neurons are specified, the output dimension will be 1024.
    input_shape : tuple
        Expected size of the input.
    bias : bool
        If True, the additive bias b is adopted.
    init_criterion : str , optional
        (glorot, he).
        This parameter controls the initialization criterion of the weights.
        It is combined with weights_init to build the initialization method of
        the quaternion-valued weights (default "glorot").
    weight_init : str, optional
        (quaternion, unitary).
        This parameter defines the initialization procedure of the
        quaternion-valued weights. "quaternion" will generate quaternion-valued
        weights following the init_criterion and the quaternion  polar form.
        "unitary" will normalize the weights to lie on the unit circle (default "quaternion").
        More details in: "Quaternion recurrent neural networks", Parcollet T.
    autograd : bool, optional
        When True, the default PyTorch autograd will be used. When False, a
        custom backpropagation will be used, reducing by a factor 3 to 4 the
        memory consumption. It is also 2x slower. This only works with
        spinor = False (default True).
    spinor : bool, optional
        When True, the layer will be turned into a spinor layer. More precisely
        W*x will be turned into W*x*W-1. The input x will be rotated by W such
        as in a spinor neural network. However, x MUST be a quaternion with
        the real part equal to zero. (0 + xi + yj + zk). Indeed, the rotation
        operation only acts on the vector part. Note that W will always be
        normalized before the rotation to ensure the quaternion algebra (default False).
        More details in: "Quaternion neural networks", Parcollet T.
    vector_scale : bool, optional
        The vector_scale is only used when spinor = True. In the context of a
        spinor neural network, multiple rotations of the input vector x are
        performed and summed. Hence, the norm of the output vector always
        increases with the number of layers, making the neural network instable
        with deep configurations. The vector_scale parameters are learnable
        parameters that acts like gates by multiplying the output vector with
        a small trainable parameter (default False).
    max_norm: float
        weight max-norm.

    Example
    -------
    >>> inputs = torch.rand(10, 50, 40)
    >>> lin = QLinear(n_neurons=100, input_shape=inputs.shape, weight_init='unitary')
    >>> output = lin(inputs)
    >>> output.shape
    torch.Size([10, 50, 400])
    """

    def __init__(
        self,
        n_neurons,
        input_shape,
        bias=True,
        init_criterion="glorot",
        weight_init="quaternion",
        autograd=True,
        spinor=False,
        vector_scale=False,
        max_norm=None,
    ):
        super().__init__()
        self.n_neurons = n_neurons
        self.init_criterion = init_criterion
        self.weight_init = weight_init
        self.autograd = autograd
        self.spinor = spinor
        self.vector_scale = vector_scale
        self.max_norm = max_norm

        # When initialising with speechbrain the input_shape is an integer !
        # we need to transform it into a list it works with all the question ops
        if isinstance(input_shape, int):
            input_shape = [1, input_shape]

        # Check the quaternion_valued form of the input
        check_quaternion_input(input_shape)

        # Computing the quaternion dimensionality of the input
        self.in_features = input_shape[-1] // 4
        self.out_features = self.n_neurons

        # Defining the weights
        self.r_weight = torch.nn.Parameter(
            torch.Tensor(self.in_features, self.out_features)
        )
        self.i_weight = torch.nn.Parameter(
            torch.Tensor(self.in_features, self.out_features)
        )
        self.j_weight = torch.nn.Parameter(
            torch.Tensor(self.in_features, self.out_features)
        )
        self.k_weight = torch.nn.Parameter(
            torch.Tensor(self.in_features, self.out_features)
        )

        # Spinor specific parameters
        if self.spinor:
            self.zero_kernel = torch.nn.Parameter(
                torch.zeros(self.r_weight.shape), requires_grad=False
            )
        else:
            self.zero_kernel = torch.Tensor(self.r_weight.shape).requires_grad_(
                False
            )

        if self.spinor and self.vector_scale:
            self.scale_param = torch.nn.Parameter(
                torch.Tensor(self.in_features, self.out_features)
            )
            torch.nn.init.xavier_uniform_(self.scale_param.data)
        else:
            self.scale_param = torch.Tensor(
                self.in_features, self.out_features
            ).requires_grad_(False)

        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(4 * n_neurons))
        else:
            self.bias = torch.Tensor(4 * n_neurons).requires_grad_(False)
        self.bias.data.fill_(0)

        # Managing the weight initialization and bias
        self.winit = {"quaternion": quaternion_init, "unitary": unitary_init}[
            self.weight_init
        ]

        # Initialise the weights
        affect_init(
            self.r_weight,
            self.i_weight,
            self.j_weight,
            self.k_weight,
            self.winit,
            init_criterion,
        )

    @torch.jit.ignore
    def forward(self, x):
        """Returns the linear transformation of input tensor.

        Arguments
        ---------
        x : torch.Tensor
            Input to transform linearly.

        Returns
        -------
        The linearly transformed input.
        """

        if self.max_norm is not None:
            renorm_quaternion_weights_inplace(
                self.r_weight,
                self.i_weight,
                self.j_weight,
                self.k_weight,
                max_norm=self.max_norm,
            )

        if self.autograd:
            if self.spinor:
                out = quaternion_linear_rotation_op(
                    x,
                    self.r_weight,
                    self.i_weight,
                    self.j_weight,
                    self.k_weight,
                    self.bias,
                    self.scale_param,
                    self.zero_kernel,
                )
            else:
                out = quaternion_linear_op(
                    x,
                    self.r_weight,
                    self.i_weight,
                    self.j_weight,
                    self.k_weight,
                    self.bias,
                )
        else:
            # The custom backward needs an input with 2D at most!
            input_dim = x.dim()
            if input_dim == 3:
                batch, time, fea = x.size()
                x = x.view(batch * time, fea)

            out = QuaternionLinearCustomBackward.apply(
                x,
                self.r_weight,
                self.i_weight,
                self.j_weight,
                self.k_weight,
                self.bias,
            )

            if input_dim == 3:
                out = out.view(batch, time, out.size(-1))

        return out
