# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange


__all__ = ["TurboVAED"]

ACT2CLS = {"swish": nn.SiLU, "silu": nn.SiLU, "mish": nn.Mish, "gelu": nn.GELU, "relu": nn.ReLU}


def get_activation(act_fn: str) -> nn.Module:
    act_fn = act_fn.lower()
    if act_fn in ACT2CLS:
        return ACT2CLS[act_fn]()
    else:
        raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")


def unpatchify(x, patch_size):
    """
    Unpatchify operation: convert patched representation back to original spatial resolution.
    Similar to Wan VAE's unpatchify.

    Args:
        x: Input tensor with shape [batch_size, (channels * patch_size * patch_size), frame, height, width]
        patch_size: The patch size used during patchification

    Returns:
        Tensor with shape [batch_size, channels, frame, height * patch_size, width * patch_size]
    """
    if patch_size == 1:
        return x

    if x.dim() != 5:
        raise ValueError(f"Invalid input shape: {x.shape}")

    # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
    batch_size, c_patches, frames, height, width = x.shape
    channels = c_patches // (patch_size * patch_size)

    # Reshape to [b, c, patch_size, patch_size, f, h, w]
    x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)

    # Rearrange to [b, c, f, h * patch_size, w * patch_size]
    x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
    x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)

    return x


class RMSNorm(nn.Module):
    r"""
    RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.

    Args:
        dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
        eps (`float`): Small value to use when calculating the reciprocal of the square-root.
        elementwise_affine (`bool`, defaults to `True`):
            Boolean flag to denote if affine transformation should be applied.
        bias (`bool`, defaults to False): If also training the `bias` param.
    """

    def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
        super().__init__()

        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if isinstance(dim, int):
            dim = (dim,)

        self.dim = torch.Size(dim)

        self.weight = None

        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        if self.weight is not None:
            # convert into half-precision if necessary
            if self.weight.dtype in [torch.float16, torch.bfloat16]:
                hidden_states = hidden_states.to(self.weight.dtype)
            hidden_states = hidden_states * self.weight
        else:
            hidden_states = hidden_states.to(input_dtype)

        return hidden_states


class TurboVAEDConv2dSplitUpsampler(nn.Module):
    def __init__(
        self,
        in_channels: int,
        kernel_size: Union[int, Tuple[int, int]] = 3,
        stride: Union[int, Tuple[int, int]] = 1,
        upscale_factor: int = 1,
        padding_mode: str = "zeros",
    ):
        super().__init__()

        self.in_channels = in_channels
        self.stride = stride if isinstance(stride, tuple) else (stride, stride)
        self.upscale_factor = upscale_factor

        out_channels = in_channels

        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)

        height_pad = self.kernel_size[0] // 2
        width_pad = self.kernel_size[1] // 2
        padding = (height_pad, width_pad)

        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.kernel_size,
            stride=1,
            padding=padding,
            padding_mode=padding_mode,
        )

    @torch.compile
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.conv(hidden_states)
        hidden_states = torch.nn.functional.pixel_shuffle(hidden_states, self.stride[0])

        return hidden_states


class TurboVAEDCausalConv3d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int, int]] = 3,
        stride: Union[int, Tuple[int, int, int]] = 1,
        dilation: Union[int, Tuple[int, int, int]] = 1,
        groups: int = 1,
        padding_mode: str = "zeros",
        is_causal: bool = False,
    ):
        super().__init__()

        assert is_causal == False
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_causal = is_causal
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)

        dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
        stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
        height_pad = self.kernel_size[1] // 2
        width_pad = self.kernel_size[2] // 2
        padding = (0, height_pad, width_pad)

        self.conv = nn.Conv3d(
            in_channels,
            out_channels,
            self.kernel_size,
            stride=stride,
            dilation=dilation,
            groups=groups,
            padding=padding,
            padding_mode=padding_mode,
        )

    @torch.compile
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        time_kernel_size = self.kernel_size[0]

        if time_kernel_size > 1:
            pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
            pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
            hidden_states = torch.cat([pad_left, hidden_states, pad_right], dim=2)

        hidden_states = self.conv(hidden_states)
        return hidden_states


