"""Library to support dual-path speech separation.

Authors
 * Cem Subakan 2020
 * Mirco Ravanelli 2020
 * Samuele Cornell 2020
 * Mirko Bronzi 2020
 * Jianyuan Zhong 2020
"""

import copy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import speechbrain.nnet.RNN as SBRNN
from speechbrain.lobes.models.transformer.Conformer import ConformerEncoder
from speechbrain.lobes.models.transformer.Transformer import (
    PositionalEncoding,
    TransformerEncoder,
)
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.linear import Linear

EPS = 1e-8


class GlobalLayerNorm(nn.Module):
    """Calculate Global Layer Normalization.

    Arguments
    ---------
    dim : (int or list or torch.Size)
        Input shape from an expected input of size.
    shape : tuple
        Expected shape of the input.
    eps : float
        A value added to the denominator for numerical stability.
    elementwise_affine : bool
        A boolean value that when set to True,
        this module has learnable per-element affine parameters
        initialized to ones (for weights) and zeros (for biases).

    Example
    -------
    >>> x = torch.randn(5, 10, 20)
    >>> GLN = GlobalLayerNorm(10, 3)
    >>> x_norm = GLN(x)
    """

    def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            if shape == 3:
                self.weight = nn.Parameter(torch.ones(self.dim, 1))
                self.bias = nn.Parameter(torch.zeros(self.dim, 1))
            if shape == 4:
                self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
                self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)

    def forward(self, x):
        """Returns the normalized tensor.

        Arguments
        ---------
        x : torch.Tensor
            Tensor of size [N, C, K, S] or [N, C, L].

        Returns
        -------
        out : torch.Tensor
            The normalized outputs.
        """
        # x = N x C x K x S or N x C x L
        # N x 1 x 1
        # cln: mean,var N x 1 x K x S
        # gln: mean,var N x 1 x 1
        if x.dim() == 3:
            mean = torch.mean(x, (1, 2), keepdim=True)
            var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
            if self.elementwise_affine:
                x = (
                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
                    + self.bias
                )
            else:
                x = (x - mean) / torch.sqrt(var + self.eps)

        if x.dim() == 4:
            mean = torch.mean(x, (1, 2, 3), keepdim=True)
            var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
            if self.elementwise_affine:
                x = (
                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
                    + self.bias
                )
            else:
                x = (x - mean) / torch.sqrt(var + self.eps)
        return x


class CumulativeLayerNorm(nn.LayerNorm):
    """Calculate Cumulative Layer Normalization.

    Arguments
    ---------
    dim : int
        Dimension that you want to normalize.
    elementwise_affine : bool
        Learnable per-element affine parameters.
    eps : float
        A small value to prevent overflow.

    Example
    -------
    >>> x = torch.randn(5, 10, 20)
    >>> CLN = CumulativeLayerNorm(10)
    >>> x_norm = CLN(x)
    """

    def __init__(self, dim, elementwise_affine=True, eps=1e-8):
        super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)

    def forward(self, x):
        """Returns the normalized tensor.

        Arguments
        ---------
        x : torch.Tensor
            torch.Tensor size [N, C, K, S] or [N, C, L]

        Returns
        -------
        out : torch.Tensor
            The normalized outputs.
        """
        # x: N x C x K x S or N x C x L
        # N x K x S x C
        if x.dim() == 4:
            x = x.permute(0, 2, 3, 1).contiguous()
            # N x K x S x C == only channel norm
            x = super().forward(x)
            # N x C x K x S
            x = x.permute(0, 3, 1, 2).contiguous()
        if x.dim() == 3:
            x = torch.transpose(x, 1, 2)
            # N x L x C == only channel norm
            x = super().forward(x)
            # N x C x L
            x = torch.transpose(x, 1, 2)
        return x


def select_norm(norm, dim, shape, eps=1e-8):
    """Just a wrapper to select the normalization type."""

    if norm == "gln":
        return GlobalLayerNorm(dim, shape, elementwise_affine=True, eps=eps)
    if norm == "cln":
        return CumulativeLayerNorm(dim, elementwise_affine=True, eps=eps)
    if norm == "ln":
        return nn.GroupNorm(1, dim, eps=eps)
    else:
        return nn.BatchNorm1d(dim)


