# adapted from sglang and fastvideo
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import os
import random
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any

import torch
from pydantic import model_validator
from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger

from vllm_omni.diffusion.quantization import (
    DiffusionQuantizationConfig,
    get_diffusion_quant_config,
)
from vllm_omni.diffusion.utils.network_utils import is_port_available

if TYPE_CHECKING:
    from vllm_omni.diffusion.quantization import DiffusionQuantizationConfig

# Import after TYPE_CHECKING to avoid circular imports at runtime
# The actual import is deferred to __post_init__ to avoid import order issues

logger = init_logger(__name__)


@config
@dataclass
class DiffusionParallelConfig:
    """Configuration for diffusion model distributed execution."""

    pipeline_parallel_size: int = 1
    """Number of pipeline parallel stages."""

    data_parallel_size: int = 1
    """Number of data parallel groups."""

    tensor_parallel_size: int = 1
    """Number of tensor parallel groups."""

    sequence_parallel_size: int | None = None
    """Number of sequence parallel groups. sequence_parallel_size = ring_degree * ulysses_degree"""

    ulysses_degree: int = 1
    """Number of GPUs used for ulysses sequence parallelism."""

    ring_degree: int = 1
    """Number of GPUs used for ring sequence parallelism."""

    cfg_parallel_size: int = 1
    """Number of Classifier Free Guidance (CFG) parallel groups."""

    vae_patch_parallel_size: int = 1
    """Number of ranks used for VAE patch/tile parallelism (decode/encode)."""

    use_hsdp: bool = False
    """Enable Hybrid Sharded Data Parallel (HSDP) for model weight sharding."""

    hsdp_shard_size: int = -1
    """Number of GPUs to shard weights across within each replica group. -1 means auto-calculate."""

    hsdp_replicate_size: int = 1
    """Number of replica groups for HSDP. Each replica holds a full sharded copy."""

    @model_validator(mode="after")
    def _validate_parallel_config(self) -> Self:
        """Validates the config relationships among the parallel strategies."""
        assert self.pipeline_parallel_size > 0, "Pipeline parallel size must be > 0"
        assert self.data_parallel_size > 0, "Data parallel size must be > 0"
        assert self.tensor_parallel_size > 0, "Tensor parallel size must be > 0"
        assert self.sequence_parallel_size > 0, "Sequence parallel size must be > 0"
        assert self.ulysses_degree > 0, "Ulysses degree must be > 0"
        assert self.ring_degree > 0, "Ring degree must be > 0"
        assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0"
        assert self.cfg_parallel_size in [1, 2], f"CFG parallel size must be 1 or 2, but got {self.cfg_parallel_size}"
        assert self.vae_patch_parallel_size > 0, "VAE patch parallel size must be > 0"
        assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, (
            "Sequence parallel size must be equal to the product of ulysses degree and ring degree,"
            f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}"
        )

        # Validate HSDP configuration
        if self.use_hsdp:
            assert self.hsdp_replicate_size > 0, "HSDP replicate size must be > 0"
            assert self.hsdp_shard_size > 0, "HSDP shard size must be > 0 (should be set in __post_init__)"
        return self

    def __post_init__(self) -> None:
        if self.sequence_parallel_size is None:
            self.sequence_parallel_size = self.ulysses_degree * self.ring_degree

        # Calculate world_size from other parallelism dimensions
        other_parallel_world_size = (
            self.pipeline_parallel_size
            * self.data_parallel_size
            * self.tensor_parallel_size
            * self.ulysses_degree
            * self.ring_degree
            * self.cfg_parallel_size
        )

        # Handle HSDP configuration
        # HSDP can work in two modes:
        # 1. Standalone: when other parallelism is all 1, HSDP determines world_size
        # 2. Combined: HSDP overlays on top of other parallelism
        if self.use_hsdp:
            if self.tensor_parallel_size > 1:
                raise ValueError(
                    "HSDP (use_hsdp=True) cannot be used with Tensor Parallelism "
                    f"(tensor_parallel_size={self.tensor_parallel_size}). "
                    "Set tensor_parallel_size=1 when using HSDP."
                )
            if self.hsdp_shard_size == -1:
                # Auto-calculate: use other_parallel_world_size as shard_size
                if self.hsdp_replicate_size <= 0:
                    raise ValueError("hsdp_replicate_size must be > 0")
                if other_parallel_world_size == 1:
                    raise ValueError(
                        "Cannot auto-calculate hsdp_shard_size when other parallelism is all 1. "
                        "Please specify hsdp_shard_size explicitly for standalone HSDP."
                    )
                if other_parallel_world_size % self.hsdp_replicate_size != 0:
                    raise ValueError(
                        f"Invalid HSDP configuration: replicate_size ({self.hsdp_replicate_size}) "
                        f"must evenly divide world_size ({other_parallel_world_size}) when shard_size is -1."
                    )
                self.hsdp_shard_size = other_parallel_world_size // self.hsdp_replicate_size
                self.world_size = other_parallel_world_size
            else:
                # Explicit shard_size: HSDP can work standalone or combined
                hsdp_world_size = self.hsdp_replicate_size * self.hsdp_shard_size
                if other_parallel_world_size == 1:
                    # Standalone HSDP: world_size is determined by HSDP
                    self.world_size = hsdp_world_size
                else:
                    # Combined: HSDP must match other parallelism world_size
                    if hsdp_world_size != other_parallel_world_size:
                        raise ValueError(
                            f"HSDP dimensions "
                            f"({self.hsdp_replicate_size} × {self.hsdp_shard_size} = {hsdp_world_size}) "
                            f"must equal world_size from other parallelism ({other_parallel_world_size})"
                        )
                    self.world_size = other_parallel_world_size
        else:
            self.world_size = other_parallel_world_size

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "DiffusionParallelConfig":
        """
        Create DiffusionParallelConfig from a dictionary.

        Args:
            data: Dictionary containing parallel configuration parameters

        Returns:
            DiffusionParallelConfig instance with parameters set from dict
        """
        if not isinstance(data, dict):
            raise TypeError(f"Expected parallel config dict, got {type(data)!r}")
        return cls(**data)


