# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple

import torch

from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder
from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like
from nemo.collections.audio.modules.features import SpectrogramToMultichannelFeatures
from nemo.collections.audio.parts.submodules.multichannel import (
    ChannelAttentionPool,
    ChannelAveragePool,
    ParametricMultichannelWienerFilter,
    TransformAttendConcatenate,
    TransformAverageConcatenate,
    WPEFilter,
)
from nemo.collections.audio.parts.utils.audio import db2mag
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging


class MaskEstimatorRNN(NeuralModule):
    """Estimate `num_outputs` masks from the input spectrogram
    using stacked RNNs and projections.

    The module is structured as follows:
        input --> spatial features --> input projection -->
            --> stacked RNNs --> output projection for each output --> sigmoid

    Reference:
        Multi-microphone neural speech separation for far-field multi-talker
        speech recognition (https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8462081)

    Args:
        num_outputs: Number of output masks to estimate
        num_subbands: Number of subbands of the input spectrogram
        num_features: Number of features after the input projections
        num_layers: Number of RNN layers
        num_hidden_features: Number of hidden features in RNN layers
        num_input_channels: Number of input channels
        dropout: If non-zero, introduces dropout on the outputs of each RNN layer except the last layer, with dropout
                 probability equal to `dropout`. Default: 0
        bidirectional: If `True`, use bidirectional RNN.
        rnn_type: Type of RNN, either `lstm` or `gru`. Default: `lstm`
        mag_reduction: Channel-wise reduction for magnitude features
        use_ipd: Use inter-channel phase difference (IPD) features
    """

    def __init__(
        self,
        num_outputs: int,
        num_subbands: int,
        num_features: int = 1024,
        num_layers: int = 3,
        num_hidden_features: Optional[int] = None,
        num_input_channels: Optional[int] = None,
        dropout: float = 0,
        bidirectional=True,
        rnn_type: str = 'lstm',
        mag_reduction: str = 'rms',
        use_ipd: bool = None,
    ):
        super().__init__()
        if num_hidden_features is None:
            num_hidden_features = num_features

        self.features = SpectrogramToMultichannelFeatures(
            num_subbands=num_subbands,
            num_input_channels=num_input_channels,
            mag_reduction=mag_reduction,
            use_ipd=use_ipd,
        )

        self.input_projection = torch.nn.Linear(
            in_features=self.features.num_features * self.features.num_channels, out_features=num_features
        )

        if rnn_type == 'lstm':
            self.rnn = torch.nn.LSTM(
                input_size=num_features,
                hidden_size=num_hidden_features,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout,
                bidirectional=bidirectional,
            )
        elif rnn_type == 'gru':
            self.rnn = torch.nn.GRU(
                input_size=num_features,
                hidden_size=num_hidden_features,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout,
                bidirectional=bidirectional,
            )
        else:
            raise ValueError(f'Unknown rnn_type: {rnn_type}')

        self.fc = torch.nn.Linear(
            in_features=2 * num_features if bidirectional else num_features, out_features=num_features
        )
        self.norm = torch.nn.LayerNorm(num_features)

        # Each output shares the RNN and has a separate projection
        self.output_projections = torch.nn.ModuleList(
            [torch.nn.Linear(in_features=num_features, out_features=num_subbands) for _ in range(num_outputs)]
        )
        self.output_nonlinearity = torch.nn.Sigmoid()

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Estimate `num_outputs` masks from the input spectrogram.

        Args:
            input: C-channel input, shape (B, C, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)

        Returns:
            Returns `num_outputs` masks in a tensor, shape (B, num_outputs, F, N),
            and output length with shape (B,)
        """
        input, _ = self.features(input=input, input_length=input_length)
        B, num_feature_channels, num_features, N = input.shape

        # (B, num_feat_channels, num_feat, N) -> (B, N, num_feat_channels, num_feat)
        input = input.permute(0, 3, 1, 2)

        # (B, N, num_feat_channels, num_feat) -> (B, N, num_feat_channels * num_features)
        input = input.view(B, N, -1)

        # Apply projection on num_feat
        input = self.input_projection(input)

        # Apply RNN on the input sequence
        input_packed = torch.nn.utils.rnn.pack_padded_sequence(
            input, input_length.cpu(), batch_first=True, enforce_sorted=False
        ).to(input.device)
        self.rnn.flatten_parameters()
        input_packed, _ = self.rnn(input_packed)
        output, output_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True)
        output_length = output_length.to(input.device)

        # Layer normalization and skip connection
        output = self.norm(self.fc(output)) + input

        # Create `num_outputs` masks
        masks = []
        for output_projection in self.output_projections:
            # Output projection
            mask = output_projection(output)
            mask = self.output_nonlinearity(mask)

            # Back to the original format
            # (B, N, F) -> (B, F, N)
            mask = mask.transpose(2, 1)

            # Append to the output
            masks.append(mask)

        # Stack along channel dimension to get (B, M, F, N)
        masks = torch.stack(masks, axis=1)

        # Mask frames beyond output length
        length_mask: torch.Tensor = make_seq_mask_like(
            lengths=output_length, like=masks, time_dim=-1, valid_ones=False
        )
        masks = masks.masked_fill(length_mask, 0.0)

        return masks, output_length


class MaskEstimatorFlexChannels(NeuralModule):
    """Estimate `num_outputs` masks from the input spectrogram
    using stacked channel-wise and temporal layers.

    This model is using interlaved channel blocks and temporal blocks, and
    it can process arbitrary number of input channels.
    Default channel block is the transform-average-concatenate layer.
    Default temporal block is the Conformer encoder.
    Reduction from multichannel signal to single-channel signal is performed
    after `channel_reduction_position` blocks. Only temporal blocks are used afterwards.
    After the sequence of blocks, the output mask is computed using an additional
    output temporal layer and a nonlinearity.

    References:
        - Yoshioka et al, VarArray: Array-Geometry-Agnostic Continuous Speech Separation, 2022
        - Jukić et al, Flexible multichannel speech enhancement for noise-robust frontend, 2023

    Args:
        num_outputs: Number of output masks.
        num_subbands: Number of subbands on the input spectrogram.
        num_blocks: Number of blocks in the model.
        channel_reduction_position: After this block, the signal will be reduced across channels.
        channel_reduction_type: Reduction across channels: 'average' or 'attention'
        channel_block_type: Block for channel processing: 'transform_average_concatenate' or 'transform_attend_concatenate'
        temporal_block_type: Block for temporal processing: 'conformer_encoder'
        temporal_block_num_layers: Number of layers for the temporal block
        temporal_block_num_heads: Number of heads for the temporal block
        temporal_block_dimension: The hidden size of the model
        temporal_block_self_attention_model: Self attention model for the temporal block
        temporal_block_att_context_size: Attention context size for the temporal block
        mag_reduction: Channel-wise reduction for magnitude features
        mag_power: Power to apply on magnitude features
        use_ipd: Use inter-channel phase difference (IPD) features
        mag_normalization: Normalize using mean ('mean') or mean and variance ('mean_var')
        ipd_normalization: Normalize using mean ('mean') or mean and variance ('mean_var')
    """

    def __init__(
        self,
        num_outputs: int,
        num_subbands: int,
        num_blocks: int,
        channel_reduction_position: int = -1,  # if 0, apply before block 0, if -1 apply at the end
        channel_reduction_type: str = 'attention',
        channel_block_type: str = 'transform_attend_concatenate',
        temporal_block_type: str = 'conformer_encoder',
        temporal_block_num_layers: int = 5,
        temporal_block_num_heads: int = 4,
        temporal_block_dimension: int = 128,
        temporal_block_self_attention_model: str = 'rel_pos',
        temporal_block_att_context_size: Optional[List[int]] = None,
        num_input_channels: Optional[int] = None,
        mag_reduction: str = 'abs_mean',
        mag_power: Optional[float] = None,
        use_ipd: bool = True,
        mag_normalization: Optional[str] = None,
        ipd_normalization: Optional[str] = None,
    ):
        super().__init__()

        self.features = SpectrogramToMultichannelFeatures(
            num_subbands=num_subbands,
            num_input_channels=num_input_channels,
            mag_reduction=mag_reduction,
            mag_power=mag_power,
            use_ipd=use_ipd,
            mag_normalization=mag_normalization,
            ipd_normalization=ipd_normalization,
        )
        self.num_blocks = num_blocks
        logging.debug('Total number of blocks: %d', self.num_blocks)

        # Channel reduction
        if channel_reduction_position == -1:
            # Apply reduction after the last layer
            channel_reduction_position = num_blocks

        if channel_reduction_position > num_blocks:
            raise ValueError(
                f'Channel reduction position {channel_reduction_position} exceeds the number of blocks {num_blocks}'
            )
        self.channel_reduction_position = channel_reduction_position
        logging.debug('Channel reduction will be applied before block %d', self.channel_reduction_position)

        # Prepare processing blocks
        self.channel_blocks = torch.nn.ModuleList()
        self.temporal_blocks = torch.nn.ModuleList()

        for n in range(num_blocks):
            logging.debug('Prepare block %d', n)

            # Setup channel block
            if n < channel_reduction_position:
                # Number of input features is either the number of input channels or the number of temporal block features
                channel_in_features = self.features.num_features if n == 0 else temporal_block_dimension
                logging.debug(
                    'Setup channel block %s with %d input features and %d output features',
                    channel_block_type,
                    channel_in_features,
                    temporal_block_dimension,
                )

                # Instantiante the channel block
                if channel_block_type == 'transform_average_concatenate':
                    channel_block = TransformAverageConcatenate(
                        in_features=channel_in_features, out_features=temporal_block_dimension
                    )
                elif channel_block_type == 'transform_attend_concatenate':
                    channel_block = TransformAttendConcatenate(
                        in_features=channel_in_features, out_features=temporal_block_dimension
                    )
                else:
                    raise ValueError(f'Unknown channel layer type: {channel_block_type}')
                self.channel_blocks.append(channel_block)

            # Setup temporal block
            temporal_in_features = (
                self.features.num_features if n == self.channel_reduction_position == 0 else temporal_block_dimension
            )
            logging.debug('Setup temporal block %s', temporal_block_type)
            if temporal_block_type == 'conformer_encoder':
                temporal_block = ConformerEncoder(
                    feat_in=temporal_in_features,
                    n_layers=temporal_block_num_layers,
                    d_model=temporal_block_dimension,
                    subsampling_factor=1,
                    self_attention_model=temporal_block_self_attention_model,
                    att_context_size=temporal_block_att_context_size,
                    n_heads=temporal_block_num_heads,
                )
            else:
                raise ValueError(f'Unknown temporal block {temporal_block}.')

            self.temporal_blocks.append(temporal_block)

        logging.debug('Setup channel reduction %s', channel_reduction_type)
        if channel_reduction_type == 'average':
            # Mean across channel dimension
            self.channel_reduction = ChannelAveragePool()
        elif channel_reduction_type == 'attention':
            # Number of input features is either the number of input channels or the number of temporal block features
            channel_reduction_in_features = (
                self.features.num_features if self.channel_reduction_position == 0 else temporal_block_dimension
            )
            # Attention across channel dimension
            self.channel_reduction = ChannelAttentionPool(in_features=channel_reduction_in_features)
        else:
            raise ValueError(f'Unknown channel reduction type: {channel_reduction_type}')

        logging.debug('Setup %d output layers', num_outputs)
        self.output_layers = torch.nn.ModuleList(
            [
                ConformerEncoder(
                    feat_in=temporal_block_dimension,
                    n_layers=1,
                    d_model=temporal_block_dimension,
                    feat_out=num_subbands,
                    subsampling_factor=1,
                    self_attention_model=temporal_block_self_attention_model,
                    att_context_size=temporal_block_att_context_size,
                    n_heads=temporal_block_num_heads,
                )
                for _ in range(num_outputs)
            ]
        )

        # Output nonlinearity
        self.output_nonlinearity = torch.nn.Sigmoid()

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Estimate `num_outputs` masks from the input spectrogram."""
        # get input features from a complex-valued spectrogram, (B, C, F, T)
        output, output_length = self.features(input=input, input_length=input_length)

        # batch and num channels
        B, M = input.size(0), input.size(1)

        # process all blocks
        for n in range(self.num_blocks):
            if n < self.channel_reduction_position:
                # apply multichannel block
                output = self.channel_blocks[n](input=output)
                # change to a single-stream format
                F, T = output.size(-2), output.size(-1)
                # (B, M, F, T) -> (B * M, F, T)
                output = output.reshape(-1, F, T)
                if M > 1:
                    # adjust the lengths accordingly
                    output_length = output_length.repeat_interleave(M)

            elif n == self.channel_reduction_position:
                # apply channel reduction
                # (B, M, F, T) -> (B, F, T)
                output = self.channel_reduction(input=output)

            # apply temporal model on each channel independently
            with typecheck.disable_checks():
                # output is AcousticEncodedRepresentation, conformer encoder requires SpectrogramType
                output, output_length = self.temporal_blocks[n](audio_signal=output, length=output_length)

            # if channel reduction has not been applied yet, go back to multichannel layout
            if n < self.channel_reduction_position:
                # back to multi-channel format with possibly a different number of features
                T = output.size(-1)
                # (B * M, F, T) -> (B, M, F, T)
                output = output.reshape(B, M, -1, T)
                if M > 1:
                    # convert lengths from single-stream format to original multichannel
                    output_length = output_length[0:-1:M]

        if self.channel_reduction_position == self.num_blocks:
            # apply channel reduction after the last layer
            # (B, M, F, T) -> (B, F, T)
            output = self.channel_reduction(input=output)

        # final mask for each output
        masks = []
        for output_layer in self.output_layers:
            # calculate mask
            with typecheck.disable_checks():
                # output is AcousticEncodedRepresentation, conformer encoder requires SpectrogramType
                mask, mask_length = output_layer(audio_signal=output, length=output_length)
            mask = self.output_nonlinearity(mask)
            # append to all masks
            masks.append(mask)

        # stack masks along channel dimensions
        masks = torch.stack(masks, dim=1)

        return masks, mask_length


