# SPDX-License-Identifier: Apache-2.0
"""
cache-dit integration module for SGLang DiT pipelines.

This module provides helper functions to enable cache-dit acceleration
on transformer modules in SGLang's modular pipeline architecture.
"""

from dataclasses import dataclass
from typing import List, Optional

import torch
import torch.distributed as dist

from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)

import cache_dit
from cache_dit import (
    BlockAdapter,
    DBCacheConfig,
    ForwardPattern,
    ParamsModifier,
    TaylorSeerCalibratorConfig,
    steps_mask,
)
from cache_dit.caching.block_adapters import BlockAdapterRegister
from cache_dit.parallelism import ParallelismBackend, ParallelismConfig

from sglang.multimodal_gen.runtime.distributed.parallel_state import get_dit_group

_original_similarity = None


def _patch_cache_dit_similarity():
    from cache_dit.caching.cache_contexts import cache_manager

    global _original_similarity
    if _original_similarity is not None:
        return

    _original_similarity = cache_manager.CachedContextManager.similarity

    def patched_similarity(self, t1, t2, *, threshold, parallelized=False, prefix="Fn"):
        if not parallelized:
            return _original_similarity(
                self,
                t1,
                t2,
                threshold=threshold,
                parallelized=parallelized,
                prefix=prefix,
            )

        sp_group = getattr(self, "_sglang_sp_group", None)
        tp_group = getattr(self, "_sglang_tp_group", None)
        tp_sp_group = getattr(self, "_sglang_tp_sp_group", None)
        target_group = tp_sp_group or sp_group or tp_group

        if target_group is None:
            return _original_similarity(
                self,
                t1,
                t2,
                threshold=threshold,
                parallelized=parallelized,
                prefix=prefix,
            )

        # Adapted from https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_manager.py#L495-L523
        condition_thresh = self.get_important_condition_threshold()
        if condition_thresh > 0.0:
            raw_diff = (t1 - t2).abs()
            token_m_df = raw_diff.mean(dim=-1)
            token_m_t1 = t1.abs().mean(dim=-1)
            token_diff = token_m_df / token_m_t1
            condition = token_diff > condition_thresh
            if condition.sum() > 0:
                condition = condition.unsqueeze(-1).expand_as(raw_diff)
                mean_diff = raw_diff[condition].mean()
                mean_t1 = t1[condition].abs().mean()
            else:
                mean_diff = (t1 - t2).abs().mean()
                mean_t1 = t1.abs().mean()
        else:
            mean_diff = (t1 - t2).abs().mean()
            mean_t1 = t1.abs().mean()

        dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG, group=target_group)
        dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG, group=target_group)

        diff = (mean_diff / mean_t1).item()
        self.add_residual_diff(diff)
        return diff < threshold

    cache_manager.CachedContextManager.similarity = patched_similarity


def _build_parallelism_config(
    sp_group: Optional[torch.distributed.ProcessGroup],
    tp_group: Optional[torch.distributed.ProcessGroup],
):
    if sp_group is None and tp_group is None:
        return None

    ulysses_size = None
    ring_size = None
    if sp_group is not None:
        ulysses_size = getattr(sp_group, "ulysses_world_size", None)
        ring_size = getattr(sp_group, "ring_world_size", None)

    tp_size = None
    if tp_group is not None:
        tp_size = dist.get_world_size(tp_group)

    return ParallelismConfig(
        backend=ParallelismBackend.NATIVE_PYTORCH,
        ulysses_size=ulysses_size,
        ring_size=ring_size,
        tp_size=tp_size,
    )


def _mark_transformer_parallelized(transformer, config, sp_group, tp_group):
    if config is None:
        return

    transformer._is_parallelized = True
    transformer._parallelism_config = config


