from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, is_torchvision_available, logging
from ..modeling_utils import ModelMixin
from ..transformers.transformer_cosmos import (
    CosmosEmbedding,
    CosmosLearnablePositionalEmbed,
    CosmosPatchEmbed,
    CosmosRotaryPosEmbed,
    CosmosTransformerBlock,
)


if is_torchvision_available():
    from torchvision import transforms

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


@dataclass
class CosmosControlNetOutput(BaseOutput):
    """
    Output of [`CosmosControlNetModel`].

    Args:
        control_block_samples (`list[torch.Tensor]`):
            List of control block activations to be injected into transformer blocks.
    """

    control_block_samples: List[torch.Tensor]


class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    r"""
    ControlNet for Cosmos Transfer2.5.

    This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
    learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
    internally from raw inputs.
    """

    _supports_gradient_checkpointing = True
    _skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"]
    _no_split_modules = ["CosmosTransformerBlock"]
    _keep_in_fp32_modules = ["learnable_pos_embed"]

    @register_to_config
    def __init__(
        self,
        n_controlnet_blocks: int = 4,
        in_channels: int = 130,
        latent_channels: int = 18,  # base latent channels (latents + condition_mask) + padding_mask
        model_channels: int = 2048,
        num_attention_heads: int = 32,
        attention_head_dim: int = 128,
        mlp_ratio: float = 4.0,
        text_embed_dim: int = 1024,
        adaln_lora_dim: int = 256,
        patch_size: Tuple[int, int, int] = (1, 2, 2),
        max_size: Tuple[int, int, int] = (128, 240, 240),
        rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
        extra_pos_embed_type: str | None = None,
        img_context_dim_in: int | None = None,
        img_context_dim_out: int = 2048,
        use_crossattn_projection: bool = False,
        crossattn_proj_in_channels: int = 1024,
        encoder_hidden_states_channels: int = 1024,
    ):
        super().__init__()

        self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)

        self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False)
        self.time_embed = CosmosEmbedding(model_channels, model_channels)

        self.learnable_pos_embed = None
        if extra_pos_embed_type == "learnable":
            self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
                hidden_size=model_channels,
                max_size=max_size,
                patch_size=patch_size,
            )

        self.img_context_proj = None
        if img_context_dim_in is not None and img_context_dim_in > 0:
            self.img_context_proj = nn.Sequential(
                nn.Linear(img_context_dim_in, img_context_dim_out, bias=True),
                nn.GELU(),
            )

        # Cross-attention projection for text embeddings (same as transformer)
        self.crossattn_proj = None
        if use_crossattn_projection:
            self.crossattn_proj = nn.Sequential(
                nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
                nn.GELU(),
            )

        # RoPE for both control and base latents
        self.rope = CosmosRotaryPosEmbed(
            hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
        )

        self.control_blocks = nn.ModuleList(
            [
                CosmosTransformerBlock(
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    cross_attention_dim=text_embed_dim,
                    mlp_ratio=mlp_ratio,
                    adaln_lora_dim=adaln_lora_dim,
                    qk_norm="rms_norm",
                    out_bias=False,
                    img_context=img_context_dim_in is not None and img_context_dim_in > 0,
                    before_proj=(block_idx == 0),
                    after_proj=True,
                )
                for block_idx in range(n_controlnet_blocks)
            ]
        )

        self.gradient_checkpointing = False

    def _expand_conditioning_scale(self, conditioning_scale: float | list[float]) -> List[float]:
        if isinstance(conditioning_scale, list):
            scales = conditioning_scale
        else:
            scales = [conditioning_scale] * len(self.control_blocks)

        if len(scales) < len(self.control_blocks):
            logger.warning(
                "Received %d control scales, but control network defines %d blocks. "
                "Scales will be trimmed or repeated to match.",
                len(scales),
                len(self.control_blocks),
            )
            scales = (scales * len(self.control_blocks))[: len(self.control_blocks)]
        return scales

    def forward(
        self,
        controls_latents: torch.Tensor,
        latents: torch.Tensor,
        timestep: torch.Tensor,
        encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
        condition_mask: torch.Tensor,
        conditioning_scale: float | list[float] = 1.0,
        padding_mask: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        fps: int | None = None,
        return_dict: bool = True,
    ) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]:
        """
        Forward pass for the ControlNet.

        Args:
            controls_latents: Control signal latents [B, C, T, H, W]
            latents: Base latents from the noising process [B, C, T, H, W]
            timestep: Diffusion timestep tensor
            encoder_hidden_states: Tuple of (text_context, img_context) or text_context
            condition_mask: Conditioning mask [B, 1, T, H, W]
            conditioning_scale: Scale factor(s) for control outputs
            padding_mask: Padding mask [B, 1, H, W] or None
            attention_mask: Optional attention mask or None
            fps: Frames per second for RoPE or None
            return_dict: Whether to return a CosmosControlNetOutput or a tuple

        Returns:
            CosmosControlNetOutput or tuple of control tensors
        """
        B, C, T, H, W = controls_latents.shape

        # 1. Prepare control latents
        control_hidden_states = controls_latents
        vace_in_channels = self.config.in_channels - 1
        if control_hidden_states.shape[1] < vace_in_channels - 1:
            pad_C = vace_in_channels - 1 - control_hidden_states.shape[1]
            control_hidden_states = torch.cat(
                [
                    control_hidden_states,
                    torch.zeros(
                        (B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
                    ),
                ],
                dim=1,
            )

        if condition_mask is not None:
            control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
        else:
            control_hidden_states = torch.cat(
                [control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
            )

        padding_mask_resized = transforms.functional.resize(
            padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
        )
        control_hidden_states = torch.cat(
            [control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
        )

        # 2. Prepare base latents (same processing as transformer.forward)
        base_hidden_states = latents
        if condition_mask is not None:
            base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1)

        base_padding_mask = transforms.functional.resize(
            padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
        )
        base_hidden_states = torch.cat(
            [base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
        )

        # 3. Generate positional embeddings (shared for both)
        image_rotary_emb = self.rope(control_hidden_states, fps=fps)
        extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None

        # 4. Patchify control latents
        control_hidden_states = self.patch_embed(control_hidden_states)
        control_hidden_states = control_hidden_states.flatten(1, 3)

        # 5. Patchify base latents
        p_t, p_h, p_w = self.config.patch_size
        post_patch_num_frames = T // p_t
        post_patch_height = H // p_h
        post_patch_width = W // p_w

        base_hidden_states = self.patch_embed_base(base_hidden_states)
        base_hidden_states = base_hidden_states.flatten(1, 3)

        # 6. Time embeddings
        if timestep.ndim == 1:
            temb, embedded_timestep = self.time_embed(base_hidden_states, timestep)
        elif timestep.ndim == 5:
            batch_size, _, num_frames, _, _ = latents.shape
            assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
                f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
            )
            timestep_flat = timestep.flatten()
            temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat)
            temb, embedded_timestep = (
                x.view(batch_size, post_patch_num_frames, 1, 1, -1)
                .expand(-1, -1, post_patch_height, post_patch_width, -1)
                .flatten(1, 3)
                for x in (temb, embedded_timestep)
            )
        else:
            raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")

        # 7. Process encoder hidden states
        if isinstance(encoder_hidden_states, tuple):
            text_context, img_context = encoder_hidden_states
        else:
            text_context = encoder_hidden_states
            img_context = None

        # Apply cross-attention projection to text context
        if self.crossattn_proj is not None:
            text_context = self.crossattn_proj(text_context)

        # Apply cross-attention projection to image context (if provided)
        if img_context is not None and self.img_context_proj is not None:
            img_context = self.img_context_proj(img_context)

        # Combine text and image context into a single tuple
        if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0:
            processed_encoder_hidden_states = (text_context, img_context)
        else:
            processed_encoder_hidden_states = text_context

        # 8. Prepare attention mask
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)  # [B, 1, 1, S]

        # 9. Run control blocks
        scales = self._expand_conditioning_scale(conditioning_scale)
        result = []
        for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
            if torch.is_grad_enabled() and self.gradient_checkpointing:
                control_hidden_states, control_proj = self._gradient_checkpointing_func(
                    block,
                    control_hidden_states,
                    processed_encoder_hidden_states,
                    embedded_timestep,
                    temb,
                    image_rotary_emb,
                    extra_pos_emb,
                    attention_mask,
                    None,  # controlnet_residual
                    base_hidden_states,
                    block_idx,
                )
            else:
                control_hidden_states, control_proj = block(
                    hidden_states=control_hidden_states,
                    encoder_hidden_states=processed_encoder_hidden_states,
                    embedded_timestep=embedded_timestep,
                    temb=temb,
                    image_rotary_emb=image_rotary_emb,
                    extra_pos_emb=extra_pos_emb,
                    attention_mask=attention_mask,
                    controlnet_residual=None,
                    latents=base_hidden_states,
                    block_idx=block_idx,
                )
            result.append(control_proj * scale)

        if not return_dict:
            return (result,)

        return CosmosControlNetOutput(control_block_samples=result)