class TurboVAEDCausalDepthwiseSeperableConv3d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int, int]] = 3,
        stride: Union[int, Tuple[int, int, int]] = 1,
        dilation: Union[int, Tuple[int, int, int]] = 1,
        padding_mode: str = "zeros",
        is_causal: bool = True,
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_causal = is_causal
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
        self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
        self.dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)

        # Calculate padding for height and width dimensions
        height_pad = self.kernel_size[1] // 2
        width_pad = self.kernel_size[2] // 2
        self.padding = (0, height_pad, width_pad)

        # Depthwise Convolution
        self.depthwise_conv = nn.Conv3d(
            in_channels,
            in_channels,
            self.kernel_size,
            stride=self.stride,
            dilation=self.dilation,
            groups=in_channels,  # Each input channel is convolved separately
            padding=self.padding,
            padding_mode=padding_mode,
        )

        # Pointwise Convolution
        self.pointwise_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)  # 1x1x1 convolution to mix channels

    @torch.compile
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        time_kernel_size = self.kernel_size[0]
        if time_kernel_size > 1:
            pad_count = (time_kernel_size - 1) // 2
            pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, pad_count, 1, 1))
            pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, pad_count, 1, 1))
            hidden_states = torch.cat([pad_left, hidden_states, pad_right], dim=2)

        # Apply depthwise convolution
        hidden_states = self.depthwise_conv(hidden_states)
        # Apply pointwise convolution
        hidden_states = self.pointwise_conv(hidden_states)

        return hidden_states


