import logging
import torch
from typing import Optional
from torch import Tensor
from torch.nn.functional import pad

from ..core.transforms_interface import BaseWaveformTransform
from ..utils.dsp import convert_decibels_to_amplitude_ratio
from ..utils.object_dict import ObjectDict


class SpliceOut(BaseWaveformTransform):

    """
    spliceout augmentation proposed in https://arxiv.org/pdf/2110.00046.pdf
    silence padding is added at the end to retain the audio length.
    """

    supported_modes = {"per_batch", "per_example"}
    requires_sample_rate = True

    def __init__(
        self,
        num_time_intervals=8,
        max_width=400,
        mode: str = "per_example",
        p: float = 0.5,
        p_mode: Optional[str] = None,
        sample_rate: Optional[int] = None,
        target_rate: Optional[int] = None,
        output_type: Optional[str] = None,
    ):
        """
        param num_time_intervals: number of time intervals to spliceout
        param max_width: maximum width of each spliceout in milliseconds
        param n_fft: size of FFT
        """

        super().__init__(
            mode=mode,
            p=p,
            p_mode=p_mode,
            sample_rate=sample_rate,
            target_rate=target_rate,
            output_type=output_type,
        )
        self.num_time_intervals = num_time_intervals
        self.max_width = max_width

    def randomize_parameters(
        self,
        samples: Tensor = None,
        sample_rate: Optional[int] = None,
        targets: Optional[Tensor] = None,
        target_rate: Optional[int] = None,
    ):
        self.transform_parameters["splice_lengths"] = torch.randint(
            low=0,
            high=int(sample_rate * self.max_width * 1e-3),
            size=(samples.shape[0], self.num_time_intervals),
        )

    def apply_transform(
        self,
        samples: Tensor = None,
        sample_rate: Optional[int] = None,
        targets: Optional[Tensor] = None,
        target_rate: Optional[int] = None,
    ) -> ObjectDict:
        spliceout_samples = []

        for i in range(samples.shape[0]):
            random_lengths = self.transform_parameters["splice_lengths"][i]
            sample = samples[i][:, :]
            for j in range(self.num_time_intervals):
                start = torch.randint(
                    0,
                    sample.shape[-1] - random_lengths[j],
                    size=(1,),
                )

                if random_lengths[j] % 2 != 0:
                    random_lengths[j] += 1

                hann_window_len = random_lengths[j]
                hann_window = torch.hann_window(hann_window_len, device=samples.device)
                hann_window_left, hann_window_right = (
                    hann_window[: hann_window_len // 2],
                    hann_window[hann_window_len // 2 :],
                )

                fading_out, fading_in = (
                    sample[:, start : start + random_lengths[j] // 2],
                    sample[:, start + random_lengths[j] // 2 : start + random_lengths[j]],
                )
                crossfade = hann_window_right * fading_out + hann_window_left * fading_in
                sample = torch.cat(
                    (
                        sample[:, :start],
                        crossfade[:, :],
                        sample[:, start + random_lengths[j] :],
                    ),
                    dim=-1,
                )

            padding = torch.zeros(
                (samples[i].shape[0], samples[i].shape[-1] - sample.shape[-1]),
                dtype=torch.float32,
                device=sample.device,
            )
            sample = torch.cat((sample, padding), dim=-1)
            spliceout_samples.append(sample.unsqueeze(0))

        return ObjectDict(
            samples=torch.cat(spliceout_samples, dim=0),
            sample_rate=sample_rate,
            targets=targets,
            target_rate=target_rate,
        )