class Encoder(nn.Module):
    """Convolutional Encoder Layer.

    Arguments
    ---------
    kernel_size : int
        Length of filters.
    out_channels : int
        Number of output channels.
    in_channels : int
        Number of  input channels.

    Example
    -------
    >>> x = torch.randn(2, 1000)
    >>> encoder = Encoder(kernel_size=4, out_channels=64)
    >>> h = encoder(x)
    >>> h.shape
    torch.Size([2, 64, 499])
    """

    def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
        super().__init__()
        self.conv1d = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=kernel_size // 2,
            groups=1,
            bias=False,
        )
        self.in_channels = in_channels

    def forward(self, x):
        """Return the encoded output.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor with dimensionality [B, L].

        Returns
        -------
        x : torch.Tensor
            Encoded tensor with dimensionality [B, N, T_out].
            where B = Batchsize
                  L = Number of timepoints
                  N = Number of filters
                  T_out = Number of timepoints at the output of the encoder
        """
        # B x L -> B x 1 x L
        if self.in_channels == 1:
            x = torch.unsqueeze(x, dim=1)
        # B x 1 x L -> B x N x T_out
        x = self.conv1d(x)
        x = F.relu(x)

        return x


class Decoder(nn.ConvTranspose1d):
    """A decoder layer that consists of ConvTranspose1d.

    Arguments
    ---------
    *args : tuple
    **kwargs : dict
        Arguments passed through to nn.ConvTranspose1d

    Example
    -------
    >>> x = torch.randn(2, 100, 1000)
    >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
    >>> h = decoder(x)
    >>> h.shape
    torch.Size([2, 1003])
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        """Return the decoded output.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor with dimensionality [B, N, L].
                where, B = Batchsize,
                       N = number of filters
                       L = time points

        Returns
        -------
        out : torch.Tensor
            The decoded outputs.
        """

        if x.dim() not in [2, 3]:
            raise RuntimeError(
                "{} accept 3/4D tensor as input".format(self.__name__)
            )
        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))

        if torch.squeeze(x).dim() == 1:
            x = torch.squeeze(x, dim=1)
        else:
            x = torch.squeeze(x)
        return x


class IdentityBlock:
    """This block is used when we want to have identity transformation within the Dual_path block.

    Arguments
    ---------
    **kwargs : dict
        Arguments are ignored.

    Example
    -------
    >>> x = torch.randn(10, 100)
    >>> IB = IdentityBlock()
    >>> xhat = IB(x)
    """

    def _init__(self, **kwargs):
        pass

    def __call__(self, x):
        return x


class FastTransformerBlock(nn.Module):
    """This block is used to implement fast transformer models with efficient attention.

    The implementations are taken from https://fast-transformers.github.io/

    Arguments
    ---------
    attention_type : str
        Specifies the type of attention.
        Check https://fast-transformers.github.io/  for details.
    out_channels : int
        Dimensionality of the representation.
    num_layers : int
        Number of layers.
    nhead : int
        Number of attention heads.
    d_ffn : int
        Dimensionality of positional feed-forward.
    dropout : float
        Dropout drop rate.
    activation : str
        Activation function.
    reformer_bucket_size : int
        bucket size for reformer.

    Example
    -------
    # >>> x = torch.randn(10, 100, 64)
    # >>> block = FastTransformerBlock('linear', 64)
    # >>> x = block(x)
    # >>> x.shape
    # torch.Size([10, 100, 64])
    """

    def __init__(
        self,
        attention_type,
        out_channels,
        num_layers=6,
        nhead=8,
        d_ffn=1024,
        dropout=0,
        activation="relu",
        reformer_bucket_size=32,
    ):
        super().__init__()
        from fast_transformers.builders import TransformerEncoderBuilder

        builder = TransformerEncoderBuilder.from_kwargs(
            attention_type=attention_type,
            n_layers=num_layers,
            n_heads=nhead,
            feed_forward_dimensions=d_ffn,
            query_dimensions=out_channels // nhead,
            value_dimensions=out_channels // nhead,
            dropout=dropout,
            attention_dropout=dropout,
            chunk_size=reformer_bucket_size,
        )
        self.mdl = builder.get()

        self.attention_type = attention_type
        self.reformer_bucket_size = reformer_bucket_size

    def forward(self, x):
        """Returns the transformed input.

        Arguments
        ---------
        x : torch.Tensor
            Tensor shaper [B, L, N].
            where, B = Batchsize,
                   N = number of filters
                   L = time points

        Returns
        -------
        out : torch.Tensor
            The transformed outputs.
        """
        if self.attention_type == "reformer":
            # pad zeros at the end
            pad_size = (self.reformer_bucket_size * 2) - (
                x.shape[1] % (self.reformer_bucket_size * 2)
            )
            device = x.device
            x_padded = torch.cat(
                [x, torch.zeros(x.size(0), pad_size, x.size(-1)).to(device)],
                dim=1,
            )

            # apply the model
            x_padded = self.mdl(x_padded)

            # get rid of zeros at the end
            return x_padded[:, :-pad_size, :]
        else:
            return self.mdl(x)


class PyTorchPositionalEncoding(nn.Module):
    """Positional encoder for the pytorch transformer.

    Arguments
    ---------
    d_model : int
        Representation dimensionality.
    dropout : float
        Dropout drop prob.
    max_len : int
        Max sequence length.

    Example
    -------
    >>> x = torch.randn(10, 100, 64)
    >>> enc = PyTorchPositionalEncoding(64)
    >>> x = enc(x)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """Returns the encoded output.

        Arguments
        ---------
        x : torch.Tensor
            Tensor shape [B, L, N],
            where, B = Batchsize,
                   N = number of filters
                   L = time points

        Returns
        -------
        out : torch.Tensor
            The encoded output.
        """
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)