@dataclass
class TransformerConfig:
    """Container for raw transformer configuration dictionaries."""

    params: dict[str, Any] = field(default_factory=dict)

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "TransformerConfig":
        if not isinstance(data, dict):
            raise TypeError(f"Expected transformer config dict, got {type(data)!r}")
        return cls(params=dict(data))

    def to_dict(self) -> dict[str, Any]:
        return dict(self.params)

    def get(self, key: str, default: Any | None = None) -> Any:
        return self.params.get(key, default)

    def __getattr__(self, item: str) -> Any:
        params = object.__getattribute__(self, "params")
        try:
            return params[item]
        except KeyError as exc:
            raise AttributeError(item) from exc


@dataclass
class DiffusionCacheConfig:
    """
    Configuration for cache adapters (TeaCache, cache-dit, etc.).

    This dataclass provides a unified interface for cache configuration parameters.
    It can be initialized from a dictionary and accessed via attributes.

    Common parameters:
        - TeaCache: rel_l1_thresh, coefficients (optional)
        - cache-dit: Fn_compute_blocks, Bn_compute_blocks, max_warmup_steps,
                    residual_diff_threshold, enable_taylorseer, taylorseer_order,
                    scm_steps_mask_policy, scm_steps_policy

    Example:
        >>> # From dict (user-facing API) - partial config uses defaults for missing keys
        >>> config = DiffusionCacheConfig.from_dict({"rel_l1_thresh": 0.3})
        >>> # Access via attribute
        >>> print(config.rel_l1_thresh)  # 0.3 (from dict)
        >>> print(config.Fn_compute_blocks)  # 8 (default)
        >>> # Empty dict uses all defaults
        >>> default_config = DiffusionCacheConfig.from_dict({})
        >>> print(default_config.rel_l1_thresh)  # 0.2 (default)
    """

    # TeaCache parameters [tea_cache only]
    # Default: 0.2 provides ~1.5x speedup with minimal quality loss (optimal balance)
    rel_l1_thresh: float = 0.2
    coefficients: list[float] | None = None  # Uses model-specific defaults if None

    # cache-dit parameters [cache-dit only]
    # Default: 1 forward compute block (optimized for single-transformer models)
    # Use 1 as default instead of cache-dit's 8, optimized for single-transformer models
    # This provides better performance while maintaining quality for most use cases
    Fn_compute_blocks: int = 1
    # Default: 0 backward compute blocks (no fusion by default)
    Bn_compute_blocks: int = 0
    # Default: 4 warmup steps (optimized for few-step distilled models like Z-Image with 8 steps)
    # Use 4 as default warmup steps instead of 8 in cache-dit, making DBCache work
    # for few-step distilled models (e.g., Z-Image with 8 steps)
    max_warmup_steps: int = 4
    # Default: -1 (unlimited cached steps) - DBCache disables caching when previous cached steps exceed this value
    # to prevent precision degradation. Set to -1 for unlimited caching (cache-dit default).
    max_cached_steps: int = -1
    # Default: 0.24 residual difference threshold (higher for more aggressive caching)
    # Use a relatively higher residual diff threshold (0.24) as default to allow more
    # aggressive caching. This is safe because we have max_continuous_cached_steps limit.
    # Without this limit, a lower threshold like 0.12 would be needed.
    residual_diff_threshold: float = 0.24
    # Default: Limit consecutive cached steps to 3 to prevent precision degradation
    # This allows us to use a higher residual_diff_threshold for more aggressive caching
    max_continuous_cached_steps: int = 3
    # Default: Disable TaylorSeer (not suitable for few-step distilled models)
    # TaylorSeer is not suitable for few-step distilled models, so we disable it by default.
    # References:
    # - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
    # - Forecast then Calibrate: Feature Caching as ODE for Efficient Diffusion Transformers
    enable_taylorseer: bool = False
    # Default: 1st order TaylorSeer polynomial
    taylorseer_order: int = 1
    # Default: None SCM mask policy (disabled by default)
    scm_steps_mask_policy: str | None = None
    # Default: "dynamic" steps policy for adaptive caching
    scm_steps_policy: str = "dynamic"
    # Used by cache-dit for scm mask generation. If this value changes during inference,
    # we will re-generate the scm mask and refresh the cache context.
    num_inference_steps: int | None = None

    # Additional parameters that may be passed but not explicitly defined
    _extra_params: dict[str, Any] = field(default_factory=dict, repr=False)

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "DiffusionCacheConfig":
        """
        Create DiffusionCacheConfig from a dictionary.

        Args:
            data: Dictionary containing cache configuration parameters

        Returns:
            DiffusionCacheConfig instance with parameters set from dict
        """
        if not isinstance(data, dict):
            raise TypeError(f"Expected cache config dict, got {type(data)!r}")

        # Get all dataclass field names automatically
        field_names = {f.name for f in fields(cls)}

        # Extract parameters that match dataclass fields (excluding private fields)
        known_params = {k: v for k, v in data.items() if k in field_names and not k.startswith("_")}

        # Store extra parameters
        extra_params = {k: v for k, v in data.items() if k not in field_names}

        # Create instance with known params (missing ones will use defaults)
        # Then update _extra_params after creation since it's a private field
        instance = cls(**known_params, _extra_params=extra_params)
        return instance

    def __getattr__(self, item: str) -> Any:
        """
        Allow access to extra parameters via attribute access.

        This enables accessing parameters that weren't explicitly defined
        in the dataclass fields but were passed in the dict.
        """
        if item == "_extra_params" or item.startswith("_"):
            return object.__getattribute__(self, item)

        extra = object.__getattribute__(self, "_extra_params")
        if item in extra:
            return extra[item]

        raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")


