# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from diffusers.models.attention import AttentionModuleMixin, FeedForward
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.normalization import (
    AdaLayerNormContinuous,
    AdaLayerNormZero,
    AdaLayerNormZeroSingle,
)
from torch.nn import LayerNorm as LayerNorm

from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig
from sglang.multimodal_gen.runtime.layers.attention import USPAttention

# from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm as LayerNorm
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm, apply_qk_norm
from sglang.multimodal_gen.runtime.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
)
from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (
    QuantizationConfig,
)
from sglang.multimodal_gen.runtime.layers.quantization.configs.nunchaku_config import (
    NunchakuConfig,
    is_nunchaku_available,
)
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
    NDRotaryEmbedding,
    apply_flashinfer_rope_qk_inplace,
)
from sglang.multimodal_gen.runtime.layers.visual_embedding import (
    CombinedTimestepGuidanceTextProjEmbeddings,
    CombinedTimestepTextProjEmbeddings,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.platforms import current_platform
from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)  # pylint: disable=invalid-name

try:
    from nunchaku.models.attention import NunchakuFeedForward  # type: ignore[import]
    from nunchaku.models.normalization import (  # type: ignore[import]
        NunchakuAdaLayerNormZero,
        NunchakuAdaLayerNormZeroSingle,
    )
    from nunchaku.ops.gemm import (
        svdq_gemm_w4a4_cuda as _svdq_gemm_w4a4,  # type: ignore[import]
    )
    from nunchaku.ops.quantize import (
        svdq_quantize_w4a4_act_fuse_lora_cuda as _svdq_quantize_w4a4,  # type: ignore[import]
    )

    _nunchaku_fused_ops_available = True
except Exception:
    NunchakuFeedForward = None
    NunchakuAdaLayerNormZero = None
    NunchakuAdaLayerNormZeroSingle = None
    _svdq_gemm_w4a4 = None
    _svdq_quantize_w4a4 = None
    _nunchaku_fused_ops_available = False


def _fused_gelu_mlp(
    x: torch.Tensor,
    fc1,
    fc2,
    pad_size: int = 256,
) -> torch.Tensor:
    """
    Fused GELU MLP matching nunchaku's fused_gelu_mlp kernel path.

    nunchaku's single-block MLP checkpoint is calibrated for the fused path where:
      1. fc1 GEMM + GELU + 0.171875 shift + unsigned re-quantization + fc2.lora_down
         are all done in a single fused kernel call
      2. fc2 GEMM then receives unsigned INT4 activations (act_unsigned=True)

    Using the sequential path (fc1 → GELU → fc2 with symmetric quantization) is
    fundamentally incompatible with these wscales, causing visually wrong outputs.
    """
    batch_size, seq_len, channels = x.shape
    x_2d = x.view(batch_size * seq_len, channels)

    quantized_x, ascales, lora_act = _svdq_quantize_w4a4(
        x_2d,
        lora_down=fc1.proj_down,
        smooth=fc1.smooth_factor,
        fp4=fc1.precision == "nvfp4",
        pad_size=pad_size,
    )

    batch_size_pad = (batch_size * seq_len + pad_size - 1) // pad_size * pad_size
    is_fp4 = fc2.precision == "nvfp4"

    qout_act = torch.empty(
        batch_size_pad,
        fc1.output_size_per_partition // 2,
        dtype=torch.uint8,
        device=x_2d.device,
    )
    if is_fp4:
        qout_ascales = torch.empty(
            fc1.output_size_per_partition // 16,
            batch_size_pad,
            dtype=torch.float8_e4m3fn,
            device=x_2d.device,
        )
    else:
        qout_ascales = torch.empty(
            fc1.output_size_per_partition // 64,
            batch_size_pad,
            dtype=x_2d.dtype,
            device=x_2d.device,
        )
    qout_lora_act = torch.empty(
        batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x_2d.device
    )

    # fused: fc1 GEMM + GELU + shift + unsigned quantize + fc2.lora_down
    _svdq_gemm_w4a4(
        act=quantized_x,
        wgt=fc1.qweight,
        qout=qout_act,
        ascales=ascales,
        wscales=fc1.wscales,
        oscales=qout_ascales,
        lora_act_in=lora_act,
        lora_up=fc1.proj_up,
        lora_down=fc2.proj_down,
        lora_act_out=qout_lora_act,
        bias=fc1.bias,
        smooth_factor=fc2.smooth_factor,
        fp4=is_fp4,
        alpha=getattr(fc1, "_nunchaku_alpha", None),
        wcscales=getattr(fc1, "wcscales", None),
    )

    output = torch.empty(
        batch_size * seq_len,
        fc2.output_size_per_partition,
        dtype=x_2d.dtype,
        device=x_2d.device,
    )
    # fc2 GEMM with unsigned INT4 activations (fused kernel shifted by 0.171875)
    _svdq_gemm_w4a4(
        act=qout_act,
        wgt=fc2.qweight,
        out=output,
        ascales=qout_ascales,
        wscales=fc2.wscales,
        lora_act_in=qout_lora_act,
        lora_up=fc2.proj_up,
        bias=fc2.bias,
        fp4=is_fp4,
        alpha=getattr(fc2, "_nunchaku_alpha", None),
        wcscales=getattr(fc2, "wcscales", None),
        act_unsigned=True,
    )

    return output.view(batch_size, seq_len, -1)


