import torch
from typing import Optional

from diffusers.models.modeling_utils import ModelMixin
from cache_dit.parallelism.backend import ParallelismBackend
from cache_dit.parallelism.config import ParallelismConfig
from cache_dit.logger import init_logger

try:
    from diffusers import ContextParallelConfig  # noqa: F401
    from cache_dit.parallelism.attention import (
        _maybe_register_custom_attn_backends,
        _is_diffusers_parallelism_available,
        enable_ulysses_anything,
        enable_ulysses_float8,
    )
    from .cp_plan_registers import ControlNetContextParallelismPlannerRegister
    from .cp_planners import _activate_controlnet_cp_planners

    _maybe_register_custom_attn_backends()
    _activate_controlnet_cp_planners()
except ImportError as e:
    raise ImportError(e)


logger = init_logger(__name__)


def maybe_enable_context_parallelism(
    controlnet: torch.nn.Module,
    parallelism_config: Optional[ParallelismConfig],
) -> torch.nn.Module:
    assert isinstance(controlnet, ModelMixin), (
        "controlnet must be an instance of diffusers' ModelMixin, " f"but got {type(controlnet)}"
    )
    if parallelism_config is None:
        return controlnet

    assert isinstance(parallelism_config, ParallelismConfig), (
        "parallelism_config must be an instance of ParallelismConfig"
        f" but got {type(parallelism_config)}"
    )

    if (
        parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
        and _is_diffusers_parallelism_available()
    ):
        cp_config = None
        if parallelism_config.ulysses_size is not None or parallelism_config.ring_size is not None:
            cp_config = ContextParallelConfig(
                ulysses_degree=parallelism_config.ulysses_size,
                ring_degree=parallelism_config.ring_size,
            )
        if cp_config is not None:
            experimental_ulysses_anything = parallelism_config.parallel_kwargs.get(
                "experimental_ulysses_anything", False
            )
            # Float8 all_to_all for Ulysses Attention/Ulysses Anything Attention
            experimental_ulysses_float8 = parallelism_config.parallel_kwargs.get(
                "experimental_ulysses_float8", False
            )

            # Must call enable_ulysses_anything before enable_ulysses_float8.
            if experimental_ulysses_anything:
                enable_ulysses_anything()

            if experimental_ulysses_float8:
                enable_ulysses_float8()

            if hasattr(controlnet, "enable_parallelism"):
                # Prefer custom cp_plan if provided
                cp_plan = parallelism_config.parallel_kwargs.get("cp_plan", None)
                if cp_plan is not None:
                    logger.info(f"Using custom context parallelism plan: {cp_plan}")
                else:
                    # Try get context parallelism plan from register if not provided
                    extra_parallel_kwargs = {}
                    if parallelism_config.parallel_kwargs is not None:
                        extra_parallel_kwargs = parallelism_config.parallel_kwargs
                    cp_plan = ControlNetContextParallelismPlannerRegister.get_planner(
                        controlnet
                    )().apply(controlnet=controlnet, **extra_parallel_kwargs)

                controlnet.enable_parallelism(config=cp_config, cp_plan=cp_plan)

            else:
                raise ValueError(
                    f"{controlnet.__class__.__name__} does not support context parallelism."
                )

    return controlnet
