from typing import Any, Callable, Optional

from typing_extensions import Literal

from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
from torchmetrics.utilities.prints import _deprecated_root_import_class


class _PermutationInvariantTraining(PermutationInvariantTraining):
    """Wrapper for deprecated import.

    >>> import torch
    >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
    >>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
    >>> target = torch.randn(3, 2, 5) # [batch, spk, time]
    >>> pit = _PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
    ...     mode="speaker-wise", eval_func="max")
    >>> pit(preds, target)
    tensor(-2.1065)

    """

    def __init__(
        self,
        metric_func: Callable,
        mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
        eval_func: Literal["max", "min"] = "max",
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("PermutationInvariantTraining", "audio")
        super().__init__(metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs)


class _ScaleInvariantSignalDistortionRatio(ScaleInvariantSignalDistortionRatio):
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> target = tensor([3.0, -0.5, 2.0, 7.0])
    >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
    >>> si_sdr = _ScaleInvariantSignalDistortionRatio()
    >>> si_sdr(preds, target)
    tensor(18.4030)

    """

    def __init__(
        self,
        zero_mean: bool = False,
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("ScaleInvariantSignalDistortionRatio", "audio")
        super().__init__(zero_mean=zero_mean, **kwargs)


class _ScaleInvariantSignalNoiseRatio(ScaleInvariantSignalNoiseRatio):
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> target = tensor([3.0, -0.5, 2.0, 7.0])
    >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
    >>> si_snr = _ScaleInvariantSignalNoiseRatio()
    >>> si_snr(preds, target)
    tensor(15.0918)

    """

    def __init__(
        self,
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("ScaleInvariantSignalNoiseRatio", "audio")
        super().__init__(**kwargs)


class _SignalDistortionRatio(SignalDistortionRatio):
    """Wrapper for deprecated import.

    >>> import torch
    >>> preds = torch.randn(8000)
    >>> target = torch.randn(8000)
    >>> sdr = _SignalDistortionRatio()
    >>> sdr(preds, target)
    tensor(-11.9930)
    >>> # use with pit
    >>> from torchmetrics.functional import signal_distortion_ratio
    >>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
    >>> target = torch.randn(4, 2, 8000)
    >>> pit = _PermutationInvariantTraining(signal_distortion_ratio,
    ...     mode="speaker-wise", eval_func="max")
    >>> pit(preds, target)
    tensor(-11.7277)

    """

    def __init__(
        self,
        use_cg_iter: Optional[int] = None,
        filter_length: int = 512,
        zero_mean: bool = False,
        load_diag: Optional[float] = None,
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("SignalDistortionRatio", "audio")
        super().__init__(
            use_cg_iter=use_cg_iter, filter_length=filter_length, zero_mean=zero_mean, load_diag=load_diag, **kwargs
        )


class _SignalNoiseRatio(SignalNoiseRatio):
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> target = tensor([3.0, -0.5, 2.0, 7.0])
    >>> preds = tensor([2.5, 0.0, 2.0, 8.0])
    >>> snr = _SignalNoiseRatio()
    >>> snr(preds, target)
    tensor(16.1805)

    """

    def __init__(
        self,
        zero_mean: bool = False,
        **kwargs: Any,
    ) -> None:
        _deprecated_root_import_class("SignalNoiseRatio", "audio")
        super().__init__(zero_mean=zero_mean, **kwargs)
