# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo

# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/layernorm.py
"""Custom normalization layers."""

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from sglang.multimodal_gen.runtime.platforms import current_platform

_is_cuda = current_platform.is_cuda()
_is_npu = current_platform.is_npu()
if _is_cuda:
    from sgl_kernel import fused_add_rmsnorm, rmsnorm

if _is_npu:
    import torch_npu

from sglang.jit_kernel.diffusion.triton.norm import norm_infer, rms_norm_fn
from sglang.jit_kernel.diffusion.triton.rmsnorm_onepass import triton_one_pass_rms_norm
from sglang.jit_kernel.diffusion.triton.scale_shift import fuse_scale_shift_kernel
from sglang.jit_kernel.norm import can_use_fused_inplace_qknorm, fused_inplace_qknorm
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    get_tp_group,
)
from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp
from sglang.multimodal_gen.runtime.utils.common import get_bool_env_var


# Copied and adapted from sglang
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        dtype: torch.dtype = torch.float32,
        var_hidden_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
        self.hidden_size = hidden_size
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
        if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
            self._forward_method = self.forward_native

    def forward_triton(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
        return rms_norm_fn(
            x, self.weight, bias=None, residual=residual, eps=self.variance_epsilon
        )

    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        shape = x.shape
        device = x.device
        x = x.reshape(-1, shape[-1])
        if residual is not None:
            residual_shape = residual.shape
            residual = residual.view(-1, shape[-1])

        if x.dtype == torch.float:
            # fp32
            out = self.forward_triton(x, residual)
            if residual is not None:
                return out[0].view(shape), out[1].view(residual_shape)
            out = out.view(shape)
            return out
        elif self.variance_size_override is not None:
            return self.forward_native(x, residual)
        elif residual is not None:
            fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
            return x.view(shape), residual.view(residual_shape)
        else:
            if x.shape[-1] <= 128:
                out = triton_one_pass_rms_norm(
                    x, self.weight.data, self.variance_epsilon
                )
            else:
                out = rmsnorm(x, self.weight.data, self.variance_epsilon)
        out = out.view(shape)

        return out

    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if not x.is_contiguous():
            x = x.contiguous()
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError(
                "Expected hidden_size to be "
                f"{self.hidden_size}, but found: {hidden_size}"
            )

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}"
                )

            x_var = x[..., : self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        x = (x * self.weight).to(orig_dtype)
        if residual is None:
            return x
        else:
            return x, residual

    def forward_cpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        return self.forward_native(x, residual)

    def forward_npu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if residual is not None:
            out, _, residual_out = torch_npu.npu_add_rms_norm(
                residual, x, self.weight.data, self.variance_epsilon
            )
            return out, residual_out
        return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]

    def forward_hip(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        # ROCm builds of sgl-kernel do not expose rmsnorm custom ops yet.
        return self.forward_native(x, residual)

    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s


# Copied and adapted from sglang
@CustomOp.register("layer_norm")
class LayerNorm(CustomOp):
    def __init__(
        self,
        hidden_size: int,
        eps=1e-5,
        bias: bool = True,
        elementwise_affine=True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()
        self.eps = eps
        factory_kwargs = {"device": device, "dtype": dtype}
        self.hidden_size = hidden_size
        if elementwise_affine:
            self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
            self.bias = (
                torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
                if bias
                else None
            )
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
            # Lazy cache for ones vector (not a registered buffer to avoid FSDP/meta issues)
            self._weight_fallback_cache = None

    def _get_weight_fallback(self, x: torch.Tensor) -> torch.Tensor:
        wf = getattr(self, "_weight_fallback_cache", None)
        if (
            wf is None
            or wf.device != x.device
            or wf.dtype != x.dtype
            or wf.numel() != self.hidden_size
        ):
            wf = torch.ones(self.hidden_size, device=x.device, dtype=x.dtype)
            self._weight_fallback_cache = wf
        return wf

    def forward_triton(self, x: torch.Tensor):
        # Fast inference kernel without residual/dropout branches
        return norm_infer(
            x.view(-1, self.hidden_size),
            self.weight,
            self.bias,
            eps=self.eps,
            is_rms_norm=False,
        ).view(x.shape)

    def forward_cuda(
        self,
        x: torch.Tensor,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        shape = x.shape
        x = x.view(-1, self.hidden_size)
        return self.forward_triton(x).view(shape)

    @torch.compile(backend="inductor", disable=current_platform.is_npu())
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        input_dtype = x.dtype
        mean = x.mean(-1, keepdim=True)
        variance = (x - mean).pow(2).mean(-1, keepdim=True)
        x = (x - mean) * torch.rsqrt(variance + self.eps)
        if self.weight is not None:
            x = self.weight * x
        # if no affine, this is a no-op
        if self.bias is not None:
            x = x + self.bias
        return x.to(input_dtype)

    def forward_cpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        return self.forward_native(x, residual)

    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s


# adapted from Diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py
# NOTE(will): Needed to match behavior of diffusers and wan2.1 even while using
# FSDP's MixedPrecisionPolicy
class FP32LayerNorm(nn.LayerNorm):
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        origin_dtype = inputs.dtype
        device = inputs.device
        return F.layer_norm(
            inputs.float(),
            self.normalized_shape,
            self.weight.float().to(device=device) if self.weight is not None else None,
            self.bias.float().to(device=device) if self.bias is not None else None,
            self.eps,
        ).to(origin_dtype)


################################################################################
# Fused norm kernel
################################################################################
def _ensure_contiguous(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
    return tensor.contiguous() if tensor is not None else None


class _ScaleResidualNormScaleShift(CustomOp):
    """
    Fused kernel that combines:
    1. residual_out = residual + gate * x
    2. normed = layernorm(residual_out) or rmsnorm(residual_out)
    3. out = normed * (1 + scale) + shift
    compute_dtype is always fp32 for higher precision.
    """

    norm_type: str

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        elementwise_affine: bool = False,
        dtype: torch.dtype = torch.float32,
        prefix: str = "",
    ):
        super().__init__()
        self.eps = eps
        self.dtype = dtype
        if self.norm_type == "rms":
            self.norm = RMSNorm(hidden_size, eps=eps, dtype=dtype)
        elif self.norm_type == "layer":
            self.norm = FP32LayerNorm(
                hidden_size, elementwise_affine=elementwise_affine, eps=eps, dtype=dtype
            )
        else:
            raise NotImplementedError(f"Norm type {self.norm_type} not implemented")

    def forward_cuda(
        self,
        residual: torch.Tensor,
        x: torch.Tensor,
        gate: torch.Tensor | int,
        shift: torch.Tensor,
        scale: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if x.shape[-1] % 256 != 0 and x.shape[-1] <= 8192:
            import warnings

            warnings.warn(
                "FusedScaleResidualNormScaleShift cuda not available, using native fallback",
                stacklevel=2,
            )
            return self.forward_native(residual, x, gate, shift, scale)

        from sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import (
            fused_scale_residual_norm_scale_shift,
        )

        if isinstance(gate, int) and gate != 1:
            raise ValueError(
                f"Only gate value of 1 is supported for int type, but got {gate}"
            )

        return fused_scale_residual_norm_scale_shift(
            residual.contiguous(),
            x.contiguous(),
            gate.contiguous() if isinstance(gate, torch.Tensor) else None,
            _ensure_contiguous(getattr(self.norm, "weight", None)),
            _ensure_contiguous(getattr(self.norm, "bias", None)),
            scale.contiguous(),
            shift.contiguous(),
            self.norm_type,
            self.eps,
        )

    def forward_hip(self, *args, **kwargs):
        # ROCm does not support CUDA/CUTLASS-based fused kernels yet,
        # so we fall back to the native PyTorch implementation.
        return self.forward_native(*args, **kwargs)

    def forward_native(
        self,
        residual: torch.Tensor,
        x: torch.Tensor,
        gate: torch.Tensor | int,
        shift: torch.Tensor,
        scale: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # x.shape: [batch_size, seq_len, inner_dim]
        if isinstance(gate, int):
            # used by cross-attention, should be 1
            assert gate == 1
            residual_output = residual + x
        elif isinstance(gate, torch.Tensor):
            if gate.dim() == 4:
                # gate.shape: [batch_size, num_frames, 1, inner_dim]
                num_frames = gate.shape[1]
                frame_seqlen = x.shape[1] // num_frames
                residual_output = residual + (
                    x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate
                ).flatten(1, 2)
            else:
                # gate.shape: [batch_size, 1, inner_dim]
                residual_output = residual + x * gate
        else:
            raise ValueError(f"Gate type {type(gate)} not supported")
        normalized = self.norm(residual_output)
        modulated = fuse_scale_shift_kernel(normalized, scale, shift)
        return modulated, residual_output


class ScaleResidualLayerNormScaleShift(_ScaleResidualNormScaleShift):
    norm_type = "layer"


class ScaleResidualRMSNormScaleShift(_ScaleResidualNormScaleShift):
    norm_type = "rms"


class _NormScaleShift(CustomOp):
    """
    Fused kernel that combines:
    1. normed = layernorm(x) or rmsnorm(x)
    2. out = normed * (1 + scale) + shift
    compute_dtype is always fp32 for higher precision.
    """

    norm_type: str

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        elementwise_affine: bool = False,
        dtype: torch.dtype = torch.float32,
        prefix: str = "",
    ):
        super().__init__()
        self.eps = eps
        if self.norm_type == "rms":
            self.norm = RMSNorm(hidden_size, eps=eps, dtype=dtype)
        elif self.norm_type == "layer":
            self.norm = FP32LayerNorm(
                hidden_size, elementwise_affine=elementwise_affine, eps=eps, dtype=dtype
            )
        else:
            raise NotImplementedError(f"Norm type {self.norm_type} not implemented")

    def forward_cuda(
        self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor
    ) -> torch.Tensor:
        if x.shape[-1] % 256 != 0 and x.shape[-1] <= 8192:
            import warnings

            warnings.warn(
                "FusedNormScaleShift cuda not available, using native fallback",
                stacklevel=2,
            )
            return self.forward_native(x, shift, scale)

        from sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import (
            fused_norm_scale_shift,
        )

        return fused_norm_scale_shift(
            x.contiguous(),
            _ensure_contiguous(getattr(self.norm, "weight", None)),
            _ensure_contiguous(getattr(self.norm, "bias", None)),
            scale.contiguous(),
            shift.contiguous(),
            self.norm_type,
            self.eps,
        )

    def forward_hip(self, *args, **kwargs):
        # ROCm does not support CUDA/CUTLASS-based fused kernels yet,
        # so we fall back to the native PyTorch implementation.
        return self.forward_native(*args, **kwargs)

    def forward_native(
        self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor
    ) -> torch.Tensor:
        normalized = self.norm(x)
        modulated = fuse_scale_shift_kernel(normalized, scale, shift)
        return modulated.to(x.dtype)


class LayerNormScaleShift(_NormScaleShift):
    norm_type = "layer"


class RMSNormScaleShift(_NormScaleShift):
    norm_type = "rms"


def apply_qk_norm(
    q: torch.Tensor,
    k: torch.Tensor,
    q_norm: "RMSNorm",
    k_norm: "RMSNorm",
    head_dim: int,
    allow_inplace: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Apply QK normalization for query and key tensors.

    Uses JIT fused inplace kernel when available, falls back to standard RMSNorm.
    """

    batch_size = q.size(0)
    q_eps = q_norm.variance_epsilon
    k_eps = k_norm.variance_epsilon
    # Only try fused path on CUDA and when it won't introduce implicit copies.
    if (
        _is_cuda
        and allow_inplace
        and (q_eps == k_eps)
        and can_use_fused_inplace_qknorm(head_dim, q.dtype)
    ):
        fused_inplace_qknorm(
            q=q.view(batch_size, -1, head_dim),
            k=k.view(batch_size, -1, head_dim),
            q_weight=q_norm.weight,
            k_weight=k_norm.weight,
            head_dim=head_dim,
            eps=q_eps,
        )
        return q, k

    q_shape = q.shape
    k_shape = k.shape
    q_out = q_norm(q.view(-1, head_dim)).view(q_shape)
    k_out = k_norm(k.view(-1, head_dim)).view(k_shape)
    return q_out, k_out


def tensor_parallel_rms_norm(x: torch.Tensor, norm: "RMSNorm") -> torch.Tensor:
    tp_rank = get_tensor_model_parallel_rank()
    tp_size = get_tensor_model_parallel_world_size()
    src_dtype = x.dtype
    weight = norm.weight.tensor_split(tp_size)[tp_rank].float()
    x_fp32 = x.float()
    variance = x_fp32.pow(2).mean(dim=-1, keepdim=True)
    variance = get_tp_group().all_reduce(
        variance, op=torch._C._distributed_c10d.ReduceOp.AVG
    )
    output = x_fp32 * torch.rsqrt(variance + norm.variance_epsilon) * weight
    return output.to(dtype=src_dtype)


# TODO: Workaround, fuse norm with new select01 kernel
def apply_layernorm_only(x: torch.Tensor, layernorm_scale_shift: LayerNormScaleShift):
    return norm_infer(
        x.view(-1, x.shape[-1]),
        layernorm_scale_shift.norm.weight,
        layernorm_scale_shift.norm.bias,
        eps=layernorm_scale_shift.eps,
        is_rms_norm=False,
    ).view(x.shape)
