# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
cache-dit integration backend for vllm-omni.

This module provides a CacheDiTBackend class to enable cache-dit acceleration on diffusion
pipelines in vllm-omni, supporting both single and dual-transformer architectures.
"""

import functools
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any, Optional

import cache_dit
import torch
from cache_dit import BlockAdapter, DBCacheConfig, ForwardPattern, ParamsModifier, TaylorSeerCalibratorConfig
from cache_dit.caching.block_adapters import FakeDiffusionPipeline
from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter
from cache_dit.caching.cache_blocks.pattern_0_1_2 import CachedBlocks_Pattern_0_1_2
from cache_dit.caching.cache_contexts import BasicCacheConfig
from cache_dit.caching.cache_contexts.cache_manager import CachedContextManager
from vllm.logger import init_logger

from vllm_omni.diffusion.cache.base import CacheBackend
from vllm_omni.diffusion.data import DiffusionCacheConfig, OmniDiffusionConfig

logger = init_logger(__name__)


# Small helper to centralize cache-dit summaries.
def cache_summary(pipeline: Any, details: bool = True) -> None:
    cache_dit.summary(pipeline.transformer, details=details)
    if hasattr(pipeline, "transformer_2"):
        cache_dit.summary(pipeline.transformer_2, details=details)


# Registry of custom cache-dit enablers for specific models
# Maps pipeline names to their cache-dit enablement functions
# Models in this registry require custom handling (e.g., dual-transformer architectures)
# Will be populated after function definitions
CUSTOM_DIT_ENABLERS: dict[str, Callable] = {}


def _build_db_cache_config(cache_config: Any) -> DBCacheConfig:
    """Build DBCacheConfig with optional SCM (Step Computation Masking) support.

    Args:
        cache_config: DiffusionCacheConfig instance.

    Returns:
        DBCacheConfig instance with SCM support if configured.
    """

    return DBCacheConfig(
        # we will refresh the context when gets num_inference_steps in the first inference request
        num_inference_steps=None,
        Fn_compute_blocks=cache_config.Fn_compute_blocks,
        Bn_compute_blocks=cache_config.Bn_compute_blocks,
        max_warmup_steps=cache_config.max_warmup_steps,
        max_cached_steps=cache_config.max_cached_steps,
        max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
        residual_diff_threshold=cache_config.residual_diff_threshold,
    )


def enable_cache_for_wan22(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
    """Enable cache-dit for Wan2.2 dual-transformer architecture.

    Wan2.2 uses two transformers (transformer and transformer_2) that need
    to be enabled together using BlockAdapter.

    Args:
        pipeline: The Wan2.2 pipeline instance.
        cache_config: DiffusionCacheConfig instance with cache configuration.

    Returns:
        A refresh function that can be called to update cache context with new num_inference_steps.
    """

    cache_dit.enable_cache(
        BlockAdapter(
            transformer=[
                pipeline.transformer,
                pipeline.transformer_2,
            ],
            blocks=[
                pipeline.transformer.blocks,
                pipeline.transformer_2.blocks,
            ],
            forward_pattern=[
                ForwardPattern.Pattern_2,
                ForwardPattern.Pattern_2,
            ],
            params_modifiers=[
                # high-noise transformer only have 30% steps
                ParamsModifier(
                    cache_config=DBCacheConfig().reset(
                        max_warmup_steps=cache_config.max_warmup_steps,
                        max_cached_steps=cache_config.max_cached_steps,
                    ),
                ),
                ParamsModifier(
                    cache_config=DBCacheConfig().reset(
                        max_warmup_steps=2,
                        max_cached_steps=20,
                    ),
                ),
            ],
            has_separate_cfg=True,
        ),
        cache_config=DBCacheConfig(
            Fn_compute_blocks=cache_config.Fn_compute_blocks,
            Bn_compute_blocks=cache_config.Bn_compute_blocks,
            max_warmup_steps=cache_config.max_warmup_steps,
            max_cached_steps=cache_config.max_cached_steps,
            max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
            residual_diff_threshold=cache_config.residual_diff_threshold,
            num_inference_steps=None,
        ),
    )

    # from https://github.com/vipshop/cache-dit/pull/542
    def _split_inference_steps(num_inference_steps: int) -> tuple[int, int]:
        """Split inference steps into high-noise and low-noise steps for Wan2.2.

        This is an internal helper function specific to Wan2.2's dual-transformer
        architecture that uses boundary_ratio to determine the split point.

        Args:
            num_inference_steps: Total number of inference steps.

        Returns:
            A tuple of (num_high_noise_steps, num_low_noise_steps).
        """
        if pipeline.boundary_ratio is not None:
            boundary_timestep = pipeline.boundary_ratio * pipeline.scheduler.config.num_train_timesteps
        else:
            boundary_timestep = None

        # Set timesteps to calculate the split
        device = next(pipeline.transformer.parameters()).device
        pipeline.scheduler.set_timesteps(num_inference_steps, device=device)

        timesteps = pipeline.scheduler.timesteps
        num_high_noise_steps = 0  # high-noise steps for transformer
        for t in timesteps:
            if boundary_timestep is None or t >= boundary_timestep:
                num_high_noise_steps += 1
        # low-noise steps for transformer_2
        num_low_noise_steps = num_inference_steps - num_high_noise_steps
        return num_high_noise_steps, num_low_noise_steps

    def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
        """Refresh cache context for both transformers with new num_inference_steps.

        Args:
            pipeline: The Wan2.2 pipeline instance.
            num_inference_steps: New number of inference steps.
        """

        num_high_noise_steps, num_low_noise_steps = _split_inference_steps(num_inference_steps)
        # Refresh context for high-noise transformer
        if cache_config.scm_steps_mask_policy is None:
            # cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose)
            cache_dit.refresh_context(
                pipeline.transformer,
                num_inference_steps=num_high_noise_steps,
                verbose=verbose,
            )
            cache_dit.refresh_context(
                pipeline.transformer_2,
                num_inference_steps=num_low_noise_steps,
                verbose=verbose,
            )
        else:
            cache_dit.refresh_context(
                pipeline.transformer,
                cache_config=DBCacheConfig().reset(
                    num_inference_steps=num_high_noise_steps,
                    steps_computation_mask=cache_dit.steps_mask(
                        mask_policy=cache_config.scm_steps_mask_policy, total_steps=num_high_noise_steps
                    ),
                    steps_computation_policy=cache_config.scm_steps_policy,
                ),
                verbose=verbose,
            )

            cache_dit.refresh_context(
                pipeline.transformer_2,
                cache_config=DBCacheConfig().reset(
                    num_inference_steps=num_low_noise_steps,
                    steps_computation_mask=cache_dit.steps_mask(
                        mask_policy=cache_config.scm_steps_mask_policy, total_steps=num_low_noise_steps
                    ),
                    steps_computation_policy=cache_config.scm_steps_policy,
                ),
                verbose=verbose,
            )

    return refresh_cache_context


def enable_cache_for_longcat_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
    """Enable cache-dit for LongCatImage pipeline.

    Args:
        pipeline: The LongCatImage pipeline instance.
        cache_config: DiffusionCacheConfig instance with cache configuration.
    """
    # Build DBCacheConfig for transformer
    db_cache_config = _build_db_cache_config(cache_config)

    calibrator = None
    if cache_config.enable_taylorseer:
        taylorseer_order = cache_config.taylorseer_order
        calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
        logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

    # Build ParamsModifier for transformer
    modifier = ParamsModifier(
        cache_config=db_cache_config,
        calibrator_config=calibrator,
    )

    logger.info(
        f"Enabling cache-dit on LongCatImage transformer with BlockAdapter: "
        f"Fn={db_cache_config.Fn_compute_blocks}, "
        f"Bn={db_cache_config.Bn_compute_blocks}, "
        f"W={db_cache_config.max_warmup_steps}, "
    )

    # Enable cache-dit using BlockAdapter for transformer
    cache_dit.enable_cache(
        (
            BlockAdapter(
                transformer=pipeline.transformer,
                blocks=[
                    pipeline.transformer.transformer_blocks,
                    pipeline.transformer.single_transformer_blocks,
                ],
                forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
                params_modifiers=[modifier],
            )
        ),
        cache_config=db_cache_config,
    )

    def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
        """Refresh cache context for the transformer with new num_inference_steps.

        Args:
            pipeline: The LongCatImage pipeline instance.
            num_inference_steps: New number of inference steps.
        """
        if cache_config.scm_steps_mask_policy is None:
            cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
        else:
            cache_dit.refresh_context(
                pipeline.transformer,
                cache_config=DBCacheConfig().reset(
                    num_inference_steps=num_inference_steps,
                    steps_computation_mask=cache_dit.steps_mask(
                        mask_policy=cache_config.scm_steps_mask_policy,
                        total_steps=num_inference_steps,
                    ),
                    steps_computation_policy=cache_config.scm_steps_policy,
                ),
                verbose=verbose,
            )

    return refresh_cache_context


def enable_cache_for_flux(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
    """Enable cache-dit for Flux.1-dev pipeline.

    Args:
        pipeline: The Flux pipeline instance.
        cache_config: DiffusionCacheConfig instance with cache configuration.
    Returns:
        A refresh function that can be called with a new ``num_inference_steps``
        to update the cache context for the pipeline.
    """
    # Build DBCacheConfig for transformer
    db_cache_config = _build_db_cache_config(cache_config)

    calibrator = None
    if cache_config.enable_taylorseer:
        taylorseer_order = cache_config.taylorseer_order
        calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
        logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

    # Build ParamsModifier for transformer
    modifier = ParamsModifier(
        cache_config=db_cache_config,
        calibrator_config=calibrator,
    )

    logger.info(
        f"Enabling cache-dit on Flux transformer with BlockAdapter: "
        f"Fn={db_cache_config.Fn_compute_blocks}, "
        f"Bn={db_cache_config.Bn_compute_blocks}, "
        f"W={db_cache_config.max_warmup_steps}, "
    )

    # Enable cache-dit using BlockAdapter for transformer
    cache_dit.enable_cache(
        (
            BlockAdapter(
                transformer=pipeline.transformer,
                blocks=[
                    pipeline.transformer.transformer_blocks,
                    pipeline.transformer.single_transformer_blocks,
                ],
                forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
                params_modifiers=[modifier],
            )
        ),
        cache_config=db_cache_config,
    )

    def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
        """Refresh cache context for the transformer with new num_inference_steps.

        Args:
            pipeline: The Flux pipeline instance.
            num_inference_steps: New number of inference steps.
        """
        if cache_config.scm_steps_mask_policy is None:
            cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
        else:
            cache_dit.refresh_context(
                pipeline.transformer,
                cache_config=DBCacheConfig().reset(
                    num_inference_steps=num_inference_steps,
                    steps_computation_mask=cache_dit.steps_mask(
                        mask_policy=cache_config.scm_steps_mask_policy,
                        total_steps=num_inference_steps,
                    ),
                    steps_computation_policy=cache_config.scm_steps_policy,
                ),
                verbose=verbose,
            )

    return refresh_cache_context


def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
    """Enable cache-dit for StableDiffusion3Pipeline.

    Args:
        pipeline: The StableDiffusion3 pipeline instance.
        cache_config: DiffusionCacheConfig instance with cache configuration.
    """
    # Build DBCacheConfig for transformer
    db_cache_config = _build_db_cache_config(cache_config)

    calibrator = None
    if cache_config.enable_taylorseer:
        taylorseer_order = cache_config.taylorseer_order
        calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
        logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

    # Build ParamsModifier for transformer
    modifier = ParamsModifier(
        cache_config=db_cache_config,
        calibrator_config=calibrator,
    )

    logger.info(
        f"Enabling cache-dit on StableDiffusion3 transformer with BlockAdapter: "
        f"Fn={db_cache_config.Fn_compute_blocks}, "
        f"Bn={db_cache_config.Bn_compute_blocks}, "
        f"W={db_cache_config.max_warmup_steps}, "
    )

    # Enable cache-dit using BlockAdapter for transformer
    cache_dit.enable_cache(
        (
            BlockAdapter(
                transformer=pipeline.transformer,
                blocks=pipeline.transformer.transformer_blocks,
                forward_pattern=ForwardPattern.Pattern_1,
                params_modifiers=[modifier],
            )
        ),
        cache_config=db_cache_config,
    )

    def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
        """Refresh cache context for the transformer with new num_inference_steps.

        Args:
            pipeline: The LongCatImage pipeline instance.
            num_inference_steps: New number of inference steps.
        """
        if cache_config.scm_steps_mask_policy is None:
            cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
        else:
            cache_dit.refresh_context(
                pipeline.transformer,
                cache_config=DBCacheConfig().reset(
                    num_inference_steps=num_inference_steps,
                    steps_computation_mask=cache_dit.steps_mask(
                        mask_policy=cache_config.scm_steps_mask_policy,
                        total_steps=num_inference_steps,
                    ),
                    steps_computation_policy=cache_config.scm_steps_policy,
                ),
                verbose=verbose,
            )

    return refresh_cache_context


def enable_cache_for_dit(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
    """Enable cache-dit for regular single-transformer DiT models.

    Args:
        pipeline: The diffusion pipeline instance.
        cache_config: DiffusionCacheConfig instance with cache configuration.

    Returns:
        A refresh function that can be called to update cache context with new num_inference_steps.
    """
    # Build DBCacheConfig with optional SCM support
    db_cache_config = _build_db_cache_config(cache_config)

    # Build calibrator config if TaylorSeer is enabled
    calibrator_config = None
    if cache_config.enable_taylorseer:
        taylorseer_order = cache_config.taylorseer_order
        calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
        logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

    logger.info(
        f"Enabling cache-dit on transformer: "
        f"Fn={db_cache_config.Fn_compute_blocks}, "
        f"Bn={db_cache_config.Bn_compute_blocks}, "
        f"W={db_cache_config.max_warmup_steps}, "
    )

    # Enable cache-dit on the transformer
    cache_dit.enable_cache(
        pipeline.transformer,
        cache_config=db_cache_config,
        calibrator_config=calibrator_config,
    )

    def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
        """Refresh cache context for the transformer with new num_inference_steps.

        Args:
            pipeline: The diffusion pipeline instance.
            num_inference_steps: New number of inference steps.
        """
        if cache_config.scm_steps_mask_policy is None:
            cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
        else:
            cache_dit.refresh_context(
                pipeline.transformer,
                cache_config=DBCacheConfig().reset(
                    num_inference_steps=num_inference_steps,
                    steps_computation_mask=cache_dit.steps_mask(
                        mask_policy=cache_config.scm_steps_mask_policy,
                        total_steps=num_inference_steps,
                    ),
                    steps_computation_policy=cache_config.scm_steps_policy,
                ),
                verbose=verbose,
            )

    return refresh_cache_context


class BagelCachedContextManager(CachedContextManager):
    """
    Custom CachedContextManager for Bagel that safely handles NaiveCache objects
    (mapped to encoder_hidden_states) by skipping tensor operations on them.
    """

    @torch.compiler.disable
    def apply_cache(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
        prefix: str = "Bn",
        encoder_prefix: str = "Bn_encoder",
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        # Allow Bn and Fn prefix to be used for residual cache.
        if "Bn" in prefix:
            hidden_states_prev = self.get_Bn_buffer(prefix)
        else:
            hidden_states_prev = self.get_Fn_buffer(prefix)

        assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"

        if self.is_cache_residual():
            hidden_states = hidden_states_prev + hidden_states
        else:
            # If cache is not residual, we use the hidden states directly
            hidden_states = hidden_states_prev

        hidden_states = hidden_states.contiguous()

        if encoder_hidden_states is not None:
            if "Bn" in encoder_prefix:
                encoder_hidden_states_prev = self.get_Bn_encoder_buffer(encoder_prefix)
            else:
                encoder_hidden_states_prev = self.get_Fn_encoder_buffer(encoder_prefix)

            if encoder_hidden_states_prev is not None:
                if self.is_encoder_cache_residual():
                    # FIX: Check if encoder_hidden_states is a tensor before adding
                    if isinstance(encoder_hidden_states, torch.Tensor) and isinstance(
                        encoder_hidden_states_prev, torch.Tensor
                    ):
                        encoder_hidden_states = encoder_hidden_states_prev + encoder_hidden_states
                else:
                    # If encoder cache is not residual, we use the encoder hidden states directly
                    encoder_hidden_states = encoder_hidden_states_prev

            # FIX: Check if encoder_hidden_states is a tensor before calling contiguous
            if isinstance(encoder_hidden_states, torch.Tensor):
                encoder_hidden_states = encoder_hidden_states.contiguous()

        return hidden_states, encoder_hidden_states


class BagelCachedBlocks(CachedBlocks_Pattern_0_1_2):
    """
    Custom CachedBlocks for Bagel that safely handles NaiveCache objects
    by adding isinstance checks in call_Mn_blocks and compute_or_prune.
    """

    def call_Mn_blocks(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        *args,
        **kwargs,
    ):
        original_hidden_states = hidden_states
        original_encoder_hidden_states = encoder_hidden_states
        for block in self._Mn_blocks():
            hidden_states = block(
                hidden_states,
                encoder_hidden_states,
                *args,
                **kwargs,
            )
            hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states)

        # compute hidden_states residual
        hidden_states = hidden_states.contiguous()

        hidden_states_residual = hidden_states - original_hidden_states

        if (
            encoder_hidden_states is not None
            and original_encoder_hidden_states is not None
            and isinstance(encoder_hidden_states, torch.Tensor)  # FIX: Added Check
        ):
            encoder_hidden_states = encoder_hidden_states.contiguous()
            encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
        else:
            encoder_hidden_states_residual = None

        return (
            hidden_states,
            encoder_hidden_states,
            hidden_states_residual,
            encoder_hidden_states_residual,
        )

    def compute_or_prune(
        self,
        block_id: int,  # Block index in the transformer blocks
        # Below are the inputs to the block
        block,  # The transformer block to be executed
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        *args,
        **kwargs,
    ):
        # NOTE: Although Bagel likely won't use pruning, implementing safe version just in case.
        # Copy-pasted from original but adding checks.

        original_hidden_states = hidden_states
        original_encoder_hidden_states = encoder_hidden_states

        can_use_prune = self._maybe_prune(
            block_id,
            hidden_states,
            prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
        )

        torch._dynamo.graph_break()
        if can_use_prune:
            self.context_manager.add_pruned_step()
            hidden_states, encoder_hidden_states = self.context_manager.apply_prune(
                hidden_states,
                encoder_hidden_states,
                prefix=(
                    f"{self.cache_prefix}_{block_id}_Bn_residual"
                    if self.context_manager.is_cache_residual()
                    else f"{self.cache_prefix}_Bn_hidden_states"
                ),
                encoder_prefix=(
                    f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
                    if self.context_manager.is_encoder_cache_residual()
                    else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
                ),
            )
            torch._dynamo.graph_break()
        else:
            # Normal steps: Compute the block and cache the residuals.
            hidden_states = block(
                hidden_states,
                encoder_hidden_states,
                *args,
                **kwargs,
            )
            hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states)
            if not self._skip_prune(block_id):
                hidden_states = hidden_states.contiguous()
                hidden_states_residual = hidden_states - original_hidden_states

                if (
                    encoder_hidden_states is not None
                    and original_encoder_hidden_states is not None
                    and isinstance(encoder_hidden_states, torch.Tensor)  # FIX: Added Check
                ):
                    encoder_hidden_states = encoder_hidden_states.contiguous()
                    encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
                else:
                    encoder_hidden_states_residual = None

                self.context_manager.set_Fn_buffer(
                    original_hidden_states,
                    prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
                )
                if self.context_manager.is_cache_residual():
                    self.context_manager.set_Bn_buffer(
                        hidden_states_residual,
                        prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
                    )
                else:
                    self.context_manager.set_Bn_buffer(
                        hidden_states,
                        prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
                    )
                if encoder_hidden_states_residual is not None:
                    if self.context_manager.is_encoder_cache_residual():
                        self.context_manager.set_Bn_encoder_buffer(
                            encoder_hidden_states_residual,
                            prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
                        )
                    else:
                        self.context_manager.set_Bn_encoder_buffer(
                            encoder_hidden_states_residual,
                            prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
                        )
            torch._dynamo.graph_break()

        return hidden_states, encoder_hidden_states


class BagelCachedAdapter(CachedAdapter):
    """
    Custom CachedAdapter for Bagel that uses BagelCachedContextManager and BagelCachedBlocks.
    """

    @classmethod
    def create_context(
        cls,
        block_adapter: BlockAdapter,
        **context_kwargs,
    ) -> tuple[list[str], list[dict[str, Any]]]:
        # Override to use BagelCachedContextManager

        BlockAdapter.assert_normalized(block_adapter)

        if BlockAdapter.is_cached(block_adapter.pipe):
            return block_adapter.pipe

        # Check context_kwargs
        context_kwargs = cls.check_context_kwargs(block_adapter, **context_kwargs)

        # Each Pipeline should have it's own context manager instance.
        cache_config: BasicCacheConfig = context_kwargs.get("cache_config", None)
        assert cache_config is not None, "cache_config can not be None."

        # Apply cache on pipeline: wrap cache context
        pipe_cls_name = block_adapter.pipe.__class__.__name__

        # USE CUSTOM CONTEXT MANAGER
        context_manager = BagelCachedContextManager(
            name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
            persistent_context=isinstance(block_adapter.pipe, FakeDiffusionPipeline),
        )

        flatten_contexts, contexts_kwargs = cls.modify_context_params(block_adapter, **context_kwargs)

        block_adapter.pipe._context_manager = context_manager  # instance level

        if not context_manager.persistent_context:
            original_call = block_adapter.pipe.__class__.__call__

            @functools.wraps(original_call)
            def new_call(self, *args, **kwargs):
                with ExitStack() as stack:
                    # cache context will be reset for each pipe inference
                    for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs):
                        stack.enter_context(
                            context_manager.enter_context(
                                context_manager.reset_context(
                                    context_name,
                                    **context_kwargs,
                                ),
                            )
                        )
                    outputs = original_call(self, *args, **kwargs)
                    cls.apply_stats_hooks(block_adapter)
                    return outputs

            block_adapter.pipe.__class__.__call__ = new_call
            block_adapter.pipe.__class__._original_call = original_call

        else:
            # Init persistent cache context for transformer
            for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs):
                context_manager.reset_context(
                    context_name,
                    **context_kwargs,
                )

        block_adapter.pipe.__class__._is_cached = True

        cls.apply_params_hooks(block_adapter, contexts_kwargs)

        return flatten_contexts, contexts_kwargs

    @classmethod
    def collect_unified_blocks(
        cls,
        block_adapter: BlockAdapter,
        contexts_kwargs: list[dict],
    ) -> list[dict[str, torch.nn.ModuleList]]:
        # Override to use BagelCachedBlocks

        BlockAdapter.assert_normalized(block_adapter)

        total_cached_blocks: list[dict[str, torch.nn.ModuleList]] = []
        assert hasattr(block_adapter.pipe, "_context_manager")
        # Skipping isinstance check for ContextManager._supported_managers to avoid import issues

        for i in range(len(block_adapter.transformer)):
            unified_blocks_bind_context = {}
            for j in range(len(block_adapter.blocks[i])):
                cache_config: BasicCacheConfig = contexts_kwargs[i * len(block_adapter.blocks[i]) + j]["cache_config"]

                # Directly instantiate BagelCachedBlocks
                unified_blocks_bind_context[block_adapter.unique_blocks_name[i][j]] = torch.nn.ModuleList(
                    [
                        BagelCachedBlocks(
                            # 0. Transformer blocks configuration
                            block_adapter.blocks[i][j],
                            transformer=block_adapter.transformer[i],
                            forward_pattern=block_adapter.forward_pattern[i][j],
                            check_forward_pattern=block_adapter.check_forward_pattern,
                            check_num_outputs=block_adapter.check_num_outputs,
                            # 1. Cache/Prune context configuration
                            cache_prefix=block_adapter.blocks_name[i][j],
                            cache_context=block_adapter.unique_blocks_name[i][j],
                            context_manager=block_adapter.pipe._context_manager,
                            cache_type=cache_config.cache_type,
                        )
                    ]
                )

            total_cached_blocks.append(unified_blocks_bind_context)

        return total_cached_blocks


def enable_cache_for_bagel(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
    """Enable cache-dit for Bagel model (via OmniDiffusion pipeline).

    Args:
        pipeline: The OmniDiffusion pipeline instance.
        cache_config: DiffusionCacheConfig instance with cache configuration.

    Returns:
        A refresh function that can be called to update cache context with new num_inference_steps.
    """
    # Build DBCacheConfig
    db_cache_config = _build_db_cache_config(cache_config)

    # Build calibrator config if TaylorSeer is enabled
    calibrator_config = None
    if cache_config.enable_taylorseer:
        taylorseer_order = cache_config.taylorseer_order
        calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
        logger.info(f"TaylorSeer enabled with order={taylorseer_order}")

    # Access the transformer: BagelPipeline -> Qwen2MoTForCausalLM -> Qwen2MoTModel
    # BagelPipeline has self.language_model which is Qwen2MoTForCausalLM
    # Qwen2MoTForCausalLM has self.model which is Qwen2MoTModel
    transformer = pipeline.language_model.model

    logger.info(
        f"Enabling cache-dit on Bagel transformer: "
        f"Fn={db_cache_config.Fn_compute_blocks}, "
        f"Bn={db_cache_config.Bn_compute_blocks}, "
        f"W={db_cache_config.max_warmup_steps}, "
    )

    # Enable cache-dit on the transformer
    # Pattern_0 corresponds to (hidden_states, encoder_hidden_states) input, output
    # Custom adapter for Bagel to handle NaiveCache correctly
    # from vllm_omni.diffusion.cache.bagel_cache_adapter import BagelCachedAdapter # No longer needed
    BagelCachedAdapter.apply(
        BlockAdapter(
            transformer=transformer,
            blocks=transformer.layers,
            forward_pattern=ForwardPattern.Pattern_0,
        ),
        cache_config=db_cache_config,
        calibrator_config=calibrator_config,
    )

    def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
        transformer = pipeline.language_model.model
        if cache_config.scm_steps_mask_policy is None:
            cache_dit.refresh_context(transformer, num_inference_steps=num_inference_steps, verbose=verbose)
        else:
            cache_dit.refresh_context(
                transformer,
                cache_config=DBCacheConfig().reset(
                    num_inference_steps=num_inference_steps,
                    steps_computation_mask=cache_dit.steps_mask(
                        mask_policy=cache_config.scm_steps_mask_policy,
                        total_steps=num_inference_steps,
                    ),
                    steps_computation_policy=cache_config.scm_steps_policy,
                ),
                verbose=verbose,
            )

    return refresh_cache_context


# Register custom cache-dit enablers after function definitions
CUSTOM_DIT_ENABLERS.update(
    {
        "Wan22Pipeline": enable_cache_for_wan22,
        "Wan22I2VPipeline": enable_cache_for_wan22,
        "Wan22TI2VPipeline": enable_cache_for_wan22,
        "FluxPipeline": enable_cache_for_flux,
        "LongCatImagePipeline": enable_cache_for_longcat_image,
        "LongCatImageEditPipeline": enable_cache_for_longcat_image,
        "StableDiffusion3Pipeline": enable_cache_for_sd3,
        "BagelPipeline": enable_cache_for_bagel,
    }
)


class CacheDiTBackend(CacheBackend):
    """Backend class for cache-dit acceleration on diffusion pipelines.

    This class implements cache-dit acceleration (DBCache, SCM, TaylorSeer) using
    the cache-dit library. It inherits from CacheBackend and provides a unified
    interface for managing cache-dit acceleration on diffusion models.

    Attributes:
        config: Cache configuration (DiffusionCacheConfig instance), inherited from CacheBackend.
        enabled: Whether cache-dit is enabled on this pipeline, inherited from CacheBackend.
        _refresh_func: Internal refresh function for updating cache context.
        _last_num_inference_steps: Last num_inference_steps used for refresh optimization.
    """

    def __init__(self, cache_config: Any = None):
        """Initialize the cache-dit backend.

        Args:
            cache_config: Cache configuration (DiffusionCacheConfig instance, dict, or None).
                         If None or empty, uses default DiffusionCacheConfig().
        """
        # Use default config if cache_config is not provided or is empty
        if cache_config is None:
            config = DiffusionCacheConfig()
        elif isinstance(cache_config, dict):
            # Convert dict to DiffusionCacheConfig, using defaults for missing keys
            config = DiffusionCacheConfig.from_dict(cache_config)
        else:
            config = cache_config

        # Initialize base class with normalized config
        super().__init__(config)

        # Cache-dit specific attributes
        self._refresh_func: Callable[[Any, int, bool], None] | None = None
        self._last_num_inference_steps: int | None = None

    def enable(self, pipeline: Any) -> None:
        """Enable cache-dit on the pipeline if configured.

        This method applies cache-dit acceleration to the appropriate transformer(s)
        in the pipeline. It handles both single-transformer and dual-transformer
        architectures (e.g., Wan2.2).

        Args:
            pipeline: The diffusion pipeline instance.
        """

        # Extract pipeline name from pipeline
        pipeline_name = pipeline.__class__.__name__
        # Check if this model has a custom cache-dit enabler
        if pipeline_name in CUSTOM_DIT_ENABLERS:
            logger.info(f"Using custom cache-dit enabler for model: {pipeline_name}")
            self._refresh_func = CUSTOM_DIT_ENABLERS[pipeline_name](pipeline, self.config)
        else:
            # For regular single-transformer models
            self._refresh_func = enable_cache_for_dit(pipeline, self.config)

        self.enabled = True
        logger.info(f"Cache-dit enabled successfully on {pipeline_name}")

    def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
        """Refresh cache context with new num_inference_steps.

        This method updates the cache context when num_inference_steps changes
        during inference. For dual-transformer models (e.g., Wan2.2), it automatically
        splits the steps based on boundary_ratio.

        Args:
            pipeline: The diffusion pipeline instance.
            num_inference_steps: New number of inference steps.
            verbose: Whether to log refresh operations.
        """
        if not self.enabled or self._refresh_func is None:
            logger.warning("Cache-dit is not enabled. Cannot refresh cache context.")
            return

        # Only refresh if num_inference_steps has changed
        if self._last_num_inference_steps is None or num_inference_steps != self._last_num_inference_steps:
            if verbose:
                logger.info(f"Refreshing cache context for transformer with num_inference_steps: {num_inference_steps}")
            self._refresh_func(pipeline, num_inference_steps, verbose)
            self._last_num_inference_steps = num_inference_steps

    def is_enabled(self) -> bool:
        """Check if cache-dit is enabled on this pipeline.

        Returns:
            True if cache-dit is enabled, False otherwise.
        """
        return self.enabled


def may_enable_cache_dit(pipeline: Any, od_config: OmniDiffusionConfig) -> Optional["CacheDiTBackend"]:
    """Enable cache-dit on the pipeline if configured (convenience function).

    This is a convenience function that creates and enables a CacheDiTBackend.
    For new code, consider using CacheDiTBackend directly.

    Args:
        pipeline: The diffusion pipeline instance.
        od_config: OmniDiffusionConfig with cache configuration.

    Returns:
        A CacheDiTBackend instance if cache-dit is enabled, None otherwise.
    """
    if od_config.cache_backend != "cache-dit" or not od_config.cache_config:
        return None

    backend = CacheDiTBackend(od_config.cache_config)
    backend.enable(pipeline)
    return backend if backend.is_enabled() else None