@dataclass
class OmniDiffusionConfig:
    # Model and path configuration (for convenience)
    model: str | None = None

    model_class_name: str | None = None

    dtype: torch.dtype = torch.bfloat16

    tf_model_config: TransformerConfig = field(default_factory=TransformerConfig)

    # Attention
    attention_backend: str | None = None

    # Running mode
    # mode: ExecutionMode = ExecutionMode.INFERENCE

    # Workload type
    # workload_type: WorkloadType = WorkloadType.T2V

    # Cache strategy (legacy)
    cache_strategy: str = "none"
    parallel_config: DiffusionParallelConfig = field(default_factory=DiffusionParallelConfig)

    # Cache backend configuration (NEW)
    cache_backend: str = "none"  # "tea_cache", "deep_cache", etc.
    cache_config: DiffusionCacheConfig | dict[str, Any] = field(default_factory=dict)
    enable_cache_dit_summary: bool = False

    # Distributed executor backend
    distributed_executor_backend: str = "mp"
    nccl_port: int | None = None

    # HuggingFace specific parameters
    trust_remote_code: bool = False
    revision: str | None = None

    num_gpus: int | None = None

    dist_timeout: int | None = None  # timeout for torch.distributed

    # pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False)

    # LoRA parameters
    lora_path: str | None = None
    lora_scale: float = 1.0
    max_cpu_loras: int | None = None

    output_type: str = "pil"

    # CPU offload parameters
    # When enabled, DiT and encoders swap GPU access (mutual exclusion):
    # - Text encoders run on GPU while DiT is on CPU
    # - DiT runs on GPU while encoders are on CPU
    enable_cpu_offload: bool = False
    # Layer-wise offloading (block-level offloading) parameters
    enable_layerwise_offload: bool = False

    pin_cpu_memory: bool = True  # Use pinned memory for faster transfers when offloading

    # VAE memory optimization parameters
    vae_use_slicing: bool = False
    vae_use_tiling: bool = False

    # STA (Sliding Tile Attention) parameters
    mask_strategy_file_path: str | None = None
    # STA_mode: STA_Mode = STA_Mode.STA_INFERENCE
    skip_time_steps: int = 15

    # Compilation
    enforce_eager: bool = False

    # Enable sleep mode
    enable_sleep_mode: bool = False

    disable_autocast: bool = False

    # VSA parameters
    VSA_sparsity: float = 0.0  # inference/validation sparsity

    # V-MoBA parameters
    moba_config_path: str | None = None
    # moba_config: dict[str, Any] = field(default_factory=dict)

    # Master port for distributed inference
    # TODO: do not hard code
    master_port: int | None = None

    # Worker extension class for custom functionality
    worker_extension_cls: str | None = None

    # Custom pipeline arguments for custom pipelines
    custom_pipeline_args: dict[str, Any] | None = None

    # Diffusion model loading format
    diffusion_load_format: str = "default"  # "default", "custom_pipeline", "dummy"

    # http server endpoint config, would be ignored in local mode
    host: str | None = None
    port: int | None = None

    scheduler_port: int = 5555

    # Stage verification
    enable_stage_verification: bool = True

    # Prompt text file for batch processing
    prompt_file_path: str | None = None

    # model paths for correct deallocation
    model_paths: dict[str, str] = field(default_factory=dict)
    model_loaded: dict[str, bool] = field(
        default_factory=lambda: {
            "transformer": True,
            "vae": True,
        }
    )
    override_transformer_cls_name: str | None = None

    # # DMD parameters
    # dmd_denoising_steps: List[int] | None = field(default=None)

    # MoE parameters used by Wan2.2
    boundary_ratio: float | None = None
    # Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p)
    flow_shift: float | None = None

    # support multi images input
    supports_multimodal_inputs: bool = False

    log_level: str = "info"

    # Omni configuration (injected from stage config)
    omni_kv_config: dict[str, Any] = field(default_factory=dict)

    # Quantization settings
    # Supported methods: "fp8" (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs)
    quantization: str | None = None
    quantization_config: "DiffusionQuantizationConfig | dict[str, Any] | None" = None

    def settle_port(self, port: int, port_inc: int = 42, max_attempts: int = 100) -> int:
        """
        Find an available port with retry logic.

        Args:
            port: Initial port to check
            port_inc: Port increment for each attempt
            max_attempts: Maximum number of attempts to find an available port

        Returns:
            An available port number

        Raises:
            RuntimeError: If no available port is found after max_attempts
        """
        attempts = 0
        original_port = port

        while attempts < max_attempts:
            if is_port_available(port):
                if attempts > 0:
                    logger.info(f"Port {original_port} was unavailable, using port {port} instead")
                return port

            attempts += 1
            if port < 60000:
                port += port_inc
            else:
                # Wrap around with randomization to avoid collision
                port = 5000 + random.randint(0, 1000)

        raise RuntimeError(
            f"Failed to find available port after {max_attempts} attempts (started from port {original_port})"
        )

    def __post_init__(self):
        # TODO: remove hard code
        initial_master_port = (self.master_port or 30005) + random.randint(0, 100)
        self.master_port = self.settle_port(initial_master_port, 37)

        # Convert parallel_config dict to DiffusionParallelConfig if needed
        # This must be done before accessing parallel_config.world_size
        if isinstance(self.parallel_config, dict):
            self.parallel_config = DiffusionParallelConfig.from_dict(self.parallel_config)
        elif not isinstance(self.parallel_config, DiffusionParallelConfig):
            # If it's neither dict nor DiffusionParallelConfig, use default config
            self.parallel_config = DiffusionParallelConfig()

        if self.num_gpus is None:
            if self.parallel_config is not None:
                self.num_gpus = self.parallel_config.world_size
            else:
                self.num_gpus = 1

        if self.num_gpus < self.parallel_config.world_size:
            raise ValueError(
                f"num_gpus ({self.num_gpus}) < parallel_config.world_size ({self.parallel_config.world_size})"
            )

        # Convert string dtype to torch.dtype if needed
        if isinstance(self.dtype, str):
            dtype_map = {
                "auto": torch.bfloat16,
                "bfloat16": torch.bfloat16,
                "bf16": torch.bfloat16,
                "float16": torch.float16,
                "fp16": torch.float16,
                "half": torch.float16,
                "float32": torch.float32,
                "fp32": torch.float32,
                "float": torch.float32,
            }
            dtype_lower = self.dtype.lower()
            if dtype_lower in dtype_map:
                self.dtype = dtype_map[dtype_lower]
            else:
                logger.warning(f"Unknown dtype string '{self.dtype}', defaulting to bfloat16")
                self.dtype = torch.bfloat16

        # Convert cache_config dict to DiffusionCacheConfig if needed
        if isinstance(self.cache_config, dict):
            self.cache_config = DiffusionCacheConfig.from_dict(self.cache_config)
        elif not isinstance(self.cache_config, DiffusionCacheConfig):
            # If it's neither dict nor DiffusionCacheConfig, convert to empty config
            self.cache_config = DiffusionCacheConfig()

        # Convert quantization config (deferred import to avoid circular imports)
        if self.quantization is not None or self.quantization_config is not None:
            from vllm_omni.diffusion.quantization import (
                DiffusionQuantizationConfig,
            )

            # Handle dict or DictConfig (from OmegaConf) - use Mapping for broader compatibility
            if isinstance(self.quantization_config, Mapping):
                # Convert DictConfig to dict if needed (OmegaConf compatibility)
                config_dict = dict(self.quantization_config)
                # Use get() instead of pop() to avoid mutating original dict
                quant_method = config_dict.get("method", self.quantization)
                # Filter out "method" key for kwargs
                quant_kwargs = {k: v for k, v in config_dict.items() if k != "method"}

                # Validate conflicting methods
                if self.quantization is not None and quant_method is not None and quant_method != self.quantization:
                    logger.warning(
                        f"Conflicting quantization methods: quantization={self.quantization!r}, "
                        f"quantization_config['method']={quant_method!r}. Using quantization_config['method']."
                    )

                self.quantization_config = get_diffusion_quant_config(quant_method, **quant_kwargs)
            elif self.quantization_config is None and self.quantization is not None:
                self.quantization_config = get_diffusion_quant_config(self.quantization)
            elif not isinstance(self.quantization_config, DiffusionQuantizationConfig):
                raise TypeError(
                    f"quantization_config must be a DiffusionQuantizationConfig, dict, or None, "
                    f"got {type(self.quantization_config)!r}"
                )

        if self.max_cpu_loras is None:
            self.max_cpu_loras = 1
        elif self.max_cpu_loras < 1:
            raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA")

    def update_multimodal_support(self) -> None:
        self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"}

    @classmethod
    def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig":
        # Backwards-compatibility: older callers may use a diffusion-specific
        # "static_lora_scale" kwarg. Normalize it to the canonical "lora_scale"
        # before constructing the dataclass to avoid TypeError on unknown fields.
        if "static_lora_scale" in kwargs:
            if "lora_scale" not in kwargs:
                kwargs["lora_scale"] = kwargs["static_lora_scale"]
            kwargs.pop("static_lora_scale", None)

        # Check environment variable as fallback for cache_backend
        # Support both old DIFFUSION_CACHE_ADAPTER and new DIFFUSION_CACHE_BACKEND for backwards compatibility
        if "cache_backend" not in kwargs:
            cache_backend = os.environ.get("DIFFUSION_CACHE_BACKEND") or os.environ.get("DIFFUSION_CACHE_ADAPTER")
            kwargs["cache_backend"] = cache_backend.lower() if cache_backend else "none"

        # Filter kwargs to only include valid fields
        valid_fields = {f.name for f in fields(cls)}
        filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields}

        return cls(**filtered_kwargs)


@dataclass
class DiffusionOutput:
    """
    Final output (after pipeline completion)
    """

    output: torch.Tensor | None = None
    trajectory_timesteps: list[torch.Tensor] | None = None
    trajectory_latents: torch.Tensor | None = None
    trajectory_decoded: list[torch.Tensor] | None = None
    error: str | None = None

    post_process_func: Callable[..., Any] | None = None

    # logged timings info, directly from Req.timings
    # timings: Optional["RequestTimings"] = None


class AttentionBackendEnum(enum.Enum):
    FA = enum.auto()
    SLIDING_TILE_ATTN = enum.auto()
    TORCH_SDPA = enum.auto()
    SAGE_ATTN = enum.auto()
    SAGE_ATTN_THREE = enum.auto()
    VIDEO_SPARSE_ATTN = enum.auto()
    VMOBA_ATTN = enum.auto()
    AITER = enum.auto()
    NO_ATTENTION = enum.auto()

    def __str__(self):
        return self.name.lower()


# Special message broadcast via scheduler queues to signal worker shutdown.
SHUTDOWN_MESSAGE = {"type": "shutdown"}
