# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from collections import namedtuple
from enum import Enum
from typing import List, Optional, Tuple

import torch
import torch.nn as nn

from xformers._deprecation_warning import deprecated_function


class ResidualNormStyle(str, Enum):
    """Support different residual path and norm styles.
    See "On Layer Normalization in the Transformer Architecture",
    Xiong et al., https://arxiv.org/pdf/2002.04745v1.pdf
    """

    Pre = "pre"
    Post = "post"
    DeepNorm = "deepnorm"


class NormalizationType(str, Enum):
    LayerNorm = "layernorm"
    Skip = "skip"
    # TODO: BatchNorm = "batchnorm"
    # TODO: GroupNorm = "groupnorm"


def get_normalization_layer(normalization_type: NormalizationType):
    class Skip(nn.Module):
        def __init__(self, *_, **__) -> None:
            super().__init__()
            deprecated_function(self)

        def forward(self, x: torch.Tensor, **_):
            return x

    return {
        NormalizationType.LayerNorm: nn.LayerNorm,
        NormalizationType.Skip: Skip,
    }[normalization_type]


class RequiresWrappedInputs:
    """Used to mark, through inheritance,
    the fact that this class will require inputs to be passed as a single list"""

    pass


# CREDITS: the following is inspired by FastAI's Transformer implementation
class Residual(nn.Module, RequiresWrappedInputs):
    """
    Object-oriented handling of the residual path

    This supports scaling of the residual path, as proposed by DeepNet_
    .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf

    .. Note: the wrapped layers must accept all the inputs as a single list
    """

    def __init__(self, layer: nn.Module, scale: Optional[float] = None):
        super().__init__()
        deprecated_function(self)
        self.layer = layer
        self.scale = scale

        # PreNorm and PostNorm require all the tensors to be passed as a list
        self.wrap_inputs = isinstance(layer, RequiresWrappedInputs)

    def forward(self, inputs: List[torch.Tensor], **kwargs):
        if self.scale is not None:
            residue = inputs[0] * self.scale
        else:
            residue = inputs[0]

        if self.wrap_inputs:
            return residue + self.layer(inputs=inputs, **kwargs)

        else:
            return residue + self.layer(*inputs, **kwargs)


class PreNorm(nn.Module, RequiresWrappedInputs):
    """Adds a normalization before computing attention

    ..Note: If a list of inputs is passed, all of them get normalized"""

    def __init__(
        self,
        d_norm: int,
        sublayer: nn.Module,
        normalization: NormalizationType,
        use_triton: bool = True,
    ):

        super().__init__()
        deprecated_function(self)
        self.norm = get_normalization_layer(normalization)(d_norm)

        self.sublayer = sublayer
        self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)

    def forward(self, inputs: List[torch.Tensor], **kwargs):
        assert len(inputs) > 0

        # Perf improvement: if the inputs are all the same, only norm once
        ids = [id(x) for x in inputs]
        if ids.count(ids[0]) == len(ids):
            # The same tensor is passed multiple times
            x_norm = self.norm(inputs[0])
            inputs_normed = [x_norm for _ in inputs]
        else:
            # The inputs differ, norm them all
            inputs_normed = [self.norm(x_) for x_ in inputs]

        if self.wrap_inputs:
            return self.sublayer(inputs=inputs_normed, **kwargs)
        else:
            return self.sublayer(*inputs_normed, **kwargs)


class PostNorm(nn.Module, RequiresWrappedInputs):
    """Adds LayerNorm after computing attention"""

    def __init__(
        self,
        d_norm: int,
        sublayer: nn.Module,
        normalization: NormalizationType,
        use_triton: bool = True,
    ):
        super().__init__()
        deprecated_function(self)
        self.norm = get_normalization_layer(normalization)(d_norm)

        self.sublayer = sublayer
        self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)

    def forward(self, inputs: List[torch.Tensor], **kwargs):
        if self.wrap_inputs:
            x = self.sublayer(inputs=inputs, **kwargs)
        else:
            x = self.sublayer(*inputs, **kwargs)
        return self.norm(x)


DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])


def get_deepnorm_coefficients(
    encoder_layers: int, decoder_layers: int
) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]:
    """
    See DeepNet_.

    Returns alpha and beta depending on the number of encoder and decoder layers,
    first tuple is for the encoder and second for the decoder

    .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
    """

    N = encoder_layers
    M = decoder_layers

    if decoder_layers == 0:
        # Encoder only
        return (
            DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25),
            None,
        )

    elif encoder_layers == 0:
        # Decoder only
        return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25)
    else:
        # Encoder/decoder
        encoder_coeffs = DeepNormCoefficients(
            alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625
        )

        decoder_coeffs = DeepNormCoefficients(
            alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25
        )

        return (encoder_coeffs, decoder_coeffs)