class PytorchTransformerBlock(nn.Module):
    """A wrapper that uses the pytorch transformer block.

    Arguments
    ---------
    out_channels : int
        Dimensionality of the representation.
    num_layers : int
        Number of layers.
    nhead : int
        Number of attention heads.
    d_ffn : int
        Dimensionality of positional feed forward.
    dropout : float
        Dropout drop rate.
    activation : str
        Activation function.
    use_positional_encoding : bool
        If true we use a positional encoding.

    Example
    -------
    >>> x = torch.randn(10, 100, 64)
    >>> block = PytorchTransformerBlock(64)
    >>> x = block(x)
    >>> x.shape
    torch.Size([10, 100, 64])
    """

    def __init__(
        self,
        out_channels,
        num_layers=6,
        nhead=8,
        d_ffn=2048,
        dropout=0.1,
        activation="relu",
        use_positional_encoding=True,
    ):
        super().__init__()

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=out_channels,
            nhead=nhead,
            dim_feedforward=d_ffn,
            dropout=dropout,
            activation=activation,
        )
        # cem :this encoder thing has a normalization component. we should look at that probably also.
        self.mdl = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        if use_positional_encoding:
            self.pos_encoder = PyTorchPositionalEncoding(out_channels)
        else:
            self.pos_encoder = None

    def forward(self, x):
        """Returns the transformed output.

        Arguments
        ---------
        x : torch.Tensor
            Tensor shape [B, L, N]
            where, B = Batchsize,
                   N = number of filters
                   L = time points

        Returns
        -------
        out : torch.Tensor
            The transformed output.
        """
        if self.pos_encoder is not None:
            x = self.pos_encoder(x)
        return self.mdl(x)


class SBTransformerBlock(nn.Module):
    """A wrapper for the SpeechBrain implementation of the transformer encoder.

    Arguments
    ---------
    num_layers : int
        Number of layers.
    d_model : int
        Dimensionality of the representation.
    nhead : int
        Number of attention heads.
    d_ffn : int
        Dimensionality of positional feed forward.
    input_shape : tuple
        Shape of input.
    kdim : int
        Dimension of the key (Optional).
    vdim : int
        Dimension of the value (Optional).
    dropout : float
        Dropout rate.
    activation : str
        Activation function.
    use_positional_encoding : bool
        If true we use a positional encoding.
    norm_before : bool
        Use normalization before transformations.
    attention_type : str
        Type of attention to use, default "regularMHA"

    Example
    -------
    >>> x = torch.randn(10, 100, 64)
    >>> block = SBTransformerBlock(1, 64, 8)
    >>> x = block(x)
    >>> x.shape
    torch.Size([10, 100, 64])
    """

    def __init__(
        self,
        num_layers,
        d_model,
        nhead,
        d_ffn=2048,
        input_shape=None,
        kdim=None,
        vdim=None,
        dropout=0.1,
        activation="relu",
        use_positional_encoding=False,
        norm_before=False,
        attention_type="regularMHA",
    ):
        super().__init__()
        self.use_positional_encoding = use_positional_encoding

        if activation == "relu":
            activation = nn.ReLU
        elif activation == "gelu":
            activation = nn.GELU
        else:
            raise ValueError("unknown activation")

        self.mdl = TransformerEncoder(
            num_layers=num_layers,
            nhead=nhead,
            d_ffn=d_ffn,
            input_shape=input_shape,
            d_model=d_model,
            kdim=kdim,
            vdim=vdim,
            dropout=dropout,
            activation=activation,
            normalize_before=norm_before,
            attention_type=attention_type,
        )

        if use_positional_encoding:
            self.pos_enc = PositionalEncoding(input_size=d_model)

    def forward(self, x):
        """Returns the transformed output.

        Arguments
        ---------
        x : torch.Tensor
            Tensor shape [B, L, N],
            where, B = Batchsize,
                   L = time points
                   N = number of filters

        Returns
        -------
        out : torch.Tensor
            The transformed output.
        """
        if self.use_positional_encoding:
            pos_enc = self.pos_enc(x)
            return self.mdl(x + pos_enc)[0]
        else:
            return self.mdl(x)[0]


