from asteroid_filterbanks import make_enc_dec
from asteroid_filterbanks.transforms import from_torch_complex, to_torch_complex
from ..masknn.convolutional import DCUMaskNet
from .base_models import BaseEncoderMaskerDecoder


class BaseDCUNet(BaseEncoderMaskerDecoder):
    """Base class for ``DCUNet`` and ``DCCRNet`` classes.

    Args:
        architecture (str): The architecture to use. Overriden by subclasses.
        stft_n_filters (int) Number of filters for the STFT.
        stft_kernel_size (int): STFT frame length to use.
        stft_stride (int, optional): STFT hop length to use.
        sample_rate (float): Sampling rate of the model.
        masknet_kwargs (optional): Passed to the masknet constructor.
    """

    masknet_class = NotImplemented

    def __init__(
        self,
        architecture,
        stft_n_filters=1024,
        stft_kernel_size=1024,
        stft_stride=256,
        sample_rate=16000.0,
        **masknet_kwargs,
    ):
        self.architecture = architecture
        self.stft_n_filters = stft_n_filters
        self.stft_kernel_size = stft_kernel_size
        self.stft_stride = stft_stride
        self.masknet_kwargs = masknet_kwargs

        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=stft_n_filters,
            kernel_size=stft_kernel_size,
            stride=stft_stride,
            sample_rate=sample_rate,
        )
        masker = self.masknet_class.default_architecture(architecture, **masknet_kwargs)
        super().__init__(encoder, masker, decoder)

    def forward_encoder(self, wav):
        tf_rep = self.encoder(wav)
        return to_torch_complex(tf_rep)

    def apply_masks(self, tf_rep, est_masks):
        masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
        return from_torch_complex(masked_tf_rep)

    def get_model_args(self):
        """Arguments needed to re-instantiate the model."""
        model_args = {
            "architecture": self.architecture,
            "stft_n_filters": self.stft_n_filters,
            "stft_kernel_size": self.stft_kernel_size,
            "stft_stride": self.stft_stride,
            "sample_rate": self.sample_rate,
            **self.masknet_kwargs,
        }
        return model_args


class DCUNet(BaseDCUNet):
    """DCUNet as proposed in [1].

    Args:
        architecture (str): The architecture to use, any of
            "DCUNet-10", "DCUNet-16", "DCUNet-20", "Large-DCUNet-20".
        stft_n_filters (int) Number of filters for the STFT.
        stft_kernel_size (int): STFT frame length to use.
        stft_stride (int, optional): STFT hop length to use.
        sample_rate (float): Sampling rate of the model.
        masknet_kwargs (optional): Passed to :class:`DCUMaskNet`

    References
        - [1] : "Phase-aware Speech Enhancement with Deep Complex U-Net",
          Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107
    """

    masknet_class = DCUMaskNet
