# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch
from vllm.logger import init_logger

# import torch.distributed as dist # Not used directly here, but good practice if needed
from vllm_omni.diffusion.attention.backends.ring.ring_globals import HAS_FA3, HAS_FLASH_ATTN
from vllm_omni.diffusion.attention.backends.ring.ring_selector import AttnType
from vllm_omni.diffusion.attention.parallel.base import (
    ParallelAttentionContext,
    # ParallelAttentionStrategy, # Not used in type hint below currently
)
from vllm_omni.diffusion.distributed.group_coordinator import SequenceParallelGroupCoordinator

# from vllm_omni.diffusion.attention.backends.ring_selector import AttnType # Already imported above
from vllm_omni.diffusion.forward_context import get_forward_context

if TYPE_CHECKING:
    from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata


@dataclass(frozen=True, slots=True)
class _RingCtx(ParallelAttentionContext):
    """Per-forward context for Ring sequence-parallel attention."""

    # Ring attention typically doesn't need complex context for post-processing
    # as the output is already correctly sharded along sequence dimension.
    pass


class RingParallelAttention:
    """Ring sequence-parallel strategy.

    This strategy prepares inputs for Ring Attention.
    Key responsibilities:
    - Concatenate joint_query (Text) to query (Image) if present.
    - Keep joint_key/value separate in metadata for the Ring kernel to handle as static prefix.
    """

    def __init__(
        self,
        sp_group: SequenceParallelGroupCoordinator,
        attn_backend_pref: str | None = None,
    ) -> None:
        self._sp_group = sp_group
        self.attn_backend_pref = attn_backend_pref

    @property
    def enabled(self) -> bool:
        return True

    @property
    def name(self) -> str:
        return "ring"

    def pre_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_metadata: AttentionMetadata | None,
    ):
        joint_tensor_query = None
        joint_strategy = "front"

        if attn_metadata is not None:
            joint_tensor_query = attn_metadata.joint_query
            joint_strategy = attn_metadata.joint_strategy

        if joint_tensor_query is not None:
            supported_joint_strategy = ["front", "rear"]
            if joint_strategy not in supported_joint_strategy:
                raise ValueError(f"joint_strategy: {joint_strategy} not supported.")

            if joint_strategy == "front":
                query = torch.cat([joint_tensor_query, query], dim=1)
            else:
                query = torch.cat([query, joint_tensor_query], dim=1)

            # Note: We do NOT concatenate joint_key/value here.
            # They are preserved in attn_metadata and will be passed
            # explicitly to ring_flash_attn_func.

        ctx = _RingCtx(name=self.name)
        return query, key, value, attn_metadata, ctx

    def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor:
        # Ring attention output is already sharded correctly along sequence dimension.
        return attn_output

    def run_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_metadata: AttentionMetadata | None,
        softmax_scale: float | None = None,
        causal: bool = False,
    ) -> torch.Tensor:
        """Run the actual Ring Attention kernel."""
        if softmax_scale is None:
            softmax_scale = query.shape[-1] ** -0.5

        backend_pref = self.attn_backend_pref
        if backend_pref is None:
            try:
                config = get_forward_context().omni_diffusion_config
                # config might not have attention_backend attribute if not updated
                backend_pref = getattr(config, "attention_backend", None)
            except Exception:
                backend_pref = None

        # Determine attention type with fallback chain: FA3 -> FA2 -> SDPA
        # FP32 is not supported by Flash Attention, force SDPA
        if query.dtype == torch.float32:
            backend_pref = "sdpa"
        elif not HAS_FA3 and not HAS_FLASH_ATTN:
            if backend_pref != "sdpa":
                logger = init_logger(__name__)
                logger.warning_once("Flash Attention (FA2/FA3) is not available! Force enabling SDPA.")
            backend_pref = "sdpa"

        # Extract joint tensors
        joint_key, joint_value = None, None
        joint_strategy = "front"
        if attn_metadata is not None:
            joint_key = attn_metadata.joint_key
            joint_value = attn_metadata.joint_value
            if attn_metadata.joint_strategy is not None:
                joint_strategy = attn_metadata.joint_strategy

        if backend_pref == "sdpa" or backend_pref == "torch":
            from vllm_omni.diffusion.attention.backends.ring_pytorch_attn import ring_pytorch_attn_func

            return ring_pytorch_attn_func(
                query,
                key,
                value,
                softmax_scale=softmax_scale,
                causal=causal,
                group=self._sp_group.ring_group,
                op_type="efficient",
                joint_tensor_key=joint_key,
                joint_tensor_value=joint_value,
                joint_strategy=joint_strategy,
            )

        from vllm_omni.diffusion.attention.backends.ring_flash_attn import ring_flash_attn_func

        # Prefer FA3 over FA2 for better performance (FA3 supports Ampere/Ada/Hopper)
        attn_type = AttnType.FA3 if HAS_FA3 else AttnType.FA

        return ring_flash_attn_func(
            query,
            key,
            value,
            dropout_p=0.0,
            softmax_scale=softmax_scale,
            causal=causal,
            window_size=(-1, -1),
            softcap=0.0,
            alibi_slopes=None,
            deterministic=False,
            group=self._sp_group.ring_group,
            attn_type=attn_type,
            joint_tensor_key=joint_key,
            joint_tensor_value=joint_value,
            joint_strategy=joint_strategy,
        )
