import inspect
import typing
from functools import wraps

from . import util


def format_figure(func):
    """Decorator for formatting figures produced by the code below.
    See :py:func:`audiotools.core.util.format_figure` for more.

    Parameters
    ----------
    func : Callable
        Plotting function that is decorated by this function.

    """

    @wraps(func)
    def wrapper(*args, **kwargs):
        f_keys = inspect.signature(util.format_figure).parameters.keys()
        f_kwargs = {}
        for k, v in list(kwargs.items()):
            if k in f_keys:
                kwargs.pop(k)
                f_kwargs[k] = v
        func(*args, **kwargs)
        util.format_figure(**f_kwargs)

    return wrapper


class DisplayMixin:
    @format_figure
    def specshow(
        self,
        preemphasis: bool = False,
        x_axis: str = "time",
        y_axis: str = "linear",
        n_mels: int = 128,
        **kwargs,
    ):
        """Displays a spectrogram, using ``librosa.display.specshow``.

        Parameters
        ----------
        preemphasis : bool, optional
            Whether or not to apply preemphasis, which makes high
            frequency detail easier to see, by default False
        x_axis : str, optional
            How to label the x axis, by default "time"
        y_axis : str, optional
            How to label the y axis, by default "linear"
        n_mels : int, optional
            If displaying a mel spectrogram with ``y_axis = "mel"``,
            this controls the number of mels, by default 128.
        kwargs : dict, optional
            Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
        """
        import librosa
        import librosa.display

        # Always re-compute the STFT data before showing it, in case
        # it changed.
        signal = self.clone()
        signal.stft_data = None

        if preemphasis:
            signal.preemphasis()

        ref = signal.magnitude.max()
        log_mag = signal.log_magnitude(ref_value=ref)

        if y_axis == "mel":
            log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
            log_mag -= log_mag.max()

        librosa.display.specshow(
            log_mag.numpy()[0].mean(axis=0),
            x_axis=x_axis,
            y_axis=y_axis,
            sr=signal.sample_rate,
            **kwargs,
        )

    @format_figure
    def waveplot(self, x_axis: str = "time", **kwargs):
        """Displays a waveform plot, using ``librosa.display.waveshow``.

        Parameters
        ----------
        x_axis : str, optional
            How to label the x axis, by default "time"
        kwargs : dict, optional
            Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
        """
        import librosa
        import librosa.display

        audio_data = self.audio_data[0].mean(dim=0)
        audio_data = audio_data.cpu().numpy()

        plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot"
        wave_plot_fn = getattr(librosa.display, plot_fn)
        wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)

    @format_figure
    def wavespec(self, x_axis: str = "time", **kwargs):
        """Displays a waveform plot, using ``librosa.display.waveshow``.

        Parameters
        ----------
        x_axis : str, optional
            How to label the x axis, by default "time"
        kwargs : dict, optional
            Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
        """
        import matplotlib.pyplot as plt
        from matplotlib.gridspec import GridSpec

        gs = GridSpec(6, 1)
        plt.subplot(gs[0, :])
        self.waveplot(x_axis=x_axis)
        plt.subplot(gs[1:, :])
        self.specshow(x_axis=x_axis, **kwargs)

    def write_audio_to_tb(
        self,
        tag: str,
        writer,
        step: int = None,
        plot_fn: typing.Union[typing.Callable, str] = "specshow",
        **kwargs,
    ):
        """Writes a signal and its spectrogram to Tensorboard. Will show up
        under the Audio and Images tab in Tensorboard.

        Parameters
        ----------
        tag : str
            Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
            written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
        writer : SummaryWriter
            A SummaryWriter object from PyTorch library.
        step : int, optional
            The step to write the signal to, by default None
        plot_fn : typing.Union[typing.Callable, str], optional
            How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
        kwargs : dict, optional
            Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
            whatever ``plot_fn`` is set to.
        """
        import matplotlib.pyplot as plt

        audio_data = self.audio_data[0, 0].detach().cpu()
        sample_rate = self.sample_rate
        writer.add_audio(tag, audio_data, step, sample_rate)

        if plot_fn is not None:
            if isinstance(plot_fn, str):
                plot_fn = getattr(self, plot_fn)
            fig = plt.figure()
            plt.clf()
            plot_fn(**kwargs)
            writer.add_figure(tag.replace("wav", "png"), fig, step)

    def save_image(
        self,
        image_path: str,
        plot_fn: typing.Union[typing.Callable, str] = "specshow",
        **kwargs,
    ):
        """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
        a specified file.

        Parameters
        ----------
        image_path : str
            Where to save the file to.
        plot_fn : typing.Union[typing.Callable, str], optional
            How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
        kwargs : dict, optional
            Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
            whatever ``plot_fn`` is set to.
        """
        import matplotlib.pyplot as plt

        if isinstance(plot_fn, str):
            plot_fn = getattr(self, plot_fn)

        plt.clf()
        plot_fn(**kwargs)
        plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
        plt.close()