def _get_qkv_projections(
    attn: "FluxAttention", hidden_states, encoder_hidden_states=None
):
    if getattr(attn, "use_fused_qkv", False):
        qkv, _ = attn.to_qkv(hidden_states)
        query, key, value = [x.contiguous() for x in qkv.chunk(3, dim=-1)]
    else:
        query, _ = attn.to_q(hidden_states)
        key, _ = attn.to_k(hidden_states)
        value, _ = attn.to_v(hidden_states)

    encoder_query = encoder_key = encoder_value = None
    if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
        if getattr(attn, "use_fused_added_qkv", False):
            added_qkv, _ = attn.to_added_qkv(encoder_hidden_states)
            encoder_query, encoder_key, encoder_value = [
                x.contiguous() for x in added_qkv.chunk(3, dim=-1)
            ]
        else:
            encoder_query, _ = attn.add_q_proj(encoder_hidden_states)
            encoder_key, _ = attn.add_k_proj(encoder_hidden_states)
            encoder_value, _ = attn.add_v_proj(encoder_hidden_states)

    return query, key, value, encoder_query, encoder_key, encoder_value


class FluxAttention(torch.nn.Module, AttentionModuleMixin):
    def __init__(
        self,
        query_dim: int,
        num_heads: int = 8,
        dim_head: int = 64,
        dropout: float = 0.0,
        bias: bool = False,
        added_kv_proj_dim: Optional[int] = None,
        added_proj_bias: Optional[bool] = True,
        out_bias: bool = True,
        eps: float = 1e-5,
        out_dim: int = None,
        context_pre_only: Optional[bool] = None,
        pre_only: bool = False,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.head_dim = dim_head
        self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads
        self.query_dim = query_dim
        self.use_bias = bias
        self.dropout = dropout
        self.out_dim = out_dim if out_dim is not None else query_dim
        self.context_pre_only = context_pre_only
        self.pre_only = pre_only
        self.heads = out_dim // dim_head if out_dim is not None else num_heads
        self.added_kv_proj_dim = added_kv_proj_dim
        self.added_proj_bias = added_proj_bias

        self.use_fused_qkv = isinstance(quant_config, NunchakuConfig)
        self.use_fused_added_qkv = isinstance(quant_config, NunchakuConfig)

        self.norm_q = RMSNorm(dim_head, eps=eps)
        self.norm_k = RMSNorm(dim_head, eps=eps)

        if self.use_fused_qkv:
            self.to_qkv = MergedColumnParallelLinear(
                query_dim,
                [self.inner_dim] * 3,
                bias=bias,
                gather_output=True,
                quant_config=quant_config,
                prefix=f"{prefix}.to_qkv" if prefix else "to_qkv",
            )
        else:
            self.to_q = ColumnParallelLinear(
                query_dim, self.inner_dim, bias=bias, gather_output=True
            )
            self.to_k = ColumnParallelLinear(
                query_dim, self.inner_dim, bias=bias, gather_output=True
            )
            self.to_v = ColumnParallelLinear(
                query_dim, self.inner_dim, bias=bias, gather_output=True
            )
        if not self.pre_only:
            self.to_out = torch.nn.ModuleList([])
            self.to_out.append(
                ColumnParallelLinear(
                    self.inner_dim,
                    self.out_dim,
                    bias=out_bias,
                    gather_output=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.to_out.0" if prefix else "",
                )
            )
            if dropout != 0.0:
                self.to_out.append(torch.nn.Dropout(dropout))

        if added_kv_proj_dim is not None:
            self.norm_added_q = RMSNorm(dim_head, eps=eps)
            self.norm_added_k = RMSNorm(dim_head, eps=eps)
            if self.use_fused_added_qkv:
                self.to_added_qkv = MergedColumnParallelLinear(
                    added_kv_proj_dim,
                    [self.inner_dim] * 3,
                    bias=added_proj_bias,
                    gather_output=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.to_added_qkv" if prefix else "to_added_qkv",
                )
            else:
                self.add_q_proj = ColumnParallelLinear(
                    added_kv_proj_dim,
                    self.inner_dim,
                    bias=added_proj_bias,
                    gather_output=True,
                )
                self.add_k_proj = ColumnParallelLinear(
                    added_kv_proj_dim,
                    self.inner_dim,
                    bias=added_proj_bias,
                    gather_output=True,
                )
                self.add_v_proj = ColumnParallelLinear(
                    added_kv_proj_dim,
                    self.inner_dim,
                    bias=added_proj_bias,
                    gather_output=True,
                )
            self.to_add_out = ColumnParallelLinear(
                self.inner_dim,
                query_dim,
                bias=out_bias,
                gather_output=True,
                quant_config=quant_config,
                prefix=f"{prefix}.to_add_out" if prefix else "",
            )

        self.attn = USPAttention(
            num_heads=num_heads,
            head_size=self.head_dim,
            dropout_rate=0,
            softmax_scale=None,
            causal=False,
        )

    def forward(
        self,
        x: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        freqs_cis=None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        query, key, value, encoder_query, encoder_key, encoder_value = (
            _get_qkv_projections(self, x, encoder_hidden_states)
        )

        query = query.unflatten(-1, (self.heads, -1))
        key = key.unflatten(-1, (self.heads, -1))
        value = value.unflatten(-1, (self.heads, -1))
        query, key = apply_qk_norm(
            q=query,
            k=key,
            q_norm=self.norm_q,
            k_norm=self.norm_k,
            head_dim=self.head_dim,
            allow_inplace=True,
        )

        if self.added_kv_proj_dim is not None:
            encoder_query = encoder_query.unflatten(-1, (self.heads, -1))
            encoder_key = encoder_key.unflatten(-1, (self.heads, -1))
            encoder_value = encoder_value.unflatten(-1, (self.heads, -1))

            encoder_query, encoder_key = apply_qk_norm(
                q=encoder_query,
                k=encoder_key,
                q_norm=self.norm_added_q,
                k_norm=self.norm_added_k,
                head_dim=self.head_dim,
                allow_inplace=True,
            )

            bsz, seq_len, _, _ = query.shape
            query = torch.cat([encoder_query, query], dim=1)
            key = torch.cat([encoder_key, key], dim=1)
            value = torch.cat([encoder_value, value], dim=1)

        if freqs_cis is not None:
            cos, sin = freqs_cis
            cos_sin_cache = torch.cat(
                [
                    cos.to(dtype=torch.float32).contiguous(),
                    sin.to(dtype=torch.float32).contiguous(),
                ],
                dim=-1,
            )
            query, key = apply_flashinfer_rope_qk_inplace(
                query, key, cos_sin_cache, is_neox=False
            )

        x = self.attn(query, key, value)
        x = x.flatten(2, 3)
        x = x.to(query.dtype)

        if encoder_hidden_states is not None:
            encoder_hidden_states, x = x.split_with_sizes(
                [
                    encoder_hidden_states.shape[1],
                    x.shape[1] - encoder_hidden_states.shape[1],
                ],
                dim=1,
            )
            if not self.pre_only:
                x, _ = self.to_out[0](x)
                if len(self.to_out) == 2:
                    x = self.to_out[1](x)
            encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states)

            return x, encoder_hidden_states
        else:
            if not self.pre_only:
                x, _ = self.to_out[0](x)
                if len(self.to_out) == 2:
                    x = self.to_out[1](x)
            return x


class FluxSingleTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        mlp_ratio: float = 4.0,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.mlp_hidden_dim = int(dim * mlp_ratio)
        self.use_nunchaku_structure = isinstance(quant_config, NunchakuConfig)

        self.norm = AdaLayerNormZeroSingle(dim)

        if self.use_nunchaku_structure:
            self.mlp_fc1 = ColumnParallelLinear(
                dim,
                self.mlp_hidden_dim,
                bias=True,
                gather_output=True,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp_fc1" if prefix else "mlp_fc1",
            )
            self.act_mlp = nn.GELU(approximate="tanh")
            self.mlp_fc2 = ColumnParallelLinear(
                self.mlp_hidden_dim,
                dim,
                bias=True,
                gather_output=True,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp_fc2" if prefix else "mlp_fc2",
            )

            self.attn = FluxAttention(
                query_dim=dim,
                dim_head=attention_head_dim,
                num_heads=num_attention_heads,
                out_dim=dim,
                bias=True,
                eps=1e-6,
                pre_only=False,
                quant_config=quant_config,
                prefix=f"{prefix}.attn" if prefix else "attn",
            )
            if NunchakuAdaLayerNormZeroSingle is not None:
                self.norm = NunchakuAdaLayerNormZeroSingle(self.norm, scale_shift=0)
        else:
            self.proj_mlp = ColumnParallelLinear(
                dim,
                self.mlp_hidden_dim,
                bias=True,
                gather_output=True,
            )
            self.act_mlp = nn.GELU(approximate="tanh")
            self.proj_out = ColumnParallelLinear(
                dim + self.mlp_hidden_dim,
                dim,
                bias=True,
                gather_output=True,
            )
            self.attn = FluxAttention(
                query_dim=dim,
                dim_head=attention_head_dim,
                num_heads=num_attention_heads,
                out_dim=dim,
                bias=True,
                eps=1e-6,
                pre_only=True,
            )

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        temb: torch.Tensor,
        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        text_seq_len = encoder_hidden_states.shape[1]
        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

        residual = hidden_states
        norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
        joint_attention_kwargs = joint_attention_kwargs or {}

        if self.use_nunchaku_structure:
            if _nunchaku_fused_ops_available:
                mlp_hidden_states = _fused_gelu_mlp(
                    norm_hidden_states, self.mlp_fc1, self.mlp_fc2
                )
            else:
                mlp_out, _ = self.mlp_fc1(norm_hidden_states)
                mlp_hidden_states = self.act_mlp(mlp_out)
                mlp_hidden_states, _ = self.mlp_fc2(mlp_hidden_states)

            attn_output = self.attn(
                x=norm_hidden_states,
                freqs_cis=freqs_cis,
                **joint_attention_kwargs,
            )
            if isinstance(attn_output, tuple):
                attn_output = attn_output[0]

            hidden_states = attn_output + mlp_hidden_states
            gate = gate.unsqueeze(1)
            hidden_states = gate * hidden_states
            hidden_states = residual + hidden_states
        else:
            proj_hidden_states, _ = self.proj_mlp(norm_hidden_states)
            mlp_hidden_states = self.act_mlp(proj_hidden_states)

            attn_output = self.attn(
                x=norm_hidden_states,
                freqs_cis=freqs_cis,
                **joint_attention_kwargs,
            )

            hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
            gate = gate.unsqueeze(1)
            proj_out, _ = self.proj_out(hidden_states)
            hidden_states = gate * proj_out
            hidden_states = residual + hidden_states

        if hidden_states.dtype == torch.float16:
            hidden_states = hidden_states.clip(-65504, 65504)

        encoder_hidden_states, hidden_states = (
            hidden_states[:, :text_seq_len],
            hidden_states[:, text_seq_len:],
        )
        return encoder_hidden_states, hidden_states


class FluxTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        qk_norm: str = "rms_norm",
        eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.norm1 = AdaLayerNormZero(dim)
        self.norm1_context = AdaLayerNormZero(dim)

        self.attn = FluxAttention(
            query_dim=dim,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim,
            num_heads=num_attention_heads,
            out_dim=dim,
            context_pre_only=False,
            bias=True,
            eps=eps,
            quant_config=quant_config,
            prefix=f"{prefix}.attn" if prefix else "attn",
        )

        self.norm2 = LayerNorm(dim, eps=1e-6, elementwise_affine=False)
        self.norm2_context = LayerNorm(dim, eps=1e-6, elementwise_affine=False)

        nunchaku_enabled = (
            quant_config is not None
            and hasattr(quant_config, "get_name")
            and quant_config.get_name() == "svdquant"
            and is_nunchaku_available()
            and NunchakuFeedForward is not None
        )
        self.use_nunchaku_structure = nunchaku_enabled
        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
        self.ff_context = FeedForward(
            dim=dim, dim_out=dim, activation_fn="gelu-approximate"
        )
        if nunchaku_enabled:
            nunchaku_kwargs = {
                "precision": quant_config.precision,
                "rank": quant_config.rank,
                "act_unsigned": quant_config.act_unsigned,
            }
            self.ff = NunchakuFeedForward(self.ff, **nunchaku_kwargs)
            self.ff_context = NunchakuFeedForward(self.ff_context, **nunchaku_kwargs)
            if NunchakuAdaLayerNormZero is not None:
                self.norm1 = NunchakuAdaLayerNormZero(self.norm1, scale_shift=0)
                self.norm1_context = NunchakuAdaLayerNormZero(
                    self.norm1_context, scale_shift=0
                )

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        temb: torch.Tensor,
        freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
            hidden_states, emb=temb
        )

        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
            self.norm1_context(encoder_hidden_states, emb=temb)
        )

        joint_attention_kwargs = joint_attention_kwargs or {}
        # Attention.
        attention_outputs = self.attn(
            x=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            freqs_cis=freqs_cis,
            **joint_attention_kwargs,
        )

        if len(attention_outputs) == 2:
            attn_output, context_attn_output = attention_outputs
        elif len(attention_outputs) == 3:
            attn_output, context_attn_output, ip_attn_output = attention_outputs

        # Process attention outputs for the `hidden_states`.
        attn_output = gate_msa.unsqueeze(1) * attn_output
        hidden_states = hidden_states + attn_output
        norm_hidden_states = self.norm2(hidden_states)
        if self.use_nunchaku_structure:
            norm_hidden_states = (
                norm_hidden_states * scale_mlp[:, None] + shift_mlp[:, None]
            )
        else:
            norm_hidden_states = (
                norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
            )

        ff_output = self.ff(norm_hidden_states)
        ff_output = gate_mlp.unsqueeze(1) * ff_output

        hidden_states = hidden_states + ff_output

        if len(attention_outputs) == 3:
            hidden_states = hidden_states + ip_attn_output
        # Process attention outputs for the `encoder_hidden_states`.
        context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
        encoder_hidden_states = encoder_hidden_states + context_attn_output

        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
        if self.use_nunchaku_structure:
            norm_encoder_hidden_states = (
                norm_encoder_hidden_states * c_scale_mlp[:, None] + c_shift_mlp[:, None]
            )
        else:
            norm_encoder_hidden_states = (
                norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
                + c_shift_mlp[:, None]
            )

        context_ff_output = self.ff_context(norm_encoder_hidden_states)
        encoder_hidden_states = (
            encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
        )
        if encoder_hidden_states.dtype == torch.float16:
            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

        return encoder_hidden_states, hidden_states