class MaskEstimatorGSS(NeuralModule):
    """Estimate masks using guided source separation with a complex
    angular Central Gaussian Mixture Model (cACGMM) [1].

    This module corresponds to `GSS` in Fig. 2 in [2].

    Notation is approximately following [1], where `gamma` denotes
    the time-frequency mask, `alpha` denotes the mixture weights,
    and `BM` denotes the shape matrix. Additionally, the provided
    source activity is denoted as `activity`.

    Args:
        num_iterations: Number of iterations for the EM algorithm
        eps: Small value for regularization
        dtype: Data type for internal computations (default `torch.cdouble`)

    References:
        [1] Ito et al., Complex Angular Central Gaussian Mixture Model for Directional Statistics in Mask-Based Microphone Array Signal Processing, 2016
        [2] Boeddeker et al., Front-End Processing for the CHiME-5 Dinner Party Scenario, 2018
    """

    def __init__(self, num_iterations: int = 3, eps: float = 1e-8, dtype: torch.dtype = torch.cdouble):
        super().__init__()

        if num_iterations <= 0:
            raise ValueError(f'Number of iterations must be positive, got {num_iterations}')

        # number of iterations for the EM algorithm
        self.num_iterations = num_iterations

        if eps <= 0:
            raise ValueError(f'eps must be positive, got {eps}')

        # small regularization constant
        self.eps = eps

        # internal calculations
        if dtype not in [torch.cfloat, torch.cdouble]:
            raise ValueError(f'Unsupported dtype {dtype}, expecting cfloat or cdouble')
        self.dtype = dtype

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\tnum_iterations: %s', self.num_iterations)
        logging.debug('\teps:            %g', self.eps)
        logging.debug('\tdtype:          %s', self.dtype)

    def normalize(self, x: torch.Tensor, dim: int = 1) -> torch.Tensor:
        """Normalize input to have a unit L2-norm across `dim`.
        By default, normalizes across the input channels.

        Args:
            x: C-channel input signal, shape (B, C, F, T)
            dim: Dimension for normalization, defaults to -3 to normalize over channels

        Returns:
            Normalized signal, shape (B, C, F, T)
        """
        norm_x = torch.linalg.vector_norm(x, ord=2, dim=dim, keepdim=True)
        x = x / (norm_x + self.eps)
        return x

    @typecheck(
        input_types={
            'alpha': NeuralType(('B', 'C', 'D')),
            'activity': NeuralType(('B', 'C', 'T')),
            'log_pdf': NeuralType(('B', 'C', 'D', 'T')),
        },
        output_types={
            'gamma': NeuralType(('B', 'C', 'D', 'T')),
        },
    )
    def update_masks(self, alpha: torch.Tensor, activity: torch.Tensor, log_pdf: torch.Tensor) -> torch.Tensor:
        """Update masks for the cACGMM.

        Args:
            alpha: component weights, shape (B, num_outputs, F)
            activity: temporal activity for the components, shape (B, num_outputs, T)
            log_pdf: logarithm of the PDF, shape (B, num_outputs, F, T)

        Returns:
            Masks for the components of the model, shape (B, num_outputs, F, T)
        """
        # (B, num_outputs, F)
        # normalize across outputs in the log domain
        log_gamma = log_pdf - torch.max(log_pdf, axis=-3, keepdim=True)[0]

        gamma = torch.exp(log_gamma)

        # calculate the mask using weight, pdf and source activity
        gamma = alpha[..., None] * gamma * activity[..., None, :]

        # normalize across components/output channels
        gamma = gamma / (torch.sum(gamma, dim=-3, keepdim=True) + self.eps)

        return gamma

    @typecheck(
        input_types={
            'gamma': NeuralType(('B', 'C', 'D', 'T')),
        },
        output_types={
            'alpha': NeuralType(('B', 'C', 'D')),
        },
    )
    def update_weights(self, gamma: torch.Tensor) -> torch.Tensor:
        """Update weights for the individual components
        in the mixture model.

        Args:
            gamma: masks, shape (B, num_outputs, F, T)

        Returns:
            Component weights, shape (B, num_outputs, F)
        """
        alpha = torch.mean(gamma, dim=-1)
        return alpha

    @typecheck(
        input_types={
            'z': NeuralType(('B', 'C', 'D', 'T')),
            'gamma': NeuralType(('B', 'C', 'D', 'T')),
            'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),
        },
        output_types={
            'log_pdf': NeuralType(('B', 'C', 'D', 'T')),
            'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),
        },
    )
    def update_pdf(
        self, z: torch.Tensor, gamma: torch.Tensor, zH_invBM_z: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update PDF of the cACGMM.

        Args:
            z: directional statistics, shape (B, num_inputs, F, T)
            gamma: masks, shape (B, num_outputs, F, T)
            zH_invBM_z: energy weighted by shape matrices, shape (B, num_outputs, F, T)

        Returns:
            Logarithm of the PDF, shape (B, num_outputs, F, T), the energy term, shape (B, num_outputs, F, T)
        """
        num_inputs = z.size(-3)

        # shape (B, num_outputs, F, T)
        scale = gamma / (zH_invBM_z + self.eps)

        # scale outer product and sum over time
        # shape (B, num_outputs, F, num_inputs, num_inputs)
        BM = num_inputs * torch.einsum('bmft,bift,bjft->bmfij', scale.to(z.dtype), z, z.conj())

        # normalize across time
        denom = torch.sum(gamma, dim=-1)
        BM = BM / (denom[..., None, None] + self.eps)

        # make sure the matrix is Hermitian
        BM = (BM + BM.conj().transpose(-1, -2)) / 2

        # use eigenvalue decomposition to calculate the log determinant
        # and the inverse-weighted energy term
        L, Q = torch.linalg.eigh(BM)

        # BM is positive definite, so all eigenvalues should be positive
        # However, small negative values may occur due to a limited precision
        L = torch.clamp(L.real, min=self.eps)

        # PDF is invariant to scaling of the shape matrix [1], so
        # eignevalues can be normalized (across num_inputs)
        L = L / (torch.max(L, axis=-1, keepdim=True)[0] + self.eps)

        # small regularization to avoid numerical issues
        L = L + self.eps

        # calculate the log determinant using the eigenvalues
        log_detBM = torch.sum(torch.log(L), dim=-1)

        # calculate the energy term using the inverse eigenvalues
        # NOTE: keeping an alternative implementation for reference (slower)
        # zH_invBM_z = torch.einsum('bift,bmfij,bmfj,bmfkj,bkft->bmft', z.conj(), Q, (1 / L).to(Q.dtype), Q.conj(), z)
        # zH_invBM_z = zH_invBM_z.abs() + self.eps # small regularization

        # calc sqrt(L) * Q^H * z
        zH_invBM_z = torch.einsum('bmfj,bmfkj,bkft->bmftj', (1 / L.sqrt()).to(Q.dtype), Q.conj(), z)
        # calc squared norm
        zH_invBM_z = zH_invBM_z.abs().pow(2).sum(-1)
        # small regularization
        zH_invBM_z = zH_invBM_z + self.eps

        # final log PDF
        log_pdf = -num_inputs * torch.log(zH_invBM_z) - log_detBM[..., None]

        return log_pdf, zH_invBM_z

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "activity": NeuralType(('B', 'C', 'T')),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "gamma": NeuralType(('B', 'C', 'D', 'T')),
        }

    @typecheck()
    def forward(self, input: torch.Tensor, activity: torch.Tensor) -> torch.Tensor:
        """Apply GSS to estimate the time-frequency masks for each output source.

        Args:
            input: batched C-channel input signal, shape (B, num_inputs, F, T)
            activity: batched frame-wise activity for each output source, shape (B, num_outputs, T)

        Returns:
            Masks for the components of the model, shape (B, num_outputs, F, T)
        """
        B, num_inputs, F, T = input.shape
        num_outputs = activity.size(1)
        device = input.device.type

        if activity.size(0) != B:
            raise ValueError(f'Batch dimension mismatch: activity {activity.shape} vs input {input.shape}')

        if activity.size(-1) != T:
            raise ValueError(f'Time dimension mismatch: activity {activity.shape} vs input {input.shape}')

        if num_outputs == 1:
            raise ValueError(f'Expecting multiple outputs, got {num_outputs}')

        with torch.amp.autocast(device, enabled=False):
            input = input.to(dtype=self.dtype)

            assert input.is_complex(), f'Expecting complex input, got {input.dtype}'

            # convert input to directional statistics by normalizing across channels
            z = self.normalize(input, dim=-3)

            # initialize masks
            gamma = torch.clamp(activity, min=self.eps)
            # normalize across channels
            gamma = gamma / torch.sum(gamma, dim=-2, keepdim=True)
            # expand to input shape
            gamma = gamma.unsqueeze(2).expand(-1, -1, F, -1)

            # initialize the energy term
            zH_invBM_z = torch.ones(B, num_outputs, F, T, dtype=input.dtype, device=input.device)

            # EM iterations
            for it in range(self.num_iterations):
                alpha = self.update_weights(gamma=gamma)
                log_pdf, zH_invBM_z = self.update_pdf(z=z, gamma=gamma, zH_invBM_z=zH_invBM_z)
                gamma = self.update_masks(alpha=alpha, activity=activity, log_pdf=log_pdf)

        if torch.any(torch.isnan(gamma)):
            raise RuntimeError(f'gamma contains NaNs: {gamma}')

        return gamma


class MaskReferenceChannel(NeuralModule):
    """A simple mask processor which applies mask
    on ref_channel of the input signal.

    Args:
        ref_channel: Index of the reference channel.
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB
    """

    def __init__(self, ref_channel: int = 0, mask_min_db: float = -200, mask_max_db: float = 0):
        super().__init__()
        self.ref_channel = ref_channel
        # Mask thresholding
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)

        logging.debug('Initialized %s with', self.__class__.__name__)
        logging.debug('\tref_channel: %d', self.ref_channel)
        logging.debug('\tmask_min:    %f', self.mask_min)
        logging.debug('\tmask_max:    %f', self.mask_max)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(
        self,
        input: torch.Tensor,
        input_length: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply mask on `ref_channel` of the input signal.
        This can be used to generate multi-channel output.
        If `mask` has `M` channels, the output will have `M` channels as well.

        Args:
            input: Input signal complex-valued spectrogram, shape (B, C, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)
            mask: Mask for M outputs, shape (B, M, F, N)

        Returns:
            M-channel output complex-valed spectrogram with shape (B, M, F, N)
        """
        # Apply thresholds
        mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)

        # Apply each output mask on the ref channel
        output = mask * input[:, self.ref_channel : self.ref_channel + 1, ...]
        return output, input_length