class TurboVAEDResnetBlock3d(nn.Module):
    r"""
    A 3D ResNet block used in the TurboVAED model.

    Args:
        in_channels (`int`):
            Number of input channels.
        out_channels (`int`, *optional*):
            Number of output channels. If None, defaults to `in_channels`.
        dropout (`float`, defaults to `0.0`):
            Dropout rate.
        eps (`float`, defaults to `1e-6`):
            Epsilon value for normalization layers.
        elementwise_affine (`bool`, defaults to `False`):
            Whether to enable elementwise affinity in the normalization layers.
        non_linearity (`str`, defaults to `"swish"`):
            Activation function to use.
        conv_shortcut (bool, defaults to `False`):
            Whether or not to use a convolution shortcut.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        dropout: float = 0.0,
        eps: float = 1e-6,
        elementwise_affine: bool = False,
        non_linearity: str = "swish",
        is_causal: bool = True,
        is_upsampler_modified: bool = False,
        is_dw_conv: bool = False,
        dw_kernel_size: int = 3,
    ) -> None:
        super().__init__()

        out_channels = out_channels or in_channels

        self.nonlinearity = get_activation(non_linearity)

        self.conv_operation = TurboVAEDCausalConv3d if not is_dw_conv else TurboVAEDCausalDepthwiseSeperableConv3d
        self.kernel_size = 3 if not is_dw_conv else dw_kernel_size

        self.is_upsampler_modified = is_upsampler_modified
        self.replace_nonlinearity = get_activation("relu")

        self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
        self.conv1 = self.conv_operation(
            in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, is_causal=is_causal
        )

        self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = self.conv_operation(
            in_channels=out_channels, out_channels=out_channels, kernel_size=self.kernel_size, is_causal=is_causal
        )

        self.norm3 = None
        self.conv_shortcut = None
        if in_channels != out_channels:
            self.norm3 = RMSNorm(in_channels, eps=eps, elementwise_affine=elementwise_affine)
            self.conv_shortcut = self.conv_operation(
                in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
            )

    @torch.compile
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        hidden_states = inputs

        hidden_states = self.norm1(hidden_states)

        if self.is_upsampler_modified:
            hidden_states = self.replace_nonlinearity(hidden_states)
        else:
            hidden_states = self.nonlinearity(hidden_states)

        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)

        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.dropout(hidden_states)

        hidden_states = self.conv2(hidden_states)

        if self.norm3 is not None:
            inputs = self.norm3(inputs)

        if self.conv_shortcut is not None:
            inputs = self.conv_shortcut(inputs)

        hidden_states = hidden_states + inputs
        return hidden_states


class WanUpsample(nn.Upsample):
    r"""
    Perform upsampling while ensuring the output tensor has the same data type as the input.

    Args:
        x (torch.Tensor): Input tensor to be upsampled.

    Returns:
        torch.Tensor: Upsampled tensor with the same data type as the input.
    """

    def forward(self, x):
        return super().forward(x.float()).type_as(x)


class WanResample(nn.Module):
    r"""
    A custom resampling module for 2D and 3D data.

    Args:
        dim (int): The number of input/output channels.
        mode (str): The resampling mode. Must be one of:
            - 'none': No resampling (identity operation).
            - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
            - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
    """

    def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
        super().__init__()
        self.dim = dim
        self.mode = mode

        # default to dim //2
        if upsample_out_dim is None:
            upsample_out_dim = dim // 2

        # layers
        if mode == "upsample2d":
            self.resample = nn.Sequential(
                WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
            )
        elif mode == "upsample3d":
            self.resample = nn.Sequential(
                WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
            )
            self.time_conv = TurboVAEDCausalConv3d(dim, dim * 2, (3, 1, 1))
        else:
            self.resample = nn.Identity()

    def forward(self, x, is_first_chunk: bool = True):
        b, c, t, h, w = x.shape
        if self.mode == "upsample3d":
            x = self.time_conv(x)
            x = rearrange(x, 'b (n_split c) t h w -> b c (t n_split) h w', n_split=2)
            assert x.shape == (b, c, t * 2, h, w), "x.shape: {}, expected: {}".format(x.shape, (b, c, t * 2, h, w))
            if is_first_chunk:
                x = x[:, :, 1:]

        x = rearrange(x, "b c t h w -> (b t) c h w")
        x = self.resample(x)
        x = rearrange(x, "(b t) c h w -> b c t h w", b=b)

        return x


class TurboVAEDMidBlock3d(nn.Module):
    r"""
    A middle block used in the TurboVAED model.

    Args:
        in_channels (`int`):
            Number of input channels.
        num_layers (`int`, defaults to `1`):
            Number of resnet layers.
        dropout (`float`, defaults to `0.0`):
            Dropout rate.
        resnet_eps (`float`, defaults to `1e-6`):
            Epsilon value for normalization layers.
        resnet_act_fn (`str`, defaults to `"swish"`):
            Activation function to use.
        is_causal (`bool`, defaults to `True`):
            Whether this layer behaves causally (future frames depend only on past frames) or not.
    """

    _supports_gradient_checkpointing = True

    def __init__(
        self,
        in_channels: int,
        num_layers: int = 1,
        dropout: float = 0.0,
        resnet_eps: float = 1e-6,
        resnet_act_fn: str = "swish",
        is_causal: bool = True,
        is_dw_conv: bool = False,
        dw_kernel_size: int = 3,
    ) -> None:
        super().__init__()

        resnets = []
        for _ in range(num_layers):
            resnets.append(
                TurboVAEDResnetBlock3d(
                    in_channels=in_channels,
                    out_channels=in_channels,
                    dropout=dropout,
                    eps=resnet_eps,
                    non_linearity=resnet_act_fn,
                    is_causal=is_causal,
                    is_dw_conv=is_dw_conv,
                    dw_kernel_size=dw_kernel_size,
                )
            )
        self.resnets = nn.ModuleList(resnets)

        self.gradient_checkpointing = False

    @torch.compile
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        r"""Forward method of the `LTXMidBlock3D` class."""

        for i, resnet in enumerate(self.resnets):
            if torch.is_grad_enabled() and self.gradient_checkpointing:
                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
            else:
                hidden_states = resnet(hidden_states)

        return hidden_states


class TurboVAEDUpBlock3d(nn.Module):
    r"""
    Up block used in the TurboVAED model.

    Args:
        in_channels (`int`):
            Number of input channels.
        out_channels (`int`, *optional*):
            Number of output channels. If None, defaults to `in_channels`.
        num_layers (`int`, defaults to `1`):
            Number of resnet layers.
        dropout (`float`, defaults to `0.0`):
            Dropout rate.
        resnet_eps (`float`, defaults to `1e-6`):
            Epsilon value for normalization layers.
        resnet_act_fn (`str`, defaults to `"swish"`):
            Activation function to use.
        spatio_temporal_scale (`bool`, defaults to `True`):
            Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
            Whether or not to downsample across temporal dimension.
        is_causal (`bool`, defaults to `True`):
            Whether this layer behaves causally (future frames depend only on past frames) or not.
    """

    _supports_gradient_checkpointing = True

    def __init__(
        self,
        in_channels: int,
        out_channels: Optional[int] = None,
        num_layers: int = 1,
        dropout: float = 0.0,
        resnet_eps: float = 1e-6,
        resnet_act_fn: str = "swish",
        spatio_temporal_scale: bool = True,
        is_causal: bool = True,
        is_dw_conv: bool = False,
        dw_kernel_size: int = 3,
        spatio_only: bool = False,
    ):
        super().__init__()

        out_channels = out_channels or in_channels

        self.conv_in = None
        if in_channels != out_channels:
            self.conv_in = TurboVAEDResnetBlock3d(
                in_channels=in_channels,
                out_channels=out_channels,
                dropout=dropout,
                eps=resnet_eps,
                non_linearity=resnet_act_fn,
                is_causal=is_causal,
                is_dw_conv=is_dw_conv,
                dw_kernel_size=dw_kernel_size,
            )

        self.upsamplers = None
        if spatio_temporal_scale:
            self.upsamplers = nn.ModuleList(
                [
                    WanResample(
                        dim=out_channels, mode="upsample2d" if spatio_only else "upsample3d", upsample_out_dim=out_channels
                    )
                ]
            )

        resnets = []
        for _ in range(num_layers):
            resnets.append(
                TurboVAEDResnetBlock3d(
                    in_channels=out_channels,
                    out_channels=out_channels,
                    dropout=dropout,
                    eps=resnet_eps,
                    non_linearity=resnet_act_fn,
                    is_causal=is_causal,
                    is_dw_conv=is_dw_conv,
                    dw_kernel_size=dw_kernel_size,
                    is_upsampler_modified=(spatio_temporal_scale),
                )
            )
        self.resnets = nn.ModuleList(resnets)

        self.gradient_checkpointing = False

    @torch.compile
    def forward(self, hidden_states: torch.Tensor, is_first_chunk: bool) -> torch.Tensor:
        if self.conv_in is not None:
            hidden_states = self.conv_in(hidden_states)

        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states, is_first_chunk=is_first_chunk)

        for i, resnet in enumerate(self.resnets):
            if torch.is_grad_enabled() and self.gradient_checkpointing:
                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
            else:
                hidden_states = resnet(hidden_states)

        return hidden_states


class TurboVAEDDecoder3d(nn.Module):
    r"""
    The `TurboVAEDDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
    sample.

    Args:
        in_channels (`int`, defaults to 128):
            Number of latent channels.
        out_channels (`int`, defaults to 3):
            Number of output channels.
        block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
            The number of output channels for each block.
        spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
            Whether a block should contain spatio-temporal upscaling layers or not.
        layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
            The number of layers per block.
        patch_size (`int`, defaults to `4`):
            The size of spatial patches.
        patch_size_t (`int`, defaults to `1`):
            The size of temporal patches.
        resnet_norm_eps (`float`, defaults to `1e-6`):
            Epsilon value for ResNet normalization layers.
        is_causal (`bool`, defaults to `False`):
            Whether this layer behaves causally (future frames depend only on past frames) or not.
    """

    def __init__(
        self,
        in_channels: int = 128,
        out_channels: int = 3,
        block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
        spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
        layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
        patch_size: int = 4,
        patch_size_t: int = 1,
        resnet_norm_eps: float = 1e-6,
        is_causal: bool = False,
        decoder_is_dw_conv: Tuple[bool, ...] = (False, False, False, False, False),
        decoder_dw_kernel_size: int = 3,
        spatio_only: Tuple[bool, ...] = (False, False, False, False),
        upsampling: bool = False,
        use_unpatchify: bool = False,
    ) -> None:
        super().__init__()

        self.patch_size = patch_size
        self.patch_size_t = patch_size_t
        self.out_channels = out_channels

        self.upsampling = upsampling
        self.use_unpatchify = use_unpatchify

        block_out_channels = tuple(reversed(block_out_channels))
        spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
        layers_per_block = tuple(reversed(layers_per_block))
        decoder_is_dw_conv = tuple(reversed(decoder_is_dw_conv))
        spatio_only = tuple(reversed(spatio_only))
        output_channel = block_out_channels[0]

        self.conv_in = TurboVAEDCausalConv3d(
            in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
        )

        self.mid_block = TurboVAEDMidBlock3d(
            in_channels=output_channel,
            num_layers=layers_per_block[0],
            resnet_eps=resnet_norm_eps,
            is_causal=is_causal,
            is_dw_conv=decoder_is_dw_conv[0],
            dw_kernel_size=decoder_dw_kernel_size,
        )

        # up blocks
        num_block_out_channels = len(block_out_channels)
        self.up_blocks = nn.ModuleList([])
        for i in range(num_block_out_channels):
            input_channel = output_channel
            output_channel = block_out_channels[i]

            up_block = TurboVAEDUpBlock3d(
                in_channels=input_channel,
                out_channels=output_channel,
                num_layers=layers_per_block[i + 1],
                resnet_eps=resnet_norm_eps,
                spatio_temporal_scale=spatio_temporal_scaling[i],
                is_causal=is_causal,
                is_dw_conv=decoder_is_dw_conv[i + 1],
                dw_kernel_size=decoder_dw_kernel_size,
                spatio_only=spatio_only[i],
            )

            self.up_blocks.append(up_block)

        # out
        assert self.patch_size == 2
        if not self.use_unpatchify:
            self.norm_up_1 = RMSNorm(output_channel, eps=1e-8, elementwise_affine=False)
            self.upsampler2d_1 = TurboVAEDConv2dSplitUpsampler(in_channels=output_channel, kernel_size=3, stride=(2, 2))
            output_channel = output_channel // (2 * 2)

        self.conv_act = nn.SiLU()

        # When use_unpatchify=True, conv_out outputs more channels (out_channels * patch_size^2)
        # and unpatchify will recover the spatial resolution
        conv_out_channels = self.out_channels
        if self.use_unpatchify and self.patch_size >= 2:
            conv_out_channels = self.out_channels * self.patch_size * self.patch_size

        self.conv_out = TurboVAEDCausalConv3d(
            in_channels=output_channel, out_channels=conv_out_channels, kernel_size=3, stride=1, is_causal=is_causal
        )

        self.gradient_checkpointing = False

    @torch.compile
    def forward(self, hidden_states: torch.Tensor, is_first_chunk: bool) -> torch.Tensor:
        hidden_states = self.conv_in(hidden_states)

        if torch.is_grad_enabled() and self.gradient_checkpointing:

            def create_custom_forward(module):
                def create_forward(*inputs):
                    return module(*inputs)

                return create_forward

            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)

            for up_block in self.up_blocks:
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(up_block), hidden_states, is_first_chunk
                )
        else:
            hidden_states = self.mid_block(hidden_states)

            for index, up_block in enumerate(self.up_blocks):
                hidden_states = up_block(hidden_states, is_first_chunk=is_first_chunk)

        if not self.use_unpatchify:
            hidden_states = self.norm_up_1(hidden_states)
            hidden_states = self.conv_act(hidden_states)

            hidden_states_array = []
            for t in range(hidden_states.shape[2]):
                h = self.upsampler2d_1(hidden_states[:, :, t, :, :])
                hidden_states_array.append(h)
            hidden_states = torch.stack(hidden_states_array, dim=2)

        # RMSNorm
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + 1e-8)
        hidden_states = hidden_states.to(input_dtype)

        hidden_states = self.conv_act(hidden_states)

        hidden_states = self.conv_out(hidden_states)

        if self.use_unpatchify:
            hidden_states = unpatchify(hidden_states, self.patch_size)

        return hidden_states


class TurboVAED(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,  # useless arg for compatibility, we only use latent channels
        out_channels: int = 3,
        latent_channels: int = 128,
        decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
        decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
        decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
        patch_size: int = 4,
        patch_size_t: int = 1,
        resnet_norm_eps: float = 1e-6,
        scaling_factor: float = 1.0,
        decoder_causal: bool = False,
        decoder_is_dw_conv: Tuple[bool, ...] = (False, False, False, False, False),
        decoder_dw_kernel_size: int = 3,
        decoder_spatio_only: Tuple[bool, ...] = (False, False, False, False),
        first_chunk_size: int = 3,
        step_size: int = 5,
        spatial_compression_ratio: int = 16,
        temporal_compression_ratio: int = 4,
        use_unpatchify: bool = False,
    ):
        super().__init__()

        self.decoder = TurboVAEDDecoder3d(
            in_channels=latent_channels,
            out_channels=out_channels,
            block_out_channels=decoder_block_out_channels,
            spatio_temporal_scaling=decoder_spatio_temporal_scaling,
            layers_per_block=decoder_layers_per_block,
            patch_size=patch_size,
            patch_size_t=patch_size_t,
            resnet_norm_eps=resnet_norm_eps,
            is_causal=decoder_causal,
            decoder_is_dw_conv=decoder_is_dw_conv,
            decoder_dw_kernel_size=decoder_dw_kernel_size,
            spatio_only=decoder_spatio_only,
            use_unpatchify=use_unpatchify,
        )

        self.first_chunk_size = first_chunk_size
        self.step_size = step_size

        self.spatial_compression_ratio = spatial_compression_ratio
        self.temporal_compression_ratio = temporal_compression_ratio

        self.z_dim = latent_channels
        self.mean = torch.tensor(
            [
                -0.2289,
                -0.0052,
                -0.1323,
                -0.2339,
                -0.2799,
                0.0174,
                0.1838,
                0.1557,
                -0.1382,
                0.0542,
                0.2813,
                0.0891,
                0.1570,
                -0.0098,
                0.0375,
                -0.1825,
                -0.2246,
                -0.1207,
                -0.0698,
                0.5109,
                0.2665,
                -0.2108,
                -0.2158,
                0.2502,
                -0.2055,
                -0.0322,
                0.1109,
                0.1567,
                -0.0729,
                0.0899,
                -0.2799,
                -0.1230,
                -0.0313,
                -0.1649,
                0.0117,
                0.0723,
                -0.2839,
                -0.2083,
                -0.0520,
                0.3748,
                0.0152,
                0.1957,
                0.1433,
                -0.2944,
                0.3573,
                -0.0548,
                -0.1681,
                -0.0667,
            ],
            dtype=torch.float32,
            device="cuda",
        )
        self.std = torch.tensor(
            [
                0.4765,
                1.0364,
                0.4514,
                1.1677,
                0.5313,
                0.4990,
                0.4818,
                0.5013,
                0.8158,
                1.0344,
                0.5894,
                1.0901,
                0.6885,
                0.6165,
                0.8454,
                0.4978,
                0.5759,
                0.3523,
                0.7135,
                0.6804,
                0.5833,
                1.4146,
                0.8986,
                0.5659,
                0.7069,
                0.5338,
                0.4889,
                0.4917,
                0.4069,
                0.4999,
                0.6866,
                0.4093,
                0.5709,
                0.6065,
                0.6415,
                0.4944,
                0.5726,
                1.2042,
                0.5458,
                1.6887,
                0.3971,
                1.0600,
                0.3943,
                0.5537,
                0.5444,
                0.4089,
                0.7468,
                0.7744,
            ],
            dtype=torch.float32,
            device="cuda",
        )
        self.scale = [self.mean, 1.0 / self.std]

    def _sliding_window_decode(self, z, output_offload=False):
        z_dtype = z.dtype
        z_device = z.device
        scale = self.scale
        assert isinstance(scale[0], torch.Tensor), "scale[0] must be a tensor"
        z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
        z = z.to(z_dtype)

        first_chunk_size = self.first_chunk_size
        step = self.step_size

        # Context mapping: 1 latent frame context -> temporal_compression_ratio pixel frames overlap
        num_overlap_pixel_frames = 1 * self.temporal_compression_ratio

        _, _, num_frames, _, _ = z.shape

        # 1. Pad frames to satisfy chunking requirements
        #     The total number of frames must follow the formula:
        #     num_frames = first_chunk_size + n * step_size
        num_padding_frames = 0

        if num_frames < first_chunk_size:
            # if input is shorter than first_chunk_size
            num_padding_frames = first_chunk_size - num_frames
        elif (num_frames - first_chunk_size) % step != 0:
            num_padding_frames = step - (num_frames - first_chunk_size) % step

        if num_padding_frames > 0:
            z = torch.cat([z, z[:, :, -1:].repeat(1, 1, num_padding_frames, 1, 1)], dim=2)
            num_frames = num_frames + num_padding_frames

        # 2. Decode with overlapping windows
        # Collect chunks on CPU to avoid GPU OOM for high resolution (e.g., 1080P) when output_offload=True
        out_chunks = []

        if num_frames == first_chunk_size:
            # if only one chunk, decode directly
            out = self.decoder(z, is_first_chunk=True)
            out_chunks.append(out.cpu() if output_offload else out)
            del out
        else:
            # first chunk: attach the right frame
            out = self.decoder(z[:, :, 0 : first_chunk_size + 1, :, :], is_first_chunk=True)
            out = out[:, :, :-num_overlap_pixel_frames]
            out_chunks.append(out.cpu() if output_offload else out)
            del out

            # middle chunk: attach the left and right frames
            # last chunk: attach the left frame
            for i in range(first_chunk_size, num_frames, step):
                is_last_chunk = i + step == num_frames
                left = i - 1
                right = i + step + 1 if not is_last_chunk else i + step

                assert left >= 0 and right <= num_frames, f"left: {left}, right: {right}, num_frames: {num_frames}"

                out_ = self.decoder(z[:, :, left:right, :, :], is_first_chunk=False)

                if is_last_chunk:
                    out_ = out_[:, :, num_overlap_pixel_frames:]
                else:
                    out_ = out_[:, :, num_overlap_pixel_frames:-num_overlap_pixel_frames]

                out_chunks.append(out_.cpu() if output_offload else out_)
                del out_

        # Concatenate chunks (on CPU if output_offload, otherwise on GPU)
        out = torch.cat(out_chunks, dim=2)
        del out_chunks

        # 3. Remove padded frames
        if num_padding_frames > 0:
            out = out[:, :, : -num_padding_frames * self.temporal_compression_ratio]

        return out

    def decode(self, z: torch.Tensor, output_offload: bool = False):
        return self._sliding_window_decode(z, output_offload=output_offload)
