# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, Optional, Union

from torch import Tensor, tensor

from torchmetrics.functional.audio.srmr import (
    _srmr_arg_validate,
    speech_reverberation_modulation_energy_ratio,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import (
    _GAMMATONE_AVAILABLE,
    _MATPLOTLIB_AVAILABLE,
    _TORCHAUDIO_AVAILABLE,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE]):
    __doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
elif not _MATPLOTLIB_AVAILABLE:
    __doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]


class SpeechReverberationModulationEnergyRatio(Metric):
    """Calculate `Speech-to-Reverberation Modulation Energy Ratio`_ (SRMR).

    SRMR is a non-intrusive metric for speech quality and intelligibility based on
    a modulation spectral representation of the speech signal.
    This code is translated from SRMRToolbox and `SRMRpy`_.

    As input to ``forward`` and ``update`` the metric accepts the following input

    - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``

    As output of `forward` and `compute` the metric returns the following output

    - ``srmr`` (:class:`~torch.Tensor`): float scaler tensor

    .. hint::
        Using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed.
        Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
        and ``pip install git+https://github.com/detly/gammatone``.

    .. attention::
        This implementation is experimental, and might not be consistent with the matlab
        implementation SRMRToolbox, especially the fast implementation.
        The slow versions, a) ``fast=False, norm=False, max_cf=128``, b) ``fast=False, norm=True, max_cf=30``,
        have a relatively small inconsistency.

    Args:
        fs: the sampling rate
        n_cochlear_filters: Number of filters in the acoustic filterbank
        low_freq: determines the frequency cutoff for the corresponding gammatone filterbank.
        min_cf: Center frequency in Hz of the first modulation filter.
        max_cf: Center frequency in Hz of the last modulation filter. If None is given,
            then 30 Hz will be used for `norm==False`, otherwise 128 Hz will be used.
        norm: Use modulation spectrum energy normalization
        fast: Use the faster version based on the gammatonegram.
            Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch,
            setting `fast=True` may slow down the speed for calculating this metric on GPU.

    Raises:
        ModuleNotFoundError:
            If ``gammatone`` or ``torchaudio`` package is not installed

    Example:
        >>> from torch import randn
        >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
        >>> preds = randn(8000)
        >>> srmr = SpeechReverberationModulationEnergyRatio(8000)
        >>> srmr(preds)
        tensor(0.3191)

    """

    msum: Tensor
    total: Tensor
    full_state_update: bool = False
    is_differentiable: bool = True
    higher_is_better: bool = True
    plot_lower_bound: Optional[float] = None
    plot_upper_bound: Optional[float] = None

    def __init__(
        self,
        fs: int,
        n_cochlear_filters: int = 23,
        low_freq: float = 125,
        min_cf: float = 4,
        max_cf: Optional[float] = None,
        norm: bool = False,
        fast: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
            raise ModuleNotFoundError(
                "speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
                " `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
                "``pip install torchaudio>=0.10`` and ``pip install git+https://github.com/detly/gammatone``"
            )
        _srmr_arg_validate(
            fs=fs,
            n_cochlear_filters=n_cochlear_filters,
            low_freq=low_freq,
            min_cf=min_cf,
            max_cf=max_cf,
            norm=norm,
            fast=fast,
        )

        self.fs = fs
        self.n_cochlear_filters = n_cochlear_filters
        self.low_freq = low_freq
        self.min_cf = min_cf
        self.max_cf = max_cf
        self.norm = norm
        self.fast = fast

        self.add_state("msum", default=tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

    def update(self, preds: Tensor) -> None:
        """Update state with predictions."""
        metric_val_batch = speech_reverberation_modulation_energy_ratio(
            preds, self.fs, self.n_cochlear_filters, self.low_freq, self.min_cf, self.max_cf, self.norm, self.fast
        ).to(self.msum.device)

        self.msum += metric_val_batch.sum()
        self.total += metric_val_batch.numel()

    def compute(self) -> Tensor:
        """Compute metric."""
        return self.msum / self.total

    def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
        """Plot a single or multiple values from the metric.

        Args:
            val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
                If no value is provided, will automatically call `metric.compute` and plot that result.
            ax: An matplotlib axis object. If provided will add plot to that axis

        Returns:
            Figure and Axes object

        Raises:
            ModuleNotFoundError:
                If `matplotlib` is not installed

        .. plot::
            :scale: 75

            >>> # Example plotting a single value
            >>> import torch
            >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
            >>> metric = SpeechReverberationModulationEnergyRatio(8000)
            >>> metric.update(torch.rand(8000))
            >>> fig_, ax_ = metric.plot()

        .. plot::
            :scale: 75

            >>> # Example plotting multiple values
            >>> import torch
            >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
            >>> metric = SpeechReverberationModulationEnergyRatio(8000)
            >>> values = [ ]
            >>> for _ in range(10):
            ...     values.append(metric(torch.rand(8000)))
            >>> fig_, ax_ = metric.plot(values)

        """
        return self._plot(val, ax)