def get_scm_mask(
    preset: str,
    num_inference_steps: int,
    compute_bins: Optional[List[int]] = None,
    cache_bins: Optional[List[int]] = None,
) -> Optional[List[int]]:
    """
    Get SCM mask using cache-dit's steps_mask().

    This is a thin wrapper that delegates to cache-dit's built-in
    steps_mask() function which handles all presets and scaling logic.

    Args:
        preset: Preset name ("none", "slow", "medium", "fast", "ultra").
        compute_bins: Custom compute bins (overrides preset).
        cache_bins: Custom cache bins (overrides preset).

    Returns:
        SCM mask list (1=compute, 0=cache), or None if disabled.
    """
    if preset == "none" and not (compute_bins and cache_bins):
        return None

    # Use cache-dit's steps_mask() directly
    mask = steps_mask(
        compute_bins=compute_bins,
        cache_bins=cache_bins,
        total_steps=num_inference_steps,
        mask_policy=preset if preset != "none" else "medium",
    )

    compute_count = sum(mask)
    cache_count = len(mask) - compute_count
    logger.info(
        "SCM: generated mask with %d compute steps, %d cache steps (preset=%s)",
        compute_count,
        cache_count,
        preset,
    )

    return mask


@dataclass
class CacheDitConfig:
    """Configuration for cache-dit integration.

    Attributes:
        enabled: Whether to enable cache-dit acceleration.
        Fn_compute_blocks: Number of first blocks to always compute (DBCache F).
        Bn_compute_blocks: Number of last blocks to always compute (DBCache B).
        max_warmup_steps: Number of warmup steps before caching starts (DBCache W).
        residual_diff_threshold: Threshold for residual difference (DBCache R).
        max_continuous_cached_steps: Maximum consecutive cached steps (DBCache MC).
        enable_taylorseer: Whether to enable TaylorSeer calibrator.
        taylorseer_order: Order of Taylor expansion (1 or 2).
        num_inference_steps: Total number of inference steps (required for transformer-only mode).
        steps_computation_mask: Binary mask for step-level caching (1=compute, 0=cache).
            Generated by get_scm_mask() (wrapper around cache_dit.steps_mask()).
        steps_computation_policy: Caching policy for SCM ("dynamic" or "static").
    """

    enabled: bool = False
    Fn_compute_blocks: int = 1
    Bn_compute_blocks: int = 0
    # Use 4 as default warmup steps instead of 8 in cache-dit, thus making
    # DBCache work for few steps distilled models, e.g., Z-Image w/ 8-steps.
    max_warmup_steps: int = 4
    # Use a relatively higher residual diff threshold (namely, 0.24) as default
    # to allow more aggressive caching due to we have already applied max continuous
    # cached steps limit, otherwise, we should use a lower threshold here like 0.12.
    residual_diff_threshold: float = 0.24
    max_continuous_cached_steps: int = 3
    # TaylorSeer is not suitable for few steps distilled models, so, we choose
    # to disable it by default. Reference:
    # - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers,
    #   https://arxiv.org/pdf/2503.06923
    # - FoCa: Forecast then Calibrate: Feature Caching as ODE for Efficient
    #   Diffusion Transformers, https://arxiv.org/pdf/2508.16211
    enable_taylorseer: bool = False
    taylorseer_order: int = 1
    num_inference_steps: Optional[int] = None
    # SCM fields (generated by _maybe_enable_cache_dit from env configuration)
    steps_computation_mask: Optional[List[int]] = None
    steps_computation_policy: str = "dynamic"