class MaskBasedBeamformer(NeuralModule):
    """Multi-channel processor using masks to estimate signal statistics.

    Args:
        filter_type: string denoting the type of the filter. Defaults to `mvdr`
        filter_beta: Parameter of the parameteric multichannel Wiener filter
        filter_rank: Parameter of the parametric multichannel Wiener filter
        filter_postfilter: Optional, postprocessing of the filter
        ref_channel: Optional, reference channel. If None, it will be estimated automatically
        ref_hard: If true, hard (one-hot) reference. If false, a soft reference
        ref_hard_use_grad: If true, use straight-through gradient when using the hard reference
        ref_subband_weighting: If true, use subband weighting when estimating reference channel
        num_subbands: Optional, used to determine the parameter size for reference estimation
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB
        diag_reg: Optional, diagonal regularization for the multichannel filter
        eps: Small regularization constant to avoid division by zero
    """

    def __init__(
        self,
        filter_type: str = 'mvdr_souden',
        filter_beta: float = 0.0,
        filter_rank: str = 'one',
        filter_postfilter: Optional[str] = None,
        ref_channel: Optional[int] = 0,
        ref_hard: bool = True,
        ref_hard_use_grad: bool = False,
        ref_subband_weighting: bool = False,
        num_subbands: Optional[int] = None,
        mask_min_db: float = -200,
        mask_max_db: float = 0,
        postmask_min_db: float = 0,
        postmask_max_db: float = 0,
        diag_reg: Optional[float] = 1e-6,
        eps: float = 1e-8,
    ):
        super().__init__()
        if filter_type not in ['pmwf', 'mvdr_souden']:
            raise ValueError(f'Unknown filter type {filter_type}')

        self.filter_type = filter_type
        if self.filter_type == 'mvdr_souden' and filter_beta != 0:
            logging.warning(
                'Using filter type %s: beta will be automatically set to zero (current beta %f) and rank to one (current rank %s).',
                self.filter_type,
                filter_beta,
                filter_rank,
            )
            filter_beta = 0.0
            filter_rank = 'one'
        # Prepare filter
        self.filter = ParametricMultichannelWienerFilter(
            beta=filter_beta,
            rank=filter_rank,
            postfilter=filter_postfilter,
            ref_channel=ref_channel,
            ref_hard=ref_hard,
            ref_hard_use_grad=ref_hard_use_grad,
            ref_subband_weighting=ref_subband_weighting,
            num_subbands=num_subbands,
            diag_reg=diag_reg,
            eps=eps,
        )
        # Mask thresholding
        if mask_min_db >= mask_max_db:
            raise ValueError(
                f'Lower bound for the mask {mask_min_db}dB must be smaller than the upper bound {mask_max_db}dB'
            )
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)
        # Postmask thresholding
        if postmask_min_db > postmask_max_db:
            raise ValueError(
                f'Lower bound for the postmask {postmask_min_db}dB must be smaller or equal to the upper bound {postmask_max_db}dB'
            )
        self.postmask_min = db2mag(postmask_min_db)
        self.postmask_max = db2mag(postmask_max_db)

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\tfilter_type:  %s', self.filter_type)
        logging.debug('\tmask_min:     %e', self.mask_min)
        logging.debug('\tmask_max:     %e', self.mask_max)
        logging.debug('\tpostmask_min: %e', self.postmask_min)
        logging.debug('\tpostmask_max: %e', self.postmask_max)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
            "mask_undesired": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True),
            "input_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @typecheck()
    def forward(
        self,
        input: torch.Tensor,
        mask: torch.Tensor,
        mask_undesired: Optional[torch.Tensor] = None,
        input_length: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Apply a mask-based beamformer to the input spectrogram.
        This can be used to generate multi-channel output.
        If `mask` has multiple channels, a multichannel filter is created for each mask,
        and the output is concatenation of individual outputs along the channel dimension.
        The total number of outputs is `num_masks * M`, where `M` is the number of channels
        at the filter output.

        Args:
            input: Input signal complex-valued spectrogram, shape (B, C, F, N)
            mask: Mask for M output signals, shape (B, num_masks, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)

        Returns:
            Multichannel output signal complex-valued spectrogram, shape (B, num_masks * M, F, N)
        """
        # Length mask
        if input_length is not None:
            length_mask: torch.Tensor = make_seq_mask_like(
                lengths=input_length, like=mask[:, 0, ...], time_dim=-1, valid_ones=False
            )

        # Use each mask to generate an output
        output, num_masks = [], mask.size(1)
        for m in range(num_masks):
            # Desired signal mask
            mask_d = mask[:, m, ...]
            # Undesired signal mask
            if mask_undesired is not None:
                mask_u = mask_undesired[:, m, ...]
            elif num_masks == 1:
                # If a single mask is estimated, use the complement
                mask_u = 1 - mask_d
            else:
                # Use sum of all other sources
                mask_u = torch.sum(mask, dim=1) - mask_d

            # Threshold masks
            mask_d = torch.clamp(mask_d, min=self.mask_min, max=self.mask_max)
            mask_u = torch.clamp(mask_u, min=self.mask_min, max=self.mask_max)

            if input_length is not None:
                mask_d = mask_d.masked_fill(length_mask, 0.0)
                mask_u = mask_u.masked_fill(length_mask, 0.0)

            # Apply filter
            output_m = self.filter(input=input, mask_s=mask_d, mask_n=mask_u)

            # Optional: apply a postmask with min and max thresholds
            if self.postmask_min < self.postmask_max:
                postmask_m = torch.clamp(mask[:, m, ...], min=self.postmask_min, max=self.postmask_max)
                output_m = output_m * postmask_m.unsqueeze(1)

            # Save the current output (B, M, F, T)
            output.append(output_m)

        # Combine outputs along the channel dimension
        # Each output is (B, M, F, T)
        output = torch.concatenate(output, axis=1)

        # Apply masking
        if input_length is not None:
            output = output.masked_fill(length_mask[:, None, ...], 0.0)

        return output, input_length


class MaskBasedDereverbWPE(NeuralModule):
    """Multi-channel linear prediction-based dereverberation using
    weighted prediction error for filter estimation.

    An optional mask to estimate the signal power can be provided.
    If a time-frequency mask is not provided, the algorithm corresponds
    to the conventional WPE algorithm.

    Args:
        filter_length: Length of the convolutional filter for each channel in frames.
        prediction_delay: Delay of the input signal for multi-channel linear prediction in frames.
        num_iterations: Number of iterations for reweighting
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a minimal value before applying it, defaults to 0dB
        diag_reg: Diagonal regularization for WPE
        eps: Small regularization constant
        dtype: Data type for internal computations

    References:
        - Kinoshita et al, Neural network-based spectrum estimation for online WPE dereverberation, 2017
        - Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction Methods for Blind MIMO Impulse Response Shortening, 2012
    """

    def __init__(
        self,
        filter_length: int,
        prediction_delay: int,
        num_iterations: int = 1,
        mask_min_db: float = -200,
        mask_max_db: float = 0,
        diag_reg: Optional[float] = 1e-6,
        eps: float = 1e-8,
        dtype: torch.dtype = torch.cdouble,
    ):
        super().__init__()
        # Filter setup
        self.filter = WPEFilter(
            filter_length=filter_length, prediction_delay=prediction_delay, diag_reg=diag_reg, eps=eps
        )
        self.num_iterations = num_iterations
        # Mask thresholding
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)
        # Internal calculations
        if dtype not in [torch.cfloat, torch.cdouble]:
            raise ValueError(f'Unsupported dtype {dtype}, expecting torch.cfloat or torch.cdouble')
        self.dtype = dtype

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\tnum_iterations: %s', self.num_iterations)
        logging.debug('\tmask_min:       %g', self.mask_min)
        logging.debug('\tmask_max:       %g', self.mask_max)
        logging.debug('\tdtype:          %s', self.dtype)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType(), optional=True),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @typecheck()
    def forward(
        self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Given an input signal `input`, apply the WPE dereverberation algoritm.

        Args:
            input: C-channel complex-valued spectrogram, shape (B, C, F, T)
            input_length: Optional length for each signal in the batch, shape (B,)
            mask: Optional mask, shape (B, 1, F, N) or (B, C, F, T)

        Returns:
            Processed tensor with the same number of channels as the input,
            shape (B, C, F, T).
        """
        io_dtype = input.dtype
        device = input.device.type

        with torch.amp.autocast(device, enabled=False):
            output = input.to(dtype=self.dtype)

            if not output.is_complex():
                raise RuntimeError(f'Expecting complex input, got {output.dtype}')

            for i in range(self.num_iterations):
                magnitude = torch.abs(output)
                if i == 0 and mask is not None:
                    # Apply thresholds
                    mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)
                    # Mask magnitude
                    magnitude = mask * magnitude
                # Calculate power
                power = magnitude**2
                # Apply filter
                output, output_length = self.filter(input=output, input_length=input_length, power=power)

        return output.to(io_dtype), output_length
