import torch
import functools
from typing import Optional, Tuple
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.transformers.transformer_cogview3plus import (
    CogView3PlusTransformer2DModel,
    CogVideoXAttnProcessor2_0,
)
from diffusers.models.transformers.transformer_cogview4 import (
    CogView4Transformer2DModel,
    CogView4AttnProcessor,
)
from diffusers.models.attention_processor import Attention
from diffusers.models.attention_dispatch import dispatch_attention_fn

try:
    from diffusers.models._modeling_parallel import (
        ContextParallelInput,
        ContextParallelOutput,
        ContextParallelModelPlan,
    )
except ImportError:
    raise ImportError(
        "Context parallelism requires the 'diffusers>=0.36.dev0'."
        "Please install latest version of diffusers from source: \n"
        "pip3 install git+https://github.com/huggingface/diffusers.git"
    )
from .cp_plan_registers import (
    ContextParallelismPlanner,
    ContextParallelismPlannerRegister,
)
from .cp_plan_cogvideox import __patch_CogVideoXAttnProcessor2_0__call__

from cache_dit.logger import init_logger

logger = init_logger(__name__)


@ContextParallelismPlannerRegister.register("CogView3Plus")
class CogView3PlusContextParallelismPlanner(ContextParallelismPlanner):
    def apply(
        self,
        transformer: Optional[torch.nn.Module | ModelMixin] = None,
        **kwargs,
    ) -> ContextParallelModelPlan:

        # NOTE: Diffusers native CP plan still not supported
        # for CogView3Plus now.
        self._cp_planner_preferred_native_diffusers = False

        if transformer is not None and self._cp_planner_preferred_native_diffusers:
            assert isinstance(
                transformer, CogView3PlusTransformer2DModel
            ), "Transformer must be an instance of CogView3PlusTransformer2DModel"
            if hasattr(transformer, "_cp_plan"):
                if transformer._cp_plan is not None:
                    return transformer._cp_plan

        # CogView3Plus and CogVideoX share the same attention processor
        CogVideoXAttnProcessor2_0.__call__ = __patch_CogVideoXAttnProcessor2_0__call__
        # Also need to patch the parallel config and attention backend
        if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
            CogVideoXAttnProcessor2_0._parallel_config = None
        if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
            CogVideoXAttnProcessor2_0._attention_backend = None

        # Otherwise, use the custom CP plan defined here, this maybe
        # a little different from the native diffusers implementation
        # for some models.
        _cp_plan = {
            # Pattern of transformer_blocks.0, split_output=False:
            #     un-split input -> split -> to_qkv/...
            #     -> all2all
            #     -> attn (local head, full seqlen)
            #     -> all2all
            #     -> splited output
            # Pattern of the rest transformer_blocks, split_output=False:
            #     splited input (previous splited output) -> to_qkv/...
            #     -> all2all
            #     -> attn (local head, full seqlen)
            #     -> all2all
            #     -> splited output
            # The `encoder_hidden_states` will be changed after each block forward,
            # so we need to split it at the first block, and keep it splited (namely,
            # automatically split by the all2all op after attn) for the rest blocks.
            # The `out` tensor of local attn will be splited into `hidden_states` and
            # `encoder_hidden_states` after each block forward, thus both of them
            # will be automatically splited by all2all comm op after local attn.
            "transformer_blocks.0": {
                "hidden_states": ContextParallelInput(
                    split_dim=1, expected_dims=3, split_output=False
                ),
                "encoder_hidden_states": ContextParallelInput(
                    split_dim=1, expected_dims=3, split_output=False
                ),
            },
            # transformer forward while using CP, since it is not splited here.
            # Then, the final proj_out will gather the splited output.
            #     splited input (previous splited output)
            #     -> all gather
            #     -> un-split output
            "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
        }
        return _cp_plan