class FluxPosEmbed(nn.Module):
    # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
    def __init__(self, theta: int, axes_dim: List[int]):
        super().__init__()
        self.rope = NDRotaryEmbedding(
            rope_dim_list=axes_dim,
            rope_theta=theta,
            use_real=False,
            repeat_interleave_real=False,
            dtype=(
                torch.float32
                if current_platform.is_mps() or current_platform.is_musa()
                else torch.float64
            ),
        )

    def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        pos = ids.float()
        # TODO: potential error: flux use n_axes = ids.shape[-1]
        # see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509
        freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos)
        return freqs_cos.contiguous().float(), freqs_sin.contiguous().float()


class FluxTransformer2DModel(CachableDiT, OffloadableDiTMixin):
    """
    The Transformer model introduced in Flux.

    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
    """

    param_names_mapping = FluxConfig().arch_config.param_names_mapping

    @classmethod
    def get_nunchaku_quant_rules(cls) -> dict[str, list[str]]:
        return {
            "skip": [
                "norm",
                "embed",
                "rotary",
                "pos_embed",
            ],
            "svdq_w4a4": [
                "attn.to_qkv",
                "attn.to_out",
                "attn.add_qkv_proj",
                "attn.to_added_qkv",
                "attn.to_add_out",
                "img_mlp",
                "txt_mlp",
                "attention.to_qkv",
                "attention.to_out",
                "proj_mlp",
                "proj_out",
                "mlp_fc1",
                "mlp_fc2",
                "ff.net",
                "ff_context.net",
            ],
            "awq_w4a16": [
                "img_mod",
                "txt_mod",
            ],
        }

    def __init__(
        self,
        config: FluxConfig,
        hf_config: dict[str, Any],
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__(config=config, hf_config=hf_config)
        self.config = config.arch_config

        self.out_channels = (
            getattr(self.config, "out_channels", None) or self.config.in_channels
        )
        self.inner_dim = (
            self.config.num_attention_heads * self.config.attention_head_dim
        )

        self.rotary_emb = FluxPosEmbed(theta=10000, axes_dim=self.config.axes_dims_rope)

        text_time_guidance_cls = (
            CombinedTimestepGuidanceTextProjEmbeddings
            if self.config.guidance_embeds
            else CombinedTimestepTextProjEmbeddings
        )
        self.time_text_embed = text_time_guidance_cls(
            embedding_dim=self.inner_dim,
            pooled_projection_dim=self.config.pooled_projection_dim,
        )

        self.context_embedder = ColumnParallelLinear(
            self.config.joint_attention_dim,
            self.inner_dim,
            bias=True,
            gather_output=True,
        )
        self.x_embedder = ColumnParallelLinear(
            self.config.in_channels, self.inner_dim, bias=True, gather_output=True
        )
        self.transformer_blocks = nn.ModuleList(
            [
                FluxTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                    quant_config=quant_config,
                    prefix=f"transformer_blocks.{i}",
                )
                for i in range(self.config.num_layers)
            ]
        )

        self.single_transformer_blocks = nn.ModuleList(
            [
                FluxSingleTransformerBlock(
                    dim=self.inner_dim,
                    num_attention_heads=self.config.num_attention_heads,
                    attention_head_dim=self.config.attention_head_dim,
                    quant_config=quant_config,
                    prefix=f"single_transformer_blocks.{i}",
                )
                for i in range(self.config.num_single_layers)
            ]
        )

        self.norm_out = AdaLayerNormContinuous(
            self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
        )
        self.proj_out = ColumnParallelLinear(
            self.inner_dim,
            self.config.patch_size * self.config.patch_size * self.out_channels,
            bias=True,
            gather_output=True,
        )

        self.layer_names = [
            "transformer_blocks",
            "single_transformer_blocks",
        ]

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
        pooled_projections: torch.Tensor = None,
        timestep: torch.LongTensor = None,
        guidance: torch.Tensor = None,
        freqs_cis: torch.Tensor = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Union[torch.Tensor, Transformer2DModelOutput]:
        """
        The [`FluxTransformer2DModel`] forward method.

        Args:
            hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
                Input `hidden_states`.
            encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
            pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
                from the embeddings of input conditions.
            timestep ( `torch.LongTensor`):
                Used to indicate denoising step.
            guidance (`torch.Tensor`):
                Guidance embeddings.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).

        """
        if (
            joint_attention_kwargs is not None
            and joint_attention_kwargs.get("scale", None) is not None
        ):
            logger.warning(
                "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
            )
        hidden_states, _ = self.x_embedder(hidden_states)

        # Only pass guidance to time_text_embed if the model supports it
        if self.config.guidance_embeds and guidance is not None:
            temb = self.time_text_embed(timestep, guidance, pooled_projections)
        else:
            temb = self.time_text_embed(timestep, pooled_projections)

        encoder_hidden_states, _ = self.context_embedder(encoder_hidden_states)

        if (
            joint_attention_kwargs is not None
            and "ip_adapter_image_embeds" in joint_attention_kwargs
        ):
            ip_adapter_image_embeds = joint_attention_kwargs.pop(
                "ip_adapter_image_embeds"
            )
            ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
            joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})

        for block in self.transformer_blocks:
            encoder_hidden_states, hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb=temb,
                freqs_cis=freqs_cis,
                joint_attention_kwargs=joint_attention_kwargs,
            )
        for block in self.single_transformer_blocks:
            encoder_hidden_states, hidden_states = block(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                temb=temb,
                freqs_cis=freqs_cis,
                joint_attention_kwargs=joint_attention_kwargs,
            )

        hidden_states = self.norm_out(hidden_states, temb)

        output, _ = self.proj_out(hidden_states)

        return output


EntryClass = FluxTransformer2DModel