def enable_cache_on_transformer(
    transformer: torch.nn.Module,
    config: CacheDitConfig,
    model_name: str = "transformer",
    sp_group: Optional[torch.distributed.ProcessGroup] = None,
    tp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> torch.nn.Module:
    """Enable cache-dit on a transformer module, by wrapping the module with cache-dit

    This function enables cache-dit acceleration using the BlockAdapterRegister
    for pre-registered models

    Args:
        model_name: Name of the model for logging purposes.
        sp_group: Sequence parallel process group (for Ulysses/Ring).
        tp_group: Tensor parallel process group.

    """
    if not config.enabled:
        return transformer

    if config.num_inference_steps is None:
        raise ValueError(
            "num_inference_steps is required for transformer-only mode. "
            "Please provide it in CacheDitConfig."
        )

    # Check if the transformer is pre-registered in cache-dit
    if not BlockAdapterRegister.is_supported(transformer):
        transformer_cls_name = transformer.__class__.__name__
        raise ValueError(
            f"{transformer_cls_name} is not officially supported by cache-dit. "
            "Supported cache-dit DiT families include Flux, QwenImage, HunyuanDiT, "
            "HunyuanVideo, Wan, CogVideoX, Mochi, and others. "
            "Please ensure your transformer belongs to one of these families or "
            "define a custom BlockAdapter."
        )

    # Build cache config (including SCM fields if provided)
    cache_config = DBCacheConfig(
        num_inference_steps=config.num_inference_steps,
        Fn_compute_blocks=config.Fn_compute_blocks,
        Bn_compute_blocks=config.Bn_compute_blocks,
        max_warmup_steps=config.max_warmup_steps,
        residual_diff_threshold=config.residual_diff_threshold,
        max_continuous_cached_steps=config.max_continuous_cached_steps,
        # SCM fields
        steps_computation_mask=config.steps_computation_mask,
        steps_computation_policy=config.steps_computation_policy,
    )

    # Build calibrator config if TaylorSeer is enabled
    calibrator_config = None
    if config.enable_taylorseer:
        calibrator_config = TaylorSeerCalibratorConfig(
            taylorseer_order=config.taylorseer_order,
        )

    # Enable cache-dit on the transformer
    logger.info(
        "Enabling cache-dit on %s with config: Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, "
        "TaylorSeer=%s (order=%d), steps=%d",
        model_name,
        config.Fn_compute_blocks,
        config.Bn_compute_blocks,
        config.max_warmup_steps,
        config.residual_diff_threshold,
        config.max_continuous_cached_steps,
        config.enable_taylorseer,
        config.taylorseer_order,
        config.num_inference_steps,
    )

    # Log SCM configuration if enabled
    if config.steps_computation_mask:
        compute_steps = sum(config.steps_computation_mask)
        cache_steps = len(config.steps_computation_mask) - compute_steps
        logger.info(
            "SCM enabled: %d compute steps, %d cache steps, policy=%s",
            compute_steps,
            cache_steps,
            config.steps_computation_policy,
        )

    parallelism_config = _build_parallelism_config(sp_group, tp_group)
    if parallelism_config is not None:
        _patch_cache_dit_similarity()

    _mark_transformer_parallelized(transformer, parallelism_config, sp_group, tp_group)

    cache_dit.enable_cache(
        transformer,
        cache_config=cache_config,
        calibrator_config=calibrator_config,
        parallelism_config=None,
    )

    if parallelism_config is not None:
        context_manager = getattr(transformer, "_context_manager", None)
        if context_manager is not None:
            context_manager._sglang_sp_group = sp_group
            context_manager._sglang_tp_group = tp_group
            # In mixed TP + SP (Ulysses/Ring) mode, cache-dit decisions must be consistent
            # across the full TP×SP model-parallel slice. Prefer using SGLang's DIT group
            # as a conservative superset group; fallback to None.
            tp_sp_group = None
            if sp_group is not None and tp_group is not None:
                tp_sp_group = get_dit_group()

            context_manager._sglang_tp_sp_group = tp_sp_group

    return transformer


def enable_cache_on_dual_transformer(
    transformer: torch.nn.Module,
    transformer_2: torch.nn.Module,
    primary_config: CacheDitConfig,
    secondary_config: CacheDitConfig,
    model_name: str = "wan2.2",
    sp_group: Optional[torch.distributed.ProcessGroup] = None,
    tp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[torch.nn.Module, torch.nn.Module]:
    """Enable cache-dit on dual transformers using BlockAdapter.

    For models with two transformers (high-noise expert and low-noise expert),
    cache-dit requires enabling cache on both simultaneously via BlockAdapter.
    This cannot be done by calling enable_cache separately on each transformer.

    Args:
        primary_config: CacheDitConfig for primary transformer.
        secondary_config: CacheDitConfig for secondary transformer.
        sp_group: Sequence parallel process group (for Ulysses/Ring).
        tp_group: Tensor parallel process group.
    """
    _supported_dual_transformer_models = [
        "wan2.2",  # Currently, only Wan2.2 will run into dual-transformer case
    ]
    if model_name not in _supported_dual_transformer_models:
        raise ValueError(
            f"Dual-transformer cache-dit is only supported for "
            f"{_supported_dual_transformer_models}, got {model_name}."
        )

    if not primary_config.enabled:
        return transformer, transformer_2

    if primary_config.num_inference_steps is None:
        raise ValueError(
            "num_inference_steps is required for dual-transformer mode. "
            "Please provide it in CacheDitConfig."
        )

    # Build DBCacheConfig for primary transformer
    primary_cache_config = DBCacheConfig(
        num_inference_steps=primary_config.num_inference_steps,
        Fn_compute_blocks=primary_config.Fn_compute_blocks,
        Bn_compute_blocks=primary_config.Bn_compute_blocks,
        max_warmup_steps=primary_config.max_warmup_steps,
        residual_diff_threshold=primary_config.residual_diff_threshold,
        max_continuous_cached_steps=primary_config.max_continuous_cached_steps,
        steps_computation_mask=primary_config.steps_computation_mask,
        steps_computation_policy=primary_config.steps_computation_policy,
    )

    # Build DBCacheConfig for secondary transformer
    secondary_cache_config = DBCacheConfig(
        num_inference_steps=secondary_config.num_inference_steps,
        Fn_compute_blocks=secondary_config.Fn_compute_blocks,
        Bn_compute_blocks=secondary_config.Bn_compute_blocks,
        max_warmup_steps=secondary_config.max_warmup_steps,
        residual_diff_threshold=secondary_config.residual_diff_threshold,
        max_continuous_cached_steps=secondary_config.max_continuous_cached_steps,
        steps_computation_mask=secondary_config.steps_computation_mask,
        steps_computation_policy=secondary_config.steps_computation_policy,
    )

    # Build calibrator configs if TaylorSeer is enabled
    primary_calibrator = None
    if primary_config.enable_taylorseer:
        primary_calibrator = TaylorSeerCalibratorConfig(
            taylorseer_order=primary_config.taylorseer_order,
        )

    secondary_calibrator = None
    if secondary_config.enable_taylorseer:
        secondary_calibrator = TaylorSeerCalibratorConfig(
            taylorseer_order=secondary_config.taylorseer_order,
        )

    # Build ParamsModifier for each transformer
    primary_modifier = ParamsModifier(
        cache_config=primary_cache_config,
        calibrator_config=primary_calibrator,
    )
    secondary_modifier = ParamsModifier(
        cache_config=secondary_cache_config,
        calibrator_config=secondary_calibrator,
    )

    # Log configuration
    logger.info(
        "Enabling cache-dit on %s dual transformers with BlockAdapter",
        model_name,
    )
    logger.info(
        "  Primary (transformer): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s",
        primary_config.Fn_compute_blocks,
        primary_config.Bn_compute_blocks,
        primary_config.max_warmup_steps,
        primary_config.residual_diff_threshold,
        primary_config.max_continuous_cached_steps,
        primary_config.enable_taylorseer,
    )
    logger.info(
        "  Secondary (transformer_2): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s",
        secondary_config.Fn_compute_blocks,
        secondary_config.Bn_compute_blocks,
        secondary_config.max_warmup_steps,
        secondary_config.residual_diff_threshold,
        secondary_config.max_continuous_cached_steps,
        secondary_config.enable_taylorseer,
    )

    # Log SCM configuration if enabled
    if primary_config.steps_computation_mask:
        compute_steps = sum(primary_config.steps_computation_mask)
        cache_steps = len(primary_config.steps_computation_mask) - compute_steps
        logger.info(
            "  SCM enabled for primary transformer: %d compute steps, %d cache steps, policy=%s",
            compute_steps,
            cache_steps,
            primary_config.steps_computation_policy,
        )
    if secondary_config.steps_computation_mask:
        compute_steps = sum(secondary_config.steps_computation_mask)
        cache_steps = len(secondary_config.steps_computation_mask) - compute_steps
        logger.info(
            "  SCM enabled for secondary transformer: %d compute steps, %d cache steps, policy=%s",
            compute_steps,
            cache_steps,
            secondary_config.steps_computation_policy,
        )

    parallelism_config = _build_parallelism_config(sp_group, tp_group)
    if parallelism_config is not None:
        _patch_cache_dit_similarity()

    _mark_transformer_parallelized(transformer, parallelism_config, sp_group, tp_group)
    _mark_transformer_parallelized(
        transformer_2, parallelism_config, sp_group, tp_group
    )

    # Get blocks attribute - Wan transformers use 'blocks' attribute
    transformer_blocks = getattr(transformer, "blocks", None)
    transformer_2_blocks = getattr(transformer_2, "blocks", None)

    if transformer_blocks is None or transformer_2_blocks is None:
        raise ValueError(
            "Dual transformers must have 'blocks' attribute for cache-dit. "
            f"transformer has blocks: {transformer_blocks is not None}, "
            f"transformer_2 has blocks: {transformer_2_blocks is not None}"
        )

    # Enable cache-dit using BlockAdapter for both transformers simultaneously
    # This is required for Wan2.2 and similar dual-transformer architectures
    if model_name == "wan2.2":
        # Use Pattern_2 for Wan2.2 dual-transformer. We should check `model_name`
        # to ensure we only apply this for supported models. Different models
        # may require different ForwardPattern.
        cache_dit.enable_cache(
            BlockAdapter(
                transformer=[transformer, transformer_2],
                blocks=[transformer_blocks, transformer_2_blocks],
                forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2],
                params_modifiers=[primary_modifier, secondary_modifier],
                has_separate_cfg=True,
            ),
            parallelism_config=None,
        )
    else:
        raise ValueError(
            f"Dual-transformer is not implemented for model {model_name} yet."
        )

    if parallelism_config is not None:
        for t in [transformer, transformer_2]:
            context_manager = getattr(t, "_context_manager", None)
            if context_manager is not None:
                context_manager._sglang_sp_group = sp_group
                context_manager._sglang_tp_group = tp_group
                tp_sp_group = None
                if sp_group is not None and tp_group is not None:
                    try:
                        tp_sp_group = get_dit_group()
                    except Exception:
                        tp_sp_group = None
                context_manager._sglang_tp_sp_group = tp_sp_group

    return transformer, transformer_2


def refresh_context_on_transformer(
    transformer: torch.nn.Module,
    num_inference_steps: int,
    scm_preset: str | None = None,
    verbose: bool = False,
) -> None:
    """Refresh cache-dit context for transformer."""
    cache_dit.refresh_context(
        transformer,
        cache_config=DBCacheConfig().reset(
            num_inference_steps=num_inference_steps,
            steps_computation_mask=cache_dit.steps_mask(
                mask_policy=scm_preset, total_steps=num_inference_steps
            ),
            steps_computation_policy=scm_preset,
        ),
        verbose=verbose,
    )
    logger.debug(f"cache-dit refreshed on transformer (steps={num_inference_steps})")


def refresh_context_on_dual_transformer(
    transformer: torch.nn.Module,
    transformer_2: torch.nn.Module,
    num_high_noise_steps: int,
    num_low_noise_steps: int,
    scm_preset: str | None = None,
    verbose: bool = False,
) -> None:
    """Refresh cache-dit context for dual transformers."""
    cache_dit.refresh_context(
        transformer,
        cache_config=DBCacheConfig().reset(
            num_inference_steps=num_high_noise_steps,
            steps_computation_mask=cache_dit.steps_mask(
                mask_policy=scm_preset, total_steps=num_high_noise_steps
            ),
            steps_computation_policy=scm_preset,
        ),
        verbose=verbose,
    )
    cache_dit.refresh_context(
        transformer_2,
        cache_config=DBCacheConfig().reset(
            num_inference_steps=num_low_noise_steps,
            steps_computation_mask=cache_dit.steps_mask(
                mask_policy=scm_preset, total_steps=num_low_noise_steps
            ),
            steps_computation_policy=scm_preset,
        ),
        verbose=verbose,
    )
    logger.debug(
        f"cache-dit refreshed on dual transformers (steps={num_high_noise_steps}, {num_low_noise_steps})"
    )