class SBRNNBlock(nn.Module):
    """RNNBlock for the dual path pipeline.

    Arguments
    ---------
    input_size : int
        Dimensionality of the input features.
    hidden_channels : int
        Dimensionality of the latent layer of the rnn.
    num_layers : int
        Number of the rnn layers.
    rnn_type : str
        Type of the the rnn cell.
    dropout : float
        Dropout rate
    bidirectional : bool
        If True, bidirectional.

    Example
    -------
    >>> x = torch.randn(10, 100, 64)
    >>> rnn = SBRNNBlock(64, 100, 1, bidirectional=True)
    >>> x = rnn(x)
    >>> x.shape
    torch.Size([10, 100, 200])
    """

    def __init__(
        self,
        input_size,
        hidden_channels,
        num_layers,
        rnn_type="LSTM",
        dropout=0,
        bidirectional=True,
    ):
        super().__init__()

        self.mdl = getattr(SBRNN, rnn_type)(
            hidden_channels,
            input_size=input_size,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=bidirectional,
        )

    def forward(self, x):
        """Returns the transformed output.

        Arguments
        ---------
        x : torch.Tensor
            [B, L, N]
            where, B = Batchsize,
                   N = number of filters
                   L = time points

        Returns
        -------
        out : torch.Tensor
            The transformed output.
        """

        return self.mdl(x)[0]