@ContextParallelismPlannerRegister.register("CogView4")
class CogView4ContextParallelismPlanner(ContextParallelismPlanner):
    def apply(
        self,
        transformer: Optional[torch.nn.Module | ModelMixin] = None,
        **kwargs,
    ) -> ContextParallelModelPlan:

        # NOTE: Diffusers native CP plan still not supported
        # for CogView4 now.
        self._cp_planner_preferred_native_diffusers = False

        if transformer is not None and self._cp_planner_preferred_native_diffusers:
            assert isinstance(
                transformer, CogView4Transformer2DModel
            ), "Transformer must be an instance of CogView4Transformer2DModel"
            if hasattr(transformer, "_cp_plan"):
                if transformer._cp_plan is not None:
                    return transformer._cp_plan

        CogView4AttnProcessor.__call__ = __patch_CogView4AttnProcessor__call__
        # Also need to patch the parallel config and attention backend
        if not hasattr(CogView4AttnProcessor, "_parallel_config"):
            CogView4AttnProcessor._parallel_config = None
        if not hasattr(CogView4AttnProcessor, "_attention_backend"):
            CogView4AttnProcessor._attention_backend = None

        # Otherwise, use the custom CP plan defined here, this maybe
        # a little different from the native diffusers implementation
        # for some models.
        _cp_plan = {
            # Pattern of transformer_blocks.0, split_output=False:
            #     un-split input -> split -> to_qkv/...
            #     -> all2all
            #     -> attn (local head, full seqlen)
            #     -> all2all
            #     -> splited output
            # Pattern of the rest transformer_blocks, split_output=False:
            #     splited input (previous splited output) -> to_qkv/...
            #     -> all2all
            #     -> attn (local head, full seqlen)
            #     -> all2all
            #     -> splited output
            # The `encoder_hidden_states` will be changed after each block forward,
            # so we need to split it at the first block, and keep it splited (namely,
            # automatically split by the all2all op after attn) for the rest blocks.
            # The `out` tensor of local attn will be splited into `hidden_states` and
            # `encoder_hidden_states` after each block forward, thus both of them
            # will be automatically splited by all2all comm op after local attn.
            "transformer_blocks.0": {
                "hidden_states": ContextParallelInput(
                    split_dim=1, expected_dims=3, split_output=False
                ),
                "encoder_hidden_states": ContextParallelInput(
                    split_dim=1, expected_dims=3, split_output=False
                ),
            },
            # Pattern of the image_rotary_emb, split at every block, because the it
            # is not automatically splited by all2all comm op and keep un-splited
            # while the block forward finished:
            #    un-split input -> split output
            #    -> after block forward
            #    -> un-split input
            #    un-split input -> split output
            #    ...
            "transformer_blocks.*": {
                "image_rotary_emb": [
                    ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
                    ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
                ],
            },
            # transformer forward while using CP, since it is not splited here.
            # Then, the final proj_out will gather the splited output.
            #     splited input (previous splited output)
            #     -> all gather
            #     -> un-split output
            "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
        }
        return _cp_plan


@functools.wraps(CogView4AttnProcessor.__call__)
def __patch_CogView4AttnProcessor__call__(
    self: CogView4AttnProcessor,
    attn: Attention,
    hidden_states: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    dtype = encoder_hidden_states.dtype

    batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
    batch_size, image_seq_length, embed_dim = hidden_states.shape
    hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

    # 1. QKV projections
    query = attn.to_q(hidden_states)
    key = attn.to_k(hidden_states)
    value = attn.to_v(hidden_states)

    # NOTE(DefTruth): no transpose
    query = query.unflatten(2, (attn.heads, -1))
    key = key.unflatten(2, (attn.heads, -1))
    value = value.unflatten(2, (attn.heads, -1))

    # 2. QK normalization
    if attn.norm_q is not None:
        query = attn.norm_q(query).to(dtype=dtype)
    if attn.norm_k is not None:
        key = attn.norm_k(key).to(dtype=dtype)

    # 3. Rotational positional embeddings applied to latent stream
    if image_rotary_emb is not None:
        from diffusers.models.embeddings import apply_rotary_emb

        query[:, text_seq_length:] = apply_rotary_emb(
            query[:, text_seq_length:],
            image_rotary_emb,
            use_real_unbind_dim=-2,
            sequence_dim=1,
        )
        key[:, text_seq_length:] = apply_rotary_emb(
            key[:, text_seq_length:],
            image_rotary_emb,
            use_real_unbind_dim=-2,
            sequence_dim=1,
        )

    # 4. Attention
    if attention_mask is not None:
        text_attn_mask = attention_mask
        assert (
            text_attn_mask.dim() == 2
        ), "the shape of text_attn_mask should be (batch_size, text_seq_length)"
        text_attn_mask = text_attn_mask.float().to(query.device)
        mix_attn_mask = torch.ones(
            (batch_size, text_seq_length + image_seq_length),
            device=query.device,
        )
        mix_attn_mask[:, :text_seq_length] = text_attn_mask  # [B, seq_len]
        # TODO(DefTruth): Permute mix_attn_mask if context parallel is used.
        # For example, if work size = 2: [E, H] -> [E_0, H_0, E_1, H_1]
        mix_attn_mask = mix_attn_mask.unsqueeze(2)  # [B, seq_len, 1]
        attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)  # [B, seq_len, seq_len]
        attention_mask = (
            (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
        )  # [B, 1, seq_len, seq_len]
        if hasattr(self, "_parallel_config") and self._parallel_config is not None:
            raise NotImplementedError(
                "Attention mask with context parallelism for CogView4 " "is not implemented yet."
            )

    # NOTE(DefTruth): Apply dispatch_attention_fn instead of sdpa directly
    hidden_states = dispatch_attention_fn(
        query,
        key,
        value,
        attn_mask=attention_mask,
        dropout_p=0.0,
        is_causal=False,
        backend=getattr(self, "_attention_backend", None),
        parallel_config=getattr(self, "_parallel_config", None),
    )
    hidden_states = hidden_states.flatten(2, 3)
    hidden_states = hidden_states.type_as(query)

    # 5. Output projection
    hidden_states = attn.to_out[0](hidden_states)
    hidden_states = attn.to_out[1](hidden_states)

    encoder_hidden_states, hidden_states = hidden_states.split(
        [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
    )
    return hidden_states, encoder_hidden_states
