# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Model-specific extractors for TeaCache.

This module provides a registry of extractor functions that know how to extract
modulated inputs from different transformer architectures. Adding support for
a new model requires only adding a new extractor function to the registry.

With Option B enhancement, extractors now return a CacheContext object containing
all model-specific information needed for generic caching, including preprocessing,
transformer execution, and postprocessing logic.
"""

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn

from vllm_omni.diffusion.forward_context import get_forward_context


@dataclass
class CacheContext:
    """
    Context object containing all model-specific information for caching.

    This allows the TeaCacheHook to remain completely generic - all model-specific
    logic is encapsulated in the extractor that returns this context.

    Attributes:
        modulated_input: Tensor used for cache decision (similarity comparison).
            Must be a torch.Tensor extracted from the first transformer block,
            typically after applying normalization and modulation.

        hidden_states: Current hidden states (will be modified by caching).
            Must be a torch.Tensor representing the main image/latent states
            after preprocessing but before transformer blocks.

        encoder_hidden_states: Optional encoder states (for dual-stream models).
            Set to None for single-stream models (e.g., Flux).
            For dual-stream models (e.g., Qwen), contains text encoder outputs.

        temb: Timestep embedding tensor.
            Must be a torch.Tensor containing the timestep conditioning.

        run_transformer_blocks: Callable that executes model-specific transformer blocks.
            Signature: () -> tuple[torch.Tensor, ...]

            Returns:
                tuple containing:
                - [0]: processed hidden_states (required)
                - [1]: processed encoder_hidden_states (optional, only for dual-stream)

            Example for single-stream:
                def run_blocks():
                    h = hidden_states
                    for block in module.transformer_blocks:
                        h = block(h, temb=temb)
                    return (h,)

            Example for dual-stream:
                def run_blocks():
                    h, e = hidden_states, encoder_hidden_states
                    for block in module.transformer_blocks:
                        e, h = block(h, e, temb=temb)
                    return (h, e)

        postprocess: Callable that does model-specific output postprocessing.
            Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple]

            Takes the processed hidden_states and applies final transformations
            (normalization, projection) to produce the model output.

            Example:
                def postprocess(h):
                    h = module.norm_out(h, temb)
                    output = module.proj_out(h)
                    return Transformer2DModelOutput(sample=output)

        extra_states: Optional dict for additional model-specific state.
            Use this for models that need to pass additional context beyond
            the standard fields.
    """

    modulated_input: torch.Tensor
    hidden_states: torch.Tensor
    encoder_hidden_states: torch.Tensor | None
    temb: torch.Tensor
    run_transformer_blocks: Callable[[], tuple[torch.Tensor, ...]]
    postprocess: Callable[[torch.Tensor], Any]
    extra_states: dict[str, Any] | None = None

    def validate(self) -> None:
        """
        Validate that the CacheContext contains valid data.

        Raises:
            TypeError: If fields have wrong types
            ValueError: If tensors have invalid properties
            RuntimeError: If callables fail basic invocation tests

        This method should be called after creating a CacheContext to catch
        common developer errors early with clear error messages.
        """
        # Validate tensor fields
        if not isinstance(self.modulated_input, torch.Tensor):
            raise TypeError(f"modulated_input must be torch.Tensor, got {type(self.modulated_input)}")

        if not isinstance(self.hidden_states, torch.Tensor):
            raise TypeError(f"hidden_states must be torch.Tensor, got {type(self.hidden_states)}")

        if self.encoder_hidden_states is not None and not isinstance(self.encoder_hidden_states, torch.Tensor):
            raise TypeError(
                f"encoder_hidden_states must be torch.Tensor or None, got {type(self.encoder_hidden_states)}"
            )

        if not isinstance(self.temb, torch.Tensor):
            raise TypeError(f"temb must be torch.Tensor, got {type(self.temb)}")

        # Validate callables
        if not callable(self.run_transformer_blocks):
            raise TypeError(f"run_transformer_blocks must be callable, got {type(self.run_transformer_blocks)}")

        if not callable(self.postprocess):
            raise TypeError(f"postprocess must be callable, got {type(self.postprocess)}")

        # Validate tensor shapes are compatible
        if self.modulated_input.shape[0] != self.hidden_states.shape[0]:
            raise ValueError(
                f"Batch size mismatch: modulated_input has batch size "
                f"{self.modulated_input.shape[0]}, but hidden_states has "
                f"{self.hidden_states.shape[0]}"
            )

        # Validate devices match
        if self.modulated_input.device != self.hidden_states.device:
            raise ValueError(
                f"Device mismatch: modulated_input on {self.modulated_input.device}, "
                f"hidden_states on {self.hidden_states.device}"
            )


def extract_qwen_context(
    module: nn.Module,
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    encoder_hidden_states_mask: torch.Tensor,
    timestep: torch.Tensor | float | int,
    img_shapes: torch.Tensor,
    txt_seq_lens: torch.Tensor,
    guidance: torch.Tensor | None = None,
    additional_t_cond: torch.Tensor | None = None,
    attention_kwargs: dict[str, Any] | None = None,
    **kwargs: Any,
) -> CacheContext:
    """
    Extract cache context for QwenImageTransformer2DModel.

    This is the ONLY Qwen-specific code needed for TeaCache support.
    It encapsulates preprocessing, modulated input extraction, transformer execution,
    and postprocessing logic.

    Args:
        module: QwenImageTransformer2DModel instance
        hidden_states: Input hidden states tensor
        encoder_hidden_states: Text encoder outputs
        encoder_hidden_states_mask: Mask for text encoder
        timestep: Current diffusion timestep
        img_shapes: Image shapes for position embedding
        txt_seq_lens: Text sequence lengths
        guidance: Optional guidance scale for CFG
        additional_t_cond: Optional additional timestep conditioning
        attention_kwargs: Additional attention arguments
        **kwargs: Additional keyword arguments ignored by this extractor

    Returns:
        CacheContext with all information needed for generic caching
    """
    from diffusers.models.modeling_outputs import Transformer2DModelOutput

    if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
        raise ValueError("Module must have transformer_blocks")

    # ============================================================================
    # PREPROCESSING (Qwen-specific)
    # ============================================================================
    hidden_states = module.img_in(hidden_states)
    timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype)
    encoder_hidden_states = module.txt_norm(encoder_hidden_states)
    encoder_hidden_states = module.txt_in(encoder_hidden_states)

    if guidance is not None:
        guidance = guidance.to(hidden_states.dtype) * 1000

    temb = (
        module.time_text_embed(timestep, hidden_states, additional_t_cond)
        if guidance is None
        else module.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
    )

    image_rotary_emb = module.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)

    # ============================================================================
    # EXTRACT MODULATED INPUT (for cache decision)
    # ============================================================================
    block = module.transformer_blocks[0]
    img_mod_params = block.img_mod(temb)
    img_mod1, _ = img_mod_params.chunk(2, dim=-1)
    img_modulated, _ = block.img_norm1(hidden_states, img_mod1)

    # ============================================================================
    # DEFINE TRANSFORMER EXECUTION (Qwen-specific)
    # ============================================================================
    def run_transformer_blocks():
        """Execute all Qwen transformer blocks."""
        h = hidden_states
        e = encoder_hidden_states
        encoder_mask = encoder_hidden_states_mask
        hidden_states_mask = None  # default
        if module.parallel_config is not None and module.parallel_config.sequence_parallel_size > 1:
            ctx = get_forward_context()
            if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
                # Create mask for the full (padded) sequence
                # valid positions = True, padding positions = False
                batch_size = hidden_states.shape[0]
                padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
                hidden_states_mask = torch.ones(
                    batch_size,
                    padded_seq_len,
                    dtype=torch.bool,
                    device=hidden_states.device,
                )
                hidden_states_mask[:, ctx.sp_original_seq_len :] = False

        # if mask is all true, set it to None
        if hidden_states_mask is not None and hidden_states_mask.all():
            hidden_states_mask = None
        if encoder_mask is not None and encoder_mask.all():
            encoder_mask = None
        for block in module.transformer_blocks:
            e, h = block(
                hidden_states=h,
                encoder_hidden_states=e,
                encoder_hidden_states_mask=encoder_mask,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                joint_attention_kwargs=attention_kwargs,
                hidden_states_mask=hidden_states_mask,
            )
        return (h, e)

    # ============================================================================
    # DEFINE POSTPROCESSING (Qwen-specific)
    # ============================================================================
    return_dict = kwargs.get("return_dict", True)

    def postprocess(h):
        """Apply Qwen-specific output postprocessing."""
        h = module.norm_out(h, temb)
        output = module.proj_out(h)
        if not return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)

    # ============================================================================
    # RETURN CONTEXT
    # ============================================================================
    return CacheContext(
        modulated_input=img_modulated,
        hidden_states=hidden_states,
        encoder_hidden_states=encoder_hidden_states,
        temb=temb,
        run_transformer_blocks=run_transformer_blocks,
        postprocess=postprocess,
    )


def extract_bagel_context(
    module: nn.Module,
    x_t: torch.Tensor,
    timestep: torch.Tensor | float | int,
    packed_vae_token_indexes: torch.LongTensor,
    packed_vae_position_ids: torch.LongTensor,
    packed_text_ids: torch.LongTensor,
    packed_text_indexes: torch.LongTensor,
    packed_indexes: torch.LongTensor,
    packed_position_ids: torch.LongTensor,
    packed_seqlens: torch.IntTensor,
    key_values_lens: torch.IntTensor,
    past_key_values: Any,
    packed_key_value_indexes: torch.LongTensor,
    **kwargs: Any,
) -> CacheContext:
    """
    Extract cache context for Bagel model.

    Args:
        module: Bagel instance
        x_t: Latent image input
        timestep: Current timestep
        packed_vae_token_indexes: Indexes for VAE tokens in packed sequence
        packed_vae_position_ids: Position IDs for VAE tokens
        packed_text_ids: Text token IDs
        packed_text_indexes: Indexes for text tokens in packed sequence
        packed_indexes: Global indexes
        packed_position_ids: Global position IDs
        packed_seqlens: Sequence lengths
        key_values_lens: KV cache lengths
        past_key_values: KV cache
        packed_key_value_indexes: KV cache indexes
        **kwargs: Additional keyword arguments

    Returns:
        CacheContext with all information needed for generic caching
    """

    # 1. Embed text
    packed_text_embedding = module.language_model.model.embed_tokens(packed_text_ids)
    packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), module.hidden_size))
    packed_sequence[packed_text_indexes] = packed_text_embedding

    # 2. Embed timestep
    if not isinstance(timestep, torch.Tensor):
        timestep = torch.tensor([timestep], device=x_t.device)
    if timestep.dim() == 0:
        timestep = timestep.unsqueeze(0)

    # 3. Embed image (x_t)
    packed_pos_embed = module.latent_pos_embed(packed_vae_position_ids)
    packed_timestep_embeds = module.time_embedder(timestep)

    x_t_emb = module.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
    if x_t_emb.dtype != packed_sequence.dtype:
        x_t_emb = x_t_emb.to(packed_sequence.dtype)

    packed_sequence[packed_vae_token_indexes] = x_t_emb

    # Use the full packed sequence as modulated input to match hidden_states size
    modulated_input = packed_sequence

    def run_transformer_blocks():
        extra_inputs = {}
        if module.use_moe:
            extra_inputs = {
                "mode": "gen",
                "packed_vae_token_indexes": packed_vae_token_indexes,
                "packed_text_indexes": packed_text_indexes,
            }

        output = module.language_model.forward(
            packed_query_sequence=packed_sequence,
            query_lens=packed_seqlens,
            packed_query_position_ids=packed_position_ids,
            packed_query_indexes=packed_indexes,
            past_key_values=past_key_values,
            key_values_lens=key_values_lens,
            packed_key_value_indexes=packed_key_value_indexes,
            update_past_key_values=False,
            is_causal=False,
            **extra_inputs,
        )
        return (output.packed_query_sequence,)

    def postprocess(h):
        v_t = module.llm2vae(h)
        v_t = v_t[packed_vae_token_indexes]
        return v_t

    return CacheContext(
        modulated_input=modulated_input,
        hidden_states=packed_sequence,  # Use full packed sequence
        encoder_hidden_states=None,
        temb=packed_timestep_embeds,  # Approximate
        run_transformer_blocks=run_transformer_blocks,
        postprocess=postprocess,
    )


def extract_zimage_context(
    module: nn.Module,
    x: list[torch.Tensor],
    t: torch.Tensor,
    cap_feats: list[torch.Tensor],
    patch_size: int = 2,
    f_patch_size: int = 1,
    **kwargs: Any,
) -> CacheContext:
    """
    Extract cache context for ZImageTransformer2DModel.

    This is the ONLY Z-Image-specific code needed for TeaCache support.
    It encapsulates preprocessing, modulated input extraction, transformer execution,
    and postprocessing logic.

    Args:
        module: ZImageTransformer2DModel instance
        x: List of image tensors per batch item
        t: Timestep tensor
        cap_feats: List of caption feature tensors per batch item
        patch_size: Patch size for patchification (default: 2)
        f_patch_size: Frame patch size (default: 1)
        **kwargs: Additional keyword arguments ignored by this extractor

    Returns:
        CacheContext with all information needed for generic caching
    """
    from torch.nn.utils.rnn import pad_sequence

    if not hasattr(module, "layers") or len(module.layers) == 0:
        raise ValueError("Module must have main transformer layers")

    bsz = len(x)
    device = x[0].device

    # ============================================================================
    # PREPROCESSING (Z-Image specific)
    # ============================================================================
    # Scale timestep and create timestep embedding
    t_scaled = t * module.t_scale
    adaln_input = module.t_embedder(t_scaled)

    # Patchify and embed inputs
    (
        x_patches,
        cap_feats_processed,
        x_size,
        x_pos_ids,
        cap_pos_ids,
        x_inner_pad_mask,
        cap_inner_pad_mask,
    ) = module.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)

    # Process image patches through embedder and noise refiner
    x_item_seqlens = [len(_) for _ in x_patches]
    x_max_item_seqlen = max(x_item_seqlens)

    x_embedded = torch.cat(x_patches, dim=0)
    x_embedded = module.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_embedded)

    # Match adaln_input dtype to x_embedded
    adaln_input = adaln_input.type_as(x_embedded)

    # Apply pad token
    x_embedded[torch.cat(x_inner_pad_mask)] = module.x_pad_token
    x_list = list(x_embedded.split(x_item_seqlens, dim=0))

    # Compute rope embeddings for image patches
    x_cos, x_sin = module.rope_embedder(torch.cat(x_pos_ids, dim=0))
    x_cos = list(x_cos.split(x_item_seqlens, dim=0))
    x_sin = list(x_sin.split(x_item_seqlens, dim=0))

    # Pad sequences for batch processing
    x_batched = pad_sequence(x_list, batch_first=True, padding_value=0.0)
    x_cos_batched = pad_sequence(x_cos, batch_first=True, padding_value=0.0)
    x_sin_batched = pad_sequence(x_sin, batch_first=True, padding_value=0.0)
    x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
    for i, seq_len in enumerate(x_item_seqlens):
        x_attn_mask[i, :seq_len] = 1

    # Run noise refiner blocks
    for layer in module.noise_refiner:
        x_batched = layer(x_batched, x_attn_mask, x_cos_batched, x_sin_batched, adaln_input)

    # Process caption features through embedder and context refiner
    cap_item_seqlens = [len(_) for _ in cap_feats_processed]
    cap_max_item_seqlen = max(cap_item_seqlens)

    cap_embedded = torch.cat(cap_feats_processed, dim=0)
    cap_embedded = module.cap_embedder(cap_embedded)
    cap_embedded[torch.cat(cap_inner_pad_mask)] = module.cap_pad_token
    cap_list = list(cap_embedded.split(cap_item_seqlens, dim=0))

    # Compute rope embeddings for caption
    cap_cos, cap_sin = module.rope_embedder(torch.cat(cap_pos_ids, dim=0))
    cap_cos = list(cap_cos.split(cap_item_seqlens, dim=0))
    cap_sin = list(cap_sin.split(cap_item_seqlens, dim=0))

    # Pad sequences for batch processing
    cap_batched = pad_sequence(cap_list, batch_first=True, padding_value=0.0)
    cap_cos_batched = pad_sequence(cap_cos, batch_first=True, padding_value=0.0)
    cap_sin_batched = pad_sequence(cap_sin, batch_first=True, padding_value=0.0)
    cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
    for i, seq_len in enumerate(cap_item_seqlens):
        cap_attn_mask[i, :seq_len] = 1

    # Run context refiner blocks
    for layer in module.context_refiner:
        cap_batched = layer(cap_batched, cap_attn_mask, cap_cos_batched, cap_sin_batched)

    # Create unified sequence (image + caption)
    unified_list = []
    unified_cos_list = []
    unified_sin_list = []
    for i in range(bsz):
        x_len = x_item_seqlens[i]
        cap_len = cap_item_seqlens[i]
        unified_list.append(torch.cat([x_batched[i][:x_len], cap_batched[i][:cap_len]]))
        unified_cos_list.append(torch.cat([x_cos_batched[i][:x_len], cap_cos_batched[i][:cap_len]]))
        unified_sin_list.append(torch.cat([x_sin_batched[i][:x_len], cap_sin_batched[i][:cap_len]]))

    unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
    unified_max_item_seqlen = max(unified_item_seqlens)

    unified = pad_sequence(unified_list, batch_first=True, padding_value=0.0)
    unified_cos = pad_sequence(unified_cos_list, batch_first=True, padding_value=0.0)
    unified_sin = pad_sequence(unified_sin_list, batch_first=True, padding_value=0.0)
    unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
    for i, seq_len in enumerate(unified_item_seqlens):
        unified_attn_mask[i, :seq_len] = 1

    # ============================================================================
    # EXTRACT MODULATED INPUT (for cache decision)
    # ============================================================================
    # Use the first main transformer block's modulation
    # The main layers have modulation=True and process the unified sequence
    block = module.layers[0]
    # Get modulation parameters: scale_msa, gate_msa, scale_mlp, gate_mlp
    mod_params = block.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
    scale_msa = 1.0 + mod_params[0]
    # Extract modulated input: normalized hidden states scaled by modulation
    modulated_input = block.attention_norm1(unified) * scale_msa

    # ============================================================================
    # DEFINE TRANSFORMER EXECUTION (Z-Image specific)
    # ============================================================================
    def run_transformer_blocks():
        """Execute all Z-Image main transformer blocks."""
        h = unified
        for layer in module.layers:
            h = layer(h, unified_attn_mask, unified_cos, unified_sin, adaln_input)
        return (h,)

    # ============================================================================
    # DEFINE POSTPROCESSING (Z-Image specific)
    # ============================================================================
    def postprocess(h):
        """Apply Z-Image specific output postprocessing."""
        h = module.all_final_layer[f"{patch_size}-{f_patch_size}"](h, adaln_input)
        h = list(h.unbind(dim=0))
        output = module.unpatchify(h, x_size, patch_size, f_patch_size)
        return output, {}

    # ============================================================================
    # RETURN CONTEXT
    # ============================================================================
    return CacheContext(
        modulated_input=modulated_input,
        hidden_states=unified,
        encoder_hidden_states=None,  # Z-Image uses unified sequence, no separate encoder states
        temb=adaln_input,
        run_transformer_blocks=run_transformer_blocks,
        postprocess=postprocess,
        extra_states={
            "unified_attn_mask": unified_attn_mask,
            "unified_cos": unified_cos,
            "unified_sin": unified_sin,
            "x_size": x_size,
            "x_item_seqlens": x_item_seqlens,
            "patch_size": patch_size,
            "f_patch_size": f_patch_size,
        },
    )


# Registry for model-specific extractors
# Key: Transformer class name
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
#
# Note: Use the transformer class name as specified in pipelines as TeaCache hooks operate
# on the transformer module and multiple pipelines can share the same transformer.
EXTRACTOR_REGISTRY: dict[str, Callable] = {
    "QwenImageTransformer2DModel": extract_qwen_context,
    "Bagel": extract_bagel_context,
    "ZImageTransformer2DModel": extract_zimage_context,
    # Future models:
    # "FluxTransformer2DModel": extract_flux_context,
    # "CogVideoXTransformer3DModel": extract_cogvideox_context,
}


def register_extractor(transformer_cls_name: str, extractor_fn: Callable) -> None:
    """
    Register a new extractor function for a model type.

    This allows extending TeaCache support to new models without modifying
    the core TeaCache code.

    Args:
        transformer_cls_name: Transformer model type identifier (class name or type string)
        extractor_fn: Function with signature (module, *args, **kwargs) -> CacheContext

    Example:
        >>> def extract_flux_context(module, hidden_states, timestep, guidance=None, **kwargs):
        ...     # Preprocessing
        ...     temb = module.time_text_embed(timestep, guidance)
        ...     # Extract modulated input
        ...     modulated = module.transformer_blocks[0].norm1(hidden_states, emb=temb)
        ...     # Define execution
        ...     def run_blocks():
        ...         h = hidden_states
        ...         for block in module.transformer_blocks:
        ...             h = block(h, temb=temb)
        ...         return (h,)
        ...     # Define postprocessing
        ...     def postprocess(h):
        ...         return module.proj_out(module.norm_out(h, temb))
        ...     # Return context
        ...     return CacheContext(modulated, hidden_states, None, temb, run_blocks, postprocess)
        >>> register_extractor("FluxTransformer2DModel", extract_flux_context)
    """
    EXTRACTOR_REGISTRY[transformer_cls_name] = extractor_fn


def get_extractor(transformer_cls_name: str) -> Callable:
    """
    Get extractor function for given transformer class.

    This function looks up the extractor based on the exact transformer_cls_name string,
    which should match the transformer type in the pipeline (i.e., pipeline.transformer.__class__.__name__).

    Args:
        transformer_cls_name: Transformer class name (e.g., "QwenImageTransformer2DModel")
                                Must exactly match a key in EXTRACTOR_REGISTRY.

    Returns:
        Extractor function with signature (module, *args, **kwargs) -> CacheContext

    Raises:
        ValueError: If model type not found in registry

    Example:
        >>> # Get extractor for QwenImageTransformer2DModel
        >>> extractor = get_extractor("QwenImageTransformer2DModel")
        >>> ctx = extractor(transformer, hidden_states, encoder_hidden_states, timestep, ...)
    """
    # Direct lookup - no substring matching
    if transformer_cls_name in EXTRACTOR_REGISTRY:
        return EXTRACTOR_REGISTRY[transformer_cls_name]

    # No match found
    available_types = list(EXTRACTOR_REGISTRY.keys())
    raise ValueError(
        f"Unknown model type: '{transformer_cls_name}'. "
        f"Available types: {available_types}\n"
        f"To add support for a new model, use register_extractor() or add to EXTRACTOR_REGISTRY."
    )