class DPTNetBlock(nn.Module):
    """The DPT Net block.

    Arguments
    ---------
    d_model : int
        Number of expected features in the input (required).
    nhead : int
        Number of heads in the multiheadattention models (required).
    dim_feedforward : int
        Dimension of the feedforward network model (default=2048).
    dropout : float
        Dropout value (default=0.1).
    activation : str
        Activation function of intermediate layer, relu or gelu (default=relu).

    Examples
    --------
    >>> encoder_layer = DPTNetBlock(d_model=512, nhead=8)
    >>> src = torch.rand(10, 100, 512)
    >>> out = encoder_layer(src)
    >>> out.shape
    torch.Size([10, 100, 512])
    """

    def __init__(
        self, d_model, nhead, dim_feedforward=256, dropout=0, activation="relu"
    ):
        from torch.nn.modules.activation import MultiheadAttention
        from torch.nn.modules.dropout import Dropout
        from torch.nn.modules.linear import Linear
        from torch.nn.modules.normalization import LayerNorm
        from torch.nn.modules.rnn import LSTM

        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        # self.linear1 = Linear(d_model, dim_feedforward)
        self.rnn = LSTM(d_model, d_model * 2, 1, bidirectional=True)
        self.dropout = Dropout(dropout)
        # self.linear2 = Linear(dim_feedforward, d_model)
        self.linear2 = Linear(d_model * 2 * 2, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if "activation" not in state:
            state["activation"] = F.relu
        super().__setstate__(state)

    def forward(self, src):
        """Pass the input through the encoder layer.

        Arguments
        ---------
        src : torch.Tensor
            Tensor shape [B, L, N]
            where, B = Batchsize,
                   N = number of filters
                   L = time points

        Returns
        -------
        Encoded outputs.
        """
        src2 = self.self_attn(
            src, src, src, attn_mask=None, key_padding_mask=None
        )[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        # src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src2 = self.rnn(src)[0]
        src2 = self.activation(src2)
        src2 = self.dropout(src2)
        src2 = self.linear2(src2)
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


def _get_activation_fn(activation):
    """Just a wrapper to get the activation functions."""

    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu


class Dual_Computation_Block(nn.Module):
    """Computation block for dual-path processing.

    Arguments
    ---------
    intra_mdl : torch.nn.module
        Model to process within the chunks.
    inter_mdl : torch.nn.module
        Model to process across the chunks.
    out_channels : int
        Dimensionality of inter/intra model.
    norm : str
        Normalization type.
    skip_around_intra : bool
        Skip connection around the intra layer.
    linear_layer_after_inter_intra : bool
        Linear layer or not after inter or intra.

    Example
    -------
    >>> intra_block = SBTransformerBlock(1, 64, 8)
    >>> inter_block = SBTransformerBlock(1, 64, 8)
    >>> dual_comp_block = Dual_Computation_Block(intra_block, inter_block, 64)
    >>> x = torch.randn(10, 64, 100, 10)
    >>> x = dual_comp_block(x)
    >>> x.shape
    torch.Size([10, 64, 100, 10])
    """

    def __init__(
        self,
        intra_mdl,
        inter_mdl,
        out_channels,
        norm="ln",
        skip_around_intra=True,
        linear_layer_after_inter_intra=True,
    ):
        super().__init__()

        self.intra_mdl = intra_mdl
        self.inter_mdl = inter_mdl
        self.skip_around_intra = skip_around_intra
        self.linear_layer_after_inter_intra = linear_layer_after_inter_intra

        # Norm
        self.norm = norm
        if norm is not None:
            self.intra_norm = select_norm(norm, out_channels, 4)
            self.inter_norm = select_norm(norm, out_channels, 4)

        # Linear
        if linear_layer_after_inter_intra:
            if isinstance(intra_mdl, SBRNNBlock):
                self.intra_linear = Linear(
                    out_channels, input_size=2 * intra_mdl.mdl.rnn.hidden_size
                )
            else:
                self.intra_linear = Linear(
                    out_channels, input_size=out_channels
                )

            if isinstance(inter_mdl, SBRNNBlock):
                self.inter_linear = Linear(
                    out_channels, input_size=2 * intra_mdl.mdl.rnn.hidden_size
                )
            else:
                self.inter_linear = Linear(
                    out_channels, input_size=out_channels
                )

    def forward(self, x):
        """Returns the output tensor.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor of dimension [B, N, K, S].

        Returns
        -------
        out: torch.Tensor
            Output tensor of dimension [B, N, K, S].
            where, B = Batchsize,
               N = number of filters
               K = time points in each chunk
               S = the number of chunks
        """
        B, N, K, S = x.shape
        # intra RNN
        # [BS, K, N]
        intra = x.permute(0, 3, 2, 1).contiguous().view(B * S, K, N)
        # [BS, K, H]

        intra = self.intra_mdl(intra)

        # [BS, K, N]
        if self.linear_layer_after_inter_intra:
            intra = self.intra_linear(intra)

        # [B, S, K, N]
        intra = intra.view(B, S, K, N)
        # [B, N, K, S]
        intra = intra.permute(0, 3, 2, 1).contiguous()
        if self.norm is not None:
            intra = self.intra_norm(intra)

        # [B, N, K, S]
        if self.skip_around_intra:
            intra = intra + x

        # inter RNN
        # [BK, S, N]
        inter = intra.permute(0, 2, 3, 1).contiguous().view(B * K, S, N)
        # [BK, S, H]
        inter = self.inter_mdl(inter)

        # [BK, S, N]
        if self.linear_layer_after_inter_intra:
            inter = self.inter_linear(inter)

        # [B, K, S, N]
        inter = inter.view(B, K, S, N)
        # [B, N, K, S]
        inter = inter.permute(0, 3, 1, 2).contiguous()
        if self.norm is not None:
            inter = self.inter_norm(inter)
        # [B, N, K, S]
        out = inter + intra

        return out


class Dual_Path_Model(nn.Module):
    """The dual path model which is the basis for dualpathrnn, sepformer, dptnet.

    Arguments
    ---------
    in_channels : int
        Number of channels at the output of the encoder.
    out_channels : int
        Number of channels that would be inputted to the intra and inter blocks.
    intra_model : torch.nn.module
        Model to process within the chunks.
    inter_model : torch.nn.module
        model to process across the chunks,
    num_layers : int
        Number of layers of Dual Computation Block.
    norm : str
        Normalization type.
    K : int
        Chunk length.
    num_spks : int
        Number of sources (speakers).
    skip_around_intra : bool
        Skip connection around intra.
    linear_layer_after_inter_intra : bool
        Linear layer after inter and intra.
    use_global_pos_enc : bool
        Global positional encodings.
    max_length : int
        Maximum sequence length.

    Example
    -------
    >>> intra_block = SBTransformerBlock(1, 64, 8)
    >>> inter_block = SBTransformerBlock(1, 64, 8)
    >>> dual_path_model = Dual_Path_Model(64, 64, intra_block, inter_block, num_spks=2)
    >>> x = torch.randn(10, 64, 2000)
    >>> x = dual_path_model(x)
    >>> x.shape
    torch.Size([2, 10, 64, 2000])
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        intra_model,
        inter_model,
        num_layers=1,
        norm="ln",
        K=200,
        num_spks=2,
        skip_around_intra=True,
        linear_layer_after_inter_intra=True,
        use_global_pos_enc=False,
        max_length=20000,
    ):
        super().__init__()
        self.K = K
        self.num_spks = num_spks
        self.num_layers = num_layers
        self.norm = select_norm(norm, in_channels, 3)
        self.conv1d = nn.Conv1d(in_channels, out_channels, 1, bias=False)
        self.use_global_pos_enc = use_global_pos_enc

        if self.use_global_pos_enc:
            self.pos_enc = PositionalEncoding(max_length)

        self.dual_mdl = nn.ModuleList([])
        for i in range(num_layers):
            self.dual_mdl.append(
                copy.deepcopy(
                    Dual_Computation_Block(
                        intra_model,
                        inter_model,
                        out_channels,
                        norm,
                        skip_around_intra=skip_around_intra,
                        linear_layer_after_inter_intra=linear_layer_after_inter_intra,
                    )
                )
            )

        self.conv2d = nn.Conv2d(
            out_channels, out_channels * num_spks, kernel_size=1
        )
        self.end_conv1x1 = nn.Conv1d(out_channels, in_channels, 1, bias=False)
        self.prelu = nn.PReLU()
        self.activation = nn.ReLU()
        # gated output layer
        self.output = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
        )
        self.output_gate = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
        )

    def forward(self, x):
        """Returns the output tensor.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor of dimension [B, N, L].

        Returns
        -------
        out : torch.Tensor
            Output tensor of dimension [spks, B, N, L]
            where, spks = Number of speakers
               B = Batchsize,
               N = number of filters
               L = the number of time points
        """

        # before each line we indicate the shape after executing the line

        # [B, N, L]
        x = self.norm(x)

        # [B, N, L]
        x = self.conv1d(x)
        if self.use_global_pos_enc:
            x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
                x.size(1) ** 0.5
            )

        # [B, N, K, S]
        x, gap = self._Segmentation(x, self.K)

        # [B, N, K, S]
        for i in range(self.num_layers):
            x = self.dual_mdl[i](x)
        x = self.prelu(x)

        # [B, N*spks, K, S]
        x = self.conv2d(x)
        B, _, K, S = x.shape

        # [B*spks, N, K, S]
        x = x.view(B * self.num_spks, -1, K, S)

        # [B*spks, N, L]
        x = self._over_add(x, gap)
        x = self.output(x) * self.output_gate(x)

        # [B*spks, N, L]
        x = self.end_conv1x1(x)

        # [B, spks, N, L]
        _, N, L = x.shape
        x = x.view(B, self.num_spks, N, L)
        x = self.activation(x)

        # [spks, B, N, L]
        x = x.transpose(0, 1)

        return x

    def _padding(self, input, K):
        """Padding the audio times.

        Arguments
        ---------
        input : torch.Tensor
            Tensor of size [B, N, L].
            where, B = Batchsize,
                   N = number of filters
                   L = time points
        K : int
            Chunks of length.

        Returns
        -------
        output : torch.Tensor
            Padded inputs
        gap : int
            Size of padding
        """
        B, N, L = input.shape
        P = K // 2
        gap = K - (P + L % K) % K
        if gap > 0:
            pad = (
                torch.Tensor(torch.zeros(B, N, gap))
                .type(input.dtype)
                .to(input.device)
            )
            input = torch.cat([input, pad], dim=2)

        _pad = (
            torch.Tensor(torch.zeros(B, N, P))
            .type(input.dtype)
            .to(input.device)
        )
        input = torch.cat([_pad, input, _pad], dim=2)

        return input, gap

    def _Segmentation(self, input, K):
        """The segmentation stage splits

        Arguments
        ---------
        input : torch.Tensor
            Tensor with dim [B, N, L].
        K : int
            Length of the chunks.

        Return
        ------
        output : torch.Tensor
            Tensor with dim [B, N, K, S].
            where, B = Batchsize,
               N = number of filters
               K = time points in each chunk
               S = the number of chunks
               L = the number of time points
        gap : int
            Size of padding
        """
        B, N, L = input.shape
        P = K // 2
        input, gap = self._padding(input, K)
        # [B, N, K, S]
        input1 = input[:, :, :-P].contiguous().view(B, N, -1, K)
        input2 = input[:, :, P:].contiguous().view(B, N, -1, K)
        input = (
            torch.cat([input1, input2], dim=3).view(B, N, -1, K).transpose(2, 3)
        )

        return input.contiguous(), gap

    def _over_add(self, input, gap):
        """Merge the sequence with the overlap-and-add method.

        Arguments
        ---------
        input : torch.Tensor
            Tensor with dim [B, N, K, S].
        gap : int
            Padding length.

        Return
        ------
        output : torch.Tensor
            Tensor with dim [B, N, L].
            where, B = Batchsize,
               N = number of filters
               K = time points in each chunk
               S = the number of chunks
               L = the number of time points
        """
        B, N, K, S = input.shape
        P = K // 2
        # [B, N, S, K]
        input = input.transpose(2, 3).contiguous().view(B, N, -1, K * 2)

        input1 = input[:, :, :, :K].contiguous().view(B, N, -1)[:, :, P:]
        input2 = input[:, :, :, K:].contiguous().view(B, N, -1)[:, :, :-P]
        input = input1 + input2
        # [B, N, L]
        if gap > 0:
            input = input[:, :, :-gap]

        return input


class SepformerWrapper(nn.Module):
    """The wrapper for the sepformer model which combines the Encoder, Masknet and the decoder
    https://arxiv.org/abs/2010.13154

    Arguments
    ---------
    encoder_kernel_size: int
        The kernel size used in the encoder
    encoder_in_nchannels: int
        The number of channels of the input audio
    encoder_out_nchannels: int
        The number of filters used in the encoder.
        Also, number of channels that would be inputted to the intra and inter blocks.
    masknet_chunksize: int
        The chunk length that is to be processed by the intra blocks
    masknet_numlayers: int
        The number of layers of combination of inter and intra blocks
    masknet_norm: str,
        The normalization type to be used in the masknet
        Should be one of 'ln' -- layernorm, 'gln' -- globallayernorm
                         'cln' -- cumulative layernorm, 'bn' -- batchnorm
                         -- see the select_norm function above for more details
    masknet_useextralinearlayer: bool
        Whether or not to use a linear layer at the output of intra and inter blocks
    masknet_extraskipconnection: bool
        This introduces extra skip connections around the intra block
    masknet_numspks: int
        This determines the number of speakers to estimate
    intra_numlayers: int
        This determines the number of layers in the intra block
    inter_numlayers: int
        This determines the number of layers in the inter block
    intra_nhead: int
        This determines the number of parallel attention heads in the intra block
    inter_nhead: int
        This determines the number of parallel attention heads in the inter block
    intra_dffn: int
        The number of dimensions in the positional feedforward model in the inter block
    inter_dffn: int
        The number of dimensions in the positional feedforward model in the intra block
    intra_use_positional: bool
        Whether or not to use positional encodings in the intra block
    inter_use_positional: bool
        Whether or not to use positional encodings in the inter block
    intra_norm_before: bool
        Whether or not we use normalization before the transformations in the intra block
    inter_norm_before: bool
        Whether or not we use normalization before the transformations in the inter block

    Example
    -------
    >>> model = SepformerWrapper()
    >>> inp = torch.rand(1, 160)
    >>> result = model.forward(inp)
    >>> result.shape
    torch.Size([1, 160, 2])
    """

    def __init__(
        self,
        encoder_kernel_size=16,
        encoder_in_nchannels=1,
        encoder_out_nchannels=256,
        masknet_chunksize=250,
        masknet_numlayers=2,
        masknet_norm="ln",
        masknet_useextralinearlayer=False,
        masknet_extraskipconnection=True,
        masknet_numspks=2,
        intra_numlayers=8,
        inter_numlayers=8,
        intra_nhead=8,
        inter_nhead=8,
        intra_dffn=1024,
        inter_dffn=1024,
        intra_use_positional=True,
        inter_use_positional=True,
        intra_norm_before=True,
        inter_norm_before=True,
    ):
        super().__init__()
        self.encoder = Encoder(
            kernel_size=encoder_kernel_size,
            out_channels=encoder_out_nchannels,
            in_channels=encoder_in_nchannels,
        )
        intra_model = SBTransformerBlock(
            num_layers=intra_numlayers,
            d_model=encoder_out_nchannels,
            nhead=intra_nhead,
            d_ffn=intra_dffn,
            use_positional_encoding=intra_use_positional,
            norm_before=intra_norm_before,
        )

        inter_model = SBTransformerBlock(
            num_layers=inter_numlayers,
            d_model=encoder_out_nchannels,
            nhead=inter_nhead,
            d_ffn=inter_dffn,
            use_positional_encoding=inter_use_positional,
            norm_before=inter_norm_before,
        )

        self.masknet = Dual_Path_Model(
            in_channels=encoder_out_nchannels,
            out_channels=encoder_out_nchannels,
            intra_model=intra_model,
            inter_model=inter_model,
            num_layers=masknet_numlayers,
            norm=masknet_norm,
            K=masknet_chunksize,
            num_spks=masknet_numspks,
            skip_around_intra=masknet_extraskipconnection,
            linear_layer_after_inter_intra=masknet_useextralinearlayer,
        )
        self.decoder = Decoder(
            in_channels=encoder_out_nchannels,
            out_channels=encoder_in_nchannels,
            kernel_size=encoder_kernel_size,
            stride=encoder_kernel_size // 2,
            bias=False,
        )
        self.num_spks = masknet_numspks

        # reinitialize the parameters
        for module in [self.encoder, self.masknet, self.decoder]:
            self.reset_layer_recursively(module)

    def reset_layer_recursively(self, layer):
        """Reinitializes the parameters of the network"""
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
        for child_layer in layer.modules():
            if layer != child_layer:
                self.reset_layer_recursively(child_layer)

    def forward(self, mix):
        """Processes the input tensor x and returns an output tensor."""
        mix_w = self.encoder(mix)
        est_mask = self.masknet(mix_w)
        mix_w = torch.stack([mix_w] * self.num_spks)
        sep_h = mix_w * est_mask

        # Decoding
        est_source = torch.cat(
            [
                self.decoder(sep_h[i]).unsqueeze(-1)
                for i in range(self.num_spks)
            ],
            dim=-1,
        )

        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_est = est_source.size(1)
        if T_origin > T_est:
            est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
        else:
            est_source = est_source[:, :T_origin, :]

        return est_source


class SBConformerEncoderBlock(nn.Module):
    """A wrapper for the SpeechBrain implementation of the ConformerEncoder.

    Arguments
    ---------
    num_layers : int
        Number of layers.
    d_model : int
        Dimensionality of the representation.
    nhead : int
        Number of attention heads.
    d_ffn : int
        Dimensionality of positional feed forward.
    input_shape : tuple
        Shape of input.
    kdim : int
        Dimension of the key (Optional).
    vdim : int
        Dimension of the value (Optional).
    dropout : float
        Dropout rate.
    activation : str
        Activation function.
    kernel_size: int
        Kernel size in the conformer encoder
    bias: bool
        Use bias or not in the convolution part of conformer encoder
    use_positional_encoding : bool
        If true we use a positional encoding.
    attention_type : str
        The type of attention to use, default "RelPosMHAXL"

    Example
    -------
    >>> x = torch.randn(10, 100, 64)
    >>> block = SBConformerEncoderBlock(1, 64, 8)
    >>> from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding
    >>> pos_enc = PositionalEncoding(64)
    >>> pos_embs = pos_enc(torch.ones(1, 199, 64))
    >>> x = block(x)
    >>> x.shape
    torch.Size([10, 100, 64])
    """

    def __init__(
        self,
        num_layers,
        d_model,
        nhead,
        d_ffn=2048,
        input_shape=None,
        kdim=None,
        vdim=None,
        dropout=0.1,
        activation="swish",
        kernel_size=31,
        bias=True,
        use_positional_encoding=True,
        attention_type="RelPosMHAXL",
    ):
        super().__init__()
        self.use_positional_encoding = use_positional_encoding
        self.attention_type = attention_type

        if activation == "relu":
            activation = nn.ReLU
        elif activation == "gelu":
            activation = nn.GELU
        elif activation == "swish":
            activation = Swish
        else:
            raise ValueError("unknown activation")

        self.mdl = ConformerEncoder(
            num_layers=num_layers,
            nhead=nhead,
            d_ffn=d_ffn,
            d_model=d_model,
            kdim=kdim,
            vdim=vdim,
            dropout=dropout,
            activation=activation,
            kernel_size=kernel_size,
            bias=bias,
            attention_type=attention_type,
        )

        if self.attention_type == "RelPosMHAXL":
            # for RelPosMHAXL, we need the positional encoding (not optional)
            self.pos_enc = PositionalEncoding(input_size=d_model)
        elif self.attention_type == "regularMHA":
            if self.use_positional_encoding:
                self.pos_enc = PositionalEncoding(input_size=d_model)
        else:
            raise ValueError("Unsupported attention type")

    def forward(self, x):
        """Returns the transformed output.

        Arguments
        ---------
        x : torch.Tensor
            Tensor shape [B, L, N],
            where, B = Batchsize,
                   L = time points
                   N = number of filters

        Returns
        -------
        Transformed output
        """
        if self.attention_type == "RelPosMHAXL":
            pos_enc = self.pos_enc(
                torch.ones(
                    x.shape[0], x.shape[1] * 2 - 1, x.shape[2], device=x.device
                )
            )
            return self.mdl(x, pos_embs=pos_enc)[0]
        elif self.attention_type == "regularMHA":
            if self.use_positional_encoding:
                pos_embs = self.pos_enc(x)
                return self.mdl(x + pos_embs)[0]
            else:
                return self.mdl(x)[0]
        else:
            raise ValueError("Unsupported attention type")
