import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin


class ResBlock(nn.Module):
    def __init__(
        self,
        channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        dilations: tuple[int, ...] = (1, 3, 5),
        leaky_relu_negative_slope: float = 0.1,
        padding_mode: str = "same",
    ):
        super().__init__()
        self.dilations = dilations
        self.negative_slope = leaky_relu_negative_slope

        self.convs1 = nn.ModuleList(
            [
                nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
                for dilation in dilations
            ]
        )

        self.convs2 = nn.ModuleList(
            [
                nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
                for _ in range(len(dilations))
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for conv1, conv2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, negative_slope=self.negative_slope)
            xt = conv1(xt)
            xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
            xt = conv2(xt)
            x = x + xt
        return x


class LTX2Vocoder(ModelMixin, ConfigMixin):
    r"""
    LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
    """

    @register_to_config
    def __init__(
        self,
        in_channels: int = 128,
        hidden_channels: int = 1024,
        out_channels: int = 2,
        upsample_kernel_sizes: list[int] = [16, 15, 8, 4, 4],
        upsample_factors: list[int] = [6, 5, 2, 2, 2],
        resnet_kernel_sizes: list[int] = [3, 7, 11],
        resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        leaky_relu_negative_slope: float = 0.1,
        output_sampling_rate: int = 24000,
    ):
        super().__init__()
        self.num_upsample_layers = len(upsample_kernel_sizes)
        self.resnets_per_upsample = len(resnet_kernel_sizes)
        self.out_channels = out_channels
        self.total_upsample_factor = math.prod(upsample_factors)
        self.negative_slope = leaky_relu_negative_slope

        if self.num_upsample_layers != len(upsample_factors):
            raise ValueError(
                f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
                f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
            )

        if self.resnets_per_upsample != len(resnet_dilations):
            raise ValueError(
                f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
                f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
            )

        self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)

        self.upsamplers = nn.ModuleList()
        self.resnets = nn.ModuleList()
        input_channels = hidden_channels
        for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
            output_channels = input_channels // 2
            self.upsamplers.append(
                nn.ConvTranspose1d(
                    input_channels,  # hidden_channels // (2 ** i)
                    output_channels,  # hidden_channels // (2 ** (i + 1))
                    kernel_size,
                    stride=stride,
                    padding=(kernel_size - stride) // 2,
                )
            )

            for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
                self.resnets.append(
                    ResBlock(
                        output_channels,
                        kernel_size,
                        dilations=dilations,
                        leaky_relu_negative_slope=leaky_relu_negative_slope,
                    )
                )
            input_channels = output_channels

        self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)

    def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
        r"""
        Forward pass of the vocoder.

        Args:
            hidden_states (`torch.Tensor`):
                Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
                is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
                `True`.
            time_last (`bool`, *optional*, defaults to `False`):
                Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.

        Returns:
            `torch.Tensor`:
                Audio waveform tensor of shape (batch_size, out_channels, audio_length)
        """

        # Ensure that the time/frame dimension is last
        if not time_last:
            hidden_states = hidden_states.transpose(2, 3)
        # Combine channels and frequency (mel bins) dimensions
        hidden_states = hidden_states.flatten(1, 2)

        hidden_states = self.conv_in(hidden_states)

        for i in range(self.num_upsample_layers):
            hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
            hidden_states = self.upsamplers[i](hidden_states)

            # Run all resnets in parallel on hidden_states
            start = i * self.resnets_per_upsample
            end = (i + 1) * self.resnets_per_upsample
            resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)

            hidden_states = torch.mean(resnet_outputs, dim=0)

        # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
        # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
        hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
        hidden_states = self.conv_out(hidden_states)
        hidden_states = torch.tanh(hidden_states)

        return hidden_states
