# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution


LATENT_DOWNSAMPLE_FACTOR = 4


class LTX2AudioCausalConv2d(nn.Module):
    """
    A causal 2D convolution that pads asymmetrically along the causal axis.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int | tuple[int, int],
        stride: int = 1,
        dilation: int | tuple[int, int] = 1,
        groups: int = 1,
        bias: bool = True,
        causality_axis: str = "height",
    ) -> None:
        super().__init__()

        self.causality_axis = causality_axis
        kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
        dilation = (dilation, dilation) if isinstance(dilation, int) else dilation

        pad_h = (kernel_size[0] - 1) * dilation[0]
        pad_w = (kernel_size[1] - 1) * dilation[1]

        if self.causality_axis == "none":
            padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
        elif self.causality_axis in {"width", "width-compatibility"}:
            padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
        elif self.causality_axis == "height":
            padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
        else:
            raise ValueError(f"Invalid causality_axis: {causality_axis}")

        self.padding = padding
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.pad(x, self.padding)
        return self.conv(x)


class LTX2AudioPixelNorm(nn.Module):
    """
    Per-pixel (per-location) RMS normalization layer.
    """

    def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
        super().__init__()
        self.dim = dim
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
        rms = torch.sqrt(mean_sq + self.eps)
        return x / rms


class LTX2AudioAttnBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        norm_type: str = "group",
    ) -> None:
        super().__init__()
        self.in_channels = in_channels

        if norm_type == "group":
            self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        elif norm_type == "pixel":
            self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
        else:
            raise ValueError(f"Invalid normalization type: {norm_type}")
        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h_ = self.norm(x)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        batch, channels, height, width = q.shape
        q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
        k = k.reshape(batch, channels, height * width).contiguous()
        attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
        attn = torch.nn.functional.softmax(attn, dim=2)

        v = v.reshape(batch, channels, height * width)
        attn = attn.permute(0, 2, 1).contiguous()
        h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)

        h_ = self.proj_out(h_)
        return x + h_


class LTX2AudioResnetBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int | None = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        temb_channels: int = 512,
        norm_type: str = "group",
        causality_axis: str = "height",
    ) -> None:
        super().__init__()
        self.causality_axis = causality_axis

        if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
            raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        if norm_type == "group":
            self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        elif norm_type == "pixel":
            self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
        else:
            raise ValueError(f"Invalid normalization type: {norm_type}")
        self.non_linearity = nn.SiLU()
        if causality_axis is not None:
            self.conv1 = LTX2AudioCausalConv2d(
                in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
            )
        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if temb_channels > 0:
            self.temb_proj = nn.Linear(temb_channels, out_channels)
        if norm_type == "group":
            self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
        elif norm_type == "pixel":
            self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
        else:
            raise ValueError(f"Invalid normalization type: {norm_type}")
        self.dropout = nn.Dropout(dropout)
        if causality_axis is not None:
            self.conv2 = LTX2AudioCausalConv2d(
                out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
            )
        else:
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                if causality_axis is not None:
                    self.conv_shortcut = LTX2AudioCausalConv2d(
                        in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
                    )
                else:
                    self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            else:
                if causality_axis is not None:
                    self.nin_shortcut = LTX2AudioCausalConv2d(
                        in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
                    )
                else:
                    self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor:
        h = self.norm1(x)
        h = self.non_linearity(h)
        h = self.conv1(h)

        if temb is not None:
            h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]

        h = self.norm2(h)
        h = self.non_linearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)

        return x + h


class LTX2AudioDownsample(nn.Module):
    def __init__(self, in_channels: int, with_conv: bool, causality_axis: str | None = "height") -> None:
        super().__init__()
        self.with_conv = with_conv
        self.causality_axis = causality_axis

        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.with_conv:
            # Padding tuple is in the order: (left, right, top, bottom).
            if self.causality_axis == "none":
                pad = (0, 1, 0, 1)
            elif self.causality_axis == "width":
                pad = (2, 0, 0, 1)
            elif self.causality_axis == "height":
                pad = (0, 1, 2, 0)
            elif self.causality_axis == "width-compatibility":
                pad = (1, 0, 0, 1)
            else:
                raise ValueError(
                    f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
                    f" and `width-compatibility`."
                )

            x = F.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            # with_conv=False implies that causality_axis is "none"
            x = F.avg_pool2d(x, kernel_size=2, stride=2)
        return x


class LTX2AudioUpsample(nn.Module):
    def __init__(self, in_channels: int, with_conv: bool, causality_axis: str | None = "height") -> None:
        super().__init__()
        self.with_conv = with_conv
        self.causality_axis = causality_axis
        if self.with_conv:
            if causality_axis is not None:
                self.conv = LTX2AudioCausalConv2d(
                    in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
                )
            else:
                self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            x = self.conv(x)
            if self.causality_axis is None or self.causality_axis == "none":
                pass
            elif self.causality_axis == "height":
                x = x[:, :, 1:, :]
            elif self.causality_axis == "width":
                x = x[:, :, :, 1:]
            elif self.causality_axis == "width-compatibility":
                pass
            else:
                raise ValueError(f"Invalid causality_axis: {self.causality_axis}")

        return x


class LTX2AudioAudioPatchifier:
    """
    Patchifier for spectrogram/audio latents.
    """

    def __init__(
        self,
        patch_size: int,
        sample_rate: int = 16000,
        hop_length: int = 160,
        audio_latent_downsample_factor: int = 4,
        is_causal: bool = True,
    ):
        self.hop_length = hop_length
        self.sample_rate = sample_rate
        self.audio_latent_downsample_factor = audio_latent_downsample_factor
        self.is_causal = is_causal
        self._patch_size = (1, patch_size, patch_size)

    def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
        batch, channels, time, freq = audio_latents.shape
        return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)

    def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
        batch, time, _ = audio_latents.shape
        return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)

    @property
    def patch_size(self) -> tuple[int, int, int]:
        return self._patch_size


class LTX2AudioEncoder(nn.Module):
    def __init__(
        self,
        base_channels: int = 128,
        output_channels: int = 1,
        num_res_blocks: int = 2,
        attn_resolutions: tuple[int, ...] | None = None,
        in_channels: int = 2,
        resolution: int = 256,
        latent_channels: int = 8,
        ch_mult: tuple[int, ...] = (1, 2, 4),
        norm_type: str = "group",
        causality_axis: str | None = "width",
        dropout: float = 0.0,
        mid_block_add_attention: bool = False,
        sample_rate: int = 16000,
        mel_hop_length: int = 160,
        is_causal: bool = True,
        mel_bins: int | None = 64,
        double_z: bool = True,
    ):
        super().__init__()

        self.sample_rate = sample_rate
        self.mel_hop_length = mel_hop_length
        self.is_causal = is_causal
        self.mel_bins = mel_bins

        self.base_channels = base_channels
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.out_ch = output_channels
        self.give_pre_end = False
        self.tanh_out = False
        self.norm_type = norm_type
        self.latent_channels = latent_channels
        self.channel_multipliers = ch_mult
        self.attn_resolutions = attn_resolutions
        self.causality_axis = causality_axis

        base_block_channels = base_channels
        base_resolution = resolution
        self.z_shape = (1, latent_channels, base_resolution, base_resolution)

        if self.causality_axis is not None:
            self.conv_in = LTX2AudioCausalConv2d(
                in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
            )
        else:
            self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)

        self.down = nn.ModuleList()
        block_in = base_block_channels
        curr_res = self.resolution

        for level in range(self.num_resolutions):
            stage = nn.Module()
            stage.block = nn.ModuleList()
            stage.attn = nn.ModuleList()
            block_out = self.base_channels * self.channel_multipliers[level]

            for _ in range(self.num_res_blocks):
                stage.block.append(
                    LTX2AudioResnetBlock(
                        in_channels=block_in,
                        out_channels=block_out,
                        temb_channels=self.temb_ch,
                        dropout=dropout,
                        norm_type=self.norm_type,
                        causality_axis=self.causality_axis,
                    )
                )
                block_in = block_out
                if self.attn_resolutions:
                    if curr_res in self.attn_resolutions:
                        stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))

            if level != self.num_resolutions - 1:
                stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
                curr_res = curr_res // 2

            self.down.append(stage)

        self.mid = nn.Module()
        self.mid.block_1 = LTX2AudioResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
            norm_type=self.norm_type,
            causality_axis=self.causality_axis,
        )
        if mid_block_add_attention:
            self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
        else:
            self.mid.attn_1 = nn.Identity()
        self.mid.block_2 = LTX2AudioResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
            norm_type=self.norm_type,
            causality_axis=self.causality_axis,
        )

        final_block_channels = block_in
        z_channels = 2 * latent_channels if double_z else latent_channels
        if self.norm_type == "group":
            self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
        elif self.norm_type == "pixel":
            self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
        else:
            raise ValueError(f"Invalid normalization type: {self.norm_type}")
        self.non_linearity = nn.SiLU()

        if self.causality_axis is not None:
            self.conv_out = LTX2AudioCausalConv2d(
                final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
            )
        else:
            self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
        hidden_states = self.conv_in(hidden_states)

        for level in range(self.num_resolutions):
            stage = self.down[level]
            for block_idx, block in enumerate(stage.block):
                hidden_states = block(hidden_states, temb=None)
                if stage.attn:
                    hidden_states = stage.attn[block_idx](hidden_states)

            if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
                hidden_states = stage.downsample(hidden_states)

        hidden_states = self.mid.block_1(hidden_states, temb=None)
        hidden_states = self.mid.attn_1(hidden_states)
        hidden_states = self.mid.block_2(hidden_states, temb=None)

        hidden_states = self.norm_out(hidden_states)
        hidden_states = self.non_linearity(hidden_states)
        hidden_states = self.conv_out(hidden_states)

        return hidden_states


class LTX2AudioDecoder(nn.Module):
    """
    Symmetric decoder that reconstructs audio spectrograms from latent features.

    The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
    convolutions.
    """

    def __init__(
        self,
        base_channels: int = 128,
        output_channels: int = 1,
        num_res_blocks: int = 2,
        attn_resolutions: tuple[int, ...] | None = None,
        in_channels: int = 2,
        resolution: int = 256,
        latent_channels: int = 8,
        ch_mult: tuple[int, ...] = (1, 2, 4),
        norm_type: str = "group",
        causality_axis: str | None = "width",
        dropout: float = 0.0,
        mid_block_add_attention: bool = False,
        sample_rate: int = 16000,
        mel_hop_length: int = 160,
        is_causal: bool = True,
        mel_bins: int | None = 64,
    ) -> None:
        super().__init__()

        self.sample_rate = sample_rate
        self.mel_hop_length = mel_hop_length
        self.is_causal = is_causal
        self.mel_bins = mel_bins
        self.patchifier = LTX2AudioAudioPatchifier(
            patch_size=1,
            audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
            sample_rate=sample_rate,
            hop_length=mel_hop_length,
            is_causal=is_causal,
        )

        self.base_channels = base_channels
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.out_ch = output_channels
        self.give_pre_end = False
        self.tanh_out = False
        self.norm_type = norm_type
        self.latent_channels = latent_channels
        self.channel_multipliers = ch_mult
        self.attn_resolutions = attn_resolutions
        self.causality_axis = causality_axis

        base_block_channels = base_channels * self.channel_multipliers[-1]
        base_resolution = resolution // (2 ** (self.num_resolutions - 1))
        self.z_shape = (1, latent_channels, base_resolution, base_resolution)

        if self.causality_axis is not None:
            self.conv_in = LTX2AudioCausalConv2d(
                latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
            )
        else:
            self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
        self.non_linearity = nn.SiLU()
        self.mid = nn.Module()
        self.mid.block_1 = LTX2AudioResnetBlock(
            in_channels=base_block_channels,
            out_channels=base_block_channels,
            temb_channels=self.temb_ch,
            dropout=dropout,
            norm_type=self.norm_type,
            causality_axis=self.causality_axis,
        )
        if mid_block_add_attention:
            self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
        else:
            self.mid.attn_1 = nn.Identity()
        self.mid.block_2 = LTX2AudioResnetBlock(
            in_channels=base_block_channels,
            out_channels=base_block_channels,
            temb_channels=self.temb_ch,
            dropout=dropout,
            norm_type=self.norm_type,
            causality_axis=self.causality_axis,
        )

        self.up = nn.ModuleList()
        block_in = base_block_channels
        curr_res = self.resolution // (2 ** (self.num_resolutions - 1))

        for level in reversed(range(self.num_resolutions)):
            stage = nn.Module()
            stage.block = nn.ModuleList()
            stage.attn = nn.ModuleList()
            block_out = self.base_channels * self.channel_multipliers[level]

            for _ in range(self.num_res_blocks + 1):
                stage.block.append(
                    LTX2AudioResnetBlock(
                        in_channels=block_in,
                        out_channels=block_out,
                        temb_channels=self.temb_ch,
                        dropout=dropout,
                        norm_type=self.norm_type,
                        causality_axis=self.causality_axis,
                    )
                )
                block_in = block_out
                if self.attn_resolutions:
                    if curr_res in self.attn_resolutions:
                        stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))

            if level != 0:
                stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
                curr_res *= 2

            self.up.insert(0, stage)

        final_block_channels = block_in

        if self.norm_type == "group":
            self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
        elif self.norm_type == "pixel":
            self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
        else:
            raise ValueError(f"Invalid normalization type: {self.norm_type}")

        if self.causality_axis is not None:
            self.conv_out = LTX2AudioCausalConv2d(
                final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
            )
        else:
            self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)

    def forward(
        self,
        sample: torch.Tensor,
    ) -> torch.Tensor:
        _, _, frames, mel_bins = sample.shape

        target_frames = frames * LATENT_DOWNSAMPLE_FACTOR

        if self.causality_axis is not None:
            target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)

        target_channels = self.out_ch
        target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins

        hidden_features = self.conv_in(sample)
        hidden_features = self.mid.block_1(hidden_features, temb=None)
        hidden_features = self.mid.attn_1(hidden_features)
        hidden_features = self.mid.block_2(hidden_features, temb=None)

        for level in reversed(range(self.num_resolutions)):
            stage = self.up[level]
            for block_idx, block in enumerate(stage.block):
                hidden_features = block(hidden_features, temb=None)
                if stage.attn:
                    hidden_features = stage.attn[block_idx](hidden_features)

            if level != 0 and hasattr(stage, "upsample"):
                hidden_features = stage.upsample(hidden_features)

        if self.give_pre_end:
            return hidden_features

        hidden = self.norm_out(hidden_features)
        hidden = self.non_linearity(hidden)
        decoded_output = self.conv_out(hidden)
        decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output

        _, _, current_time, current_freq = decoded_output.shape
        target_time = target_frames
        target_freq = target_mel_bins

        decoded_output = decoded_output[
            :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
        ]

        time_padding_needed = target_time - decoded_output.shape[2]
        freq_padding_needed = target_freq - decoded_output.shape[3]

        if time_padding_needed > 0 or freq_padding_needed > 0:
            padding = (
                0,
                max(freq_padding_needed, 0),
                0,
                max(time_padding_needed, 0),
            )
            decoded_output = F.pad(decoded_output, padding)

        decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]

        return decoded_output


class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
    r"""
    LTX2 audio VAE for encoding and decoding audio latent representations.
    """

    _supports_gradient_checkpointing = False

    @register_to_config
    def __init__(
        self,
        base_channels: int = 128,
        output_channels: int = 2,
        ch_mult: tuple[int, ...] = (1, 2, 4),
        num_res_blocks: int = 2,
        attn_resolutions: tuple[int, ...] | None = None,
        in_channels: int = 2,
        resolution: int = 256,
        latent_channels: int = 8,
        norm_type: str = "pixel",
        causality_axis: str | None = "height",
        dropout: float = 0.0,
        mid_block_add_attention: bool = False,
        sample_rate: int = 16000,
        mel_hop_length: int = 160,
        is_causal: bool = True,
        mel_bins: int | None = 64,
        double_z: bool = True,
    ) -> None:
        super().__init__()

        supported_causality_axes = {"none", "width", "height", "width-compatibility"}
        if causality_axis not in supported_causality_axes:
            raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")

        attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions

        self.encoder = LTX2AudioEncoder(
            base_channels=base_channels,
            output_channels=output_channels,
            ch_mult=ch_mult,
            num_res_blocks=num_res_blocks,
            attn_resolutions=attn_resolution_set,
            in_channels=in_channels,
            resolution=resolution,
            latent_channels=latent_channels,
            norm_type=norm_type,
            causality_axis=causality_axis,
            dropout=dropout,
            mid_block_add_attention=mid_block_add_attention,
            sample_rate=sample_rate,
            mel_hop_length=mel_hop_length,
            is_causal=is_causal,
            mel_bins=mel_bins,
            double_z=double_z,
        )

        self.decoder = LTX2AudioDecoder(
            base_channels=base_channels,
            output_channels=output_channels,
            ch_mult=ch_mult,
            num_res_blocks=num_res_blocks,
            attn_resolutions=attn_resolution_set,
            in_channels=in_channels,
            resolution=resolution,
            latent_channels=latent_channels,
            norm_type=norm_type,
            causality_axis=causality_axis,
            dropout=dropout,
            mid_block_add_attention=mid_block_add_attention,
            sample_rate=sample_rate,
            mel_hop_length=mel_hop_length,
            is_causal=is_causal,
            mel_bins=mel_bins,
        )

        # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
        # the entire dataset and stored in model's checkpoint under AudioVAE state_dict
        latents_std = torch.ones((base_channels,))
        latents_mean = torch.zeros((base_channels,))
        self.register_buffer("latents_mean", latents_mean, persistent=True)
        self.register_buffer("latents_std", latents_std, persistent=True)

        # TODO: calculate programmatically instead of hardcoding
        self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR  # 4
        # TODO: confirm whether the mel compression ratio below is correct
        self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
        self.use_slicing = False

    def _encode(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)

    @apply_forward_hook
    def encode(self, x: torch.Tensor, return_dict: bool = True):
        if self.use_slicing and x.shape[0] > 1:
            encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
            h = torch.cat(encoded_slices)
        else:
            h = self._encode(x)
        posterior = DiagonalGaussianDistribution(h)

        if not return_dict:
            return (posterior,)
        return AutoencoderKLOutput(latent_dist=posterior)

    def _decode(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z)

    @apply_forward_hook
    def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor:
        if self.use_slicing and z.shape[0] > 1:
            decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
            decoded = torch.cat(decoded_slices)
        else:
            decoded = self._decode(z)

        if not return_dict:
            return (decoded,)

        return DecoderOutput(sample=decoded)

    def forward(
        self,
        sample: torch.Tensor,
        sample_posterior: bool = False,
        return_dict: bool = True,
        generator: torch.Generator | None = None,
    ) -> DecoderOutput | torch.Tensor:
        posterior = self.encode(sample).latent_dist
        if sample_posterior:
            z = posterior.sample(generator=generator)
        else:
            z = posterior.mode()
        dec = self.decode(z)
        if not return_dict:
            return (dec.sample,)
        return dec
