from __future__ import annotations

import math
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from tensorrt_llm._common import default_net

from ..._utils import str_dtype_to_trt, trt_dtype_to_np
from ...functional import (
    Tensor,
    bert_attention,
    cast,
    chunk,
    concat,
    constant,
    expand_dims,
    expand_dims_like,
    expand_mask,
    gelu,
    matmul,
    permute,
    shape,
    silu,
    slice,
    softmax,
    squeeze,
    unsqueeze,
    view,
)
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
from ...module import Module


class FeedForward(Module):
    def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim

        self.project_in = Linear(dim, inner_dim)
        self.ff = Linear(inner_dim, dim_out)

    def forward(self, x):
        return self.ff(gelu(self.project_in(x)))


class AdaLayerNormZero(Module):
    def __init__(self, dim):
        super().__init__()

        self.linear = Linear(dim, dim * 6)
        self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, emb=None):
        emb = self.linear(silu(emb))
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
        x = self.norm(x)
        ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
        if default_net().plugin_config.remove_input_padding:
            x = x * (ones + scale_msa) + shift_msa
        else:
            x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaLayerNormZero_Final(Module):
    def __init__(self, dim):
        super().__init__()

        self.linear = Linear(dim, dim * 2)

        self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)

    def forward(self, x, emb):
        emb = self.linear(silu(emb))
        scale, shift = chunk(emb, 2, dim=1)
        ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
        if default_net().plugin_config.remove_input_padding:
            x = self.norm(x) * (ones + scale) + shift
        else:
            x = self.norm(x) * unsqueeze((ones + scale), 1)
            x = x + unsqueeze(shift, 1)
        return x


class ConvPositionEmbedding(Module):
    def __init__(self, dim, kernel_size=31, groups=16):
        super().__init__()
        assert kernel_size % 2 != 0
        self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
        self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
        self.mish = Mish()

    def forward(self, x, mask=None):
        if default_net().plugin_config.remove_input_padding:
            x = unsqueeze(x, 0)
        if mask is not None:
            mask = mask.view(concat([shape(mask, 0), 1, shape(mask, 1)]))  # [B 1 N]
            mask = expand_dims_like(mask, x)  # [B D N]
            mask = cast(mask, x.dtype)
        x = permute(x, [0, 2, 1])  # [B D N]

        if mask is not None:
            x = self.mish(self.conv1d2(self.mish(self.conv1d1(x * mask) * mask)) * mask)
        else:
            x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))

        x = permute(x, [0, 2, 1])  # [B N D]
        if default_net().plugin_config.remove_input_padding:
            x = squeeze(x, 0)
        return x


class Attention(Module):
    def __init__(
        self,
        processor: AttnProcessor,
        dim: int,
        heads: int = 16,
        dim_head: int = 64,
        dropout: float = 0.0,
        context_dim: Optional[int] = None,  # if not None -> joint attention
        context_pre_only=None,
    ):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

        self.processor = processor

        self.dim = dim  # hidden_size
        self.heads = heads
        self.inner_dim = dim_head * heads
        self.dropout = dropout
        self.attention_head_size = dim_head
        self.context_dim = context_dim
        self.context_pre_only = context_pre_only
        self.tp_size = 1
        self.num_attention_heads = heads // self.tp_size
        self.num_attention_kv_heads = heads // self.tp_size  # 8
        self.dtype = str_dtype_to_trt("float32")
        self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
        self.to_q = ColumnLinear(
            dim,
            self.tp_size * self.num_attention_heads * self.attention_head_size,
            bias=True,
            dtype=self.dtype,
            tp_group=None,
            tp_size=self.tp_size,
        )
        self.to_k = ColumnLinear(
            dim,
            self.tp_size * self.num_attention_heads * self.attention_head_size,
            bias=True,
            dtype=self.dtype,
            tp_group=None,
            tp_size=self.tp_size,
        )
        self.to_v = ColumnLinear(
            dim,
            self.tp_size * self.num_attention_heads * self.attention_head_size,
            bias=True,
            dtype=self.dtype,
            tp_group=None,
            tp_size=self.tp_size,
        )

        if self.context_dim is not None:
            self.to_k_c = Linear(context_dim, self.inner_dim)
            self.to_v_c = Linear(context_dim, self.inner_dim)
            if self.context_pre_only is not None:
                self.to_q_c = Linear(context_dim, self.inner_dim)

        self.to_out = RowLinear(
            self.tp_size * self.num_attention_heads * self.attention_head_size,
            dim,
            bias=True,
            dtype=self.dtype,
            tp_group=None,
            tp_size=self.tp_size,
        )

        if self.context_pre_only is not None and not self.context_pre_only:
            self.to_out_c = Linear(self.inner_dim, dim)

    def forward(
        self,
        x,  # noised input x
        rope_cos,
        rope_sin,
        input_lengths,
        mask=None,
        c=None,  # context c
        scale=1.0,
        rope=None,
        c_rope=None,  # rotary position embedding for c
    ) -> torch.Tensor:
        if c is not None:
            return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
        else:
            return self.processor(
                self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
            )


def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
    shape_tensor = concat(
        [shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
    )
    if default_net().plugin_config.remove_input_padding:
        assert tensor.ndim() == 2
        x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
        x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
        x1 = expand_dims(x1, 2)
        x2 = expand_dims(x2, 2)
        zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
        x2 = zero - x2
        x = concat([x2, x1], 2)
        out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
    else:
        assert tensor.ndim() == 3

        x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
        x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
        x1 = expand_dims(x1, 3)
        x2 = expand_dims(x2, 3)
        zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
        x2 = zero - x2
        x = concat([x2, x1], 3)
        out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))

    return out


def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
    full_dim = x.size(-1)
    head_dim = rope_cos.size(-1)  # attn head dim, e.g. 64
    if pe_attn_head is None:
        pe_attn_head = full_dim // head_dim
    rotated_dim = head_dim * pe_attn_head

    rotated_and_unrotated_list = []

    if default_net().plugin_config.remove_input_padding:  # for [N, D] input
        new_t_shape = concat([shape(x, 0), head_dim])  # (2, -1, 64)

        for i in range(pe_attn_head):
            x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
            x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
            rotated_and_unrotated_list.append(x_rotated_i)

        new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim])  # (2, -1, 1024 - 64 * pe_attn_head)
        x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
        rotated_and_unrotated_list.append(x_unrotated)

    else:  # for [B, N, D] input
        new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim])  # (2, -1, 64)

        for i in range(pe_attn_head):
            x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
            x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
            rotated_and_unrotated_list.append(x_rotated_i)

        new_t_unrotated_shape = concat(
            [shape(x, 0), shape(x, 1), full_dim - rotated_dim]
        )  # (2, -1, 1024 - 64 * pe_attn_head)
        x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
        rotated_and_unrotated_list.append(x_unrotated)

    out = concat(rotated_and_unrotated_list, dim=-1)

    return out


class AttnProcessor:
    def __init__(
        self,
        pe_attn_head: Optional[int] = None,  # number of attention head to apply rope, None for all
    ):
        self.pe_attn_head = pe_attn_head

    def __call__(
        self,
        attn,
        x,  # noised input x
        rope_cos,
        rope_sin,
        input_lengths,
        scale=1.0,
        rope=None,
        mask=None,
    ) -> torch.FloatTensor:
        query = attn.to_q(x)
        key = attn.to_k(x)
        value = attn.to_v(x)
        # k,v,q all (2,1226,1024)
        query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
        key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)

        # attention
        inner_dim = key.shape[-1]
        norm_factor = math.sqrt(attn.attention_head_size)
        q_scaling = 1.0 / norm_factor
        if default_net().plugin_config.remove_input_padding:
            mask = None

        if default_net().plugin_config.bert_attention_plugin:
            qkv = concat([query, key, value], dim=-1)
            # TRT plugin mode
            assert input_lengths is not None
            if default_net().plugin_config.remove_input_padding:
                qkv = qkv.view(concat([-1, 3 * inner_dim]))
                max_input_length = constant(
                    np.zeros(
                        [
                            2048,
                        ],
                        dtype=np.int32,
                    )
                )
            else:
                max_input_length = None
            context = bert_attention(
                qkv,
                input_lengths,
                attn.num_attention_heads,
                attn.attention_head_size,
                q_scaling=q_scaling,
                max_input_length=max_input_length,
            )
        else:
            assert not default_net().plugin_config.remove_input_padding

            def transpose_for_scores(x):
                new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])

                y = x.view(new_x_shape)
                y = y.transpose(1, 2)
                return y

            def transpose_for_scores_k(x):
                new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])

                y = x.view(new_x_shape)
                y = y.permute([0, 2, 3, 1])
                return y

            query = transpose_for_scores(query)
            key = transpose_for_scores_k(key)
            value = transpose_for_scores(value)

            attention_scores = matmul(query, key, use_fp32_acc=False)

            if mask is not None:
                attention_mask = expand_mask(mask, shape(query, 2))
                attention_mask = cast(attention_mask, attention_scores.dtype)
                attention_scores = attention_scores + attention_mask

            attention_probs = softmax(attention_scores, dim=-1)

            context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
            context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
        context = attn.to_out(context)
        if mask is not None:
            mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
            mask = expand_dims_like(mask, context)
            mask = cast(mask, context.dtype)
            context = context * mask
        return context


# DiT Block
class DiTBlock(Module):
    def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
        super().__init__()

        self.attn_norm = AdaLayerNormZero(dim)
        self.attn = Attention(
            processor=AttnProcessor(pe_attn_head=pe_attn_head),
            dim=dim,
            heads=heads,
            dim_head=dim_head,
            dropout=dropout,
        )

        self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)

    def forward(
        self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError, mask=None
    ):  # x: noised input, t: time embedding
        # pre-norm & modulation for attention input
        norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
        # attention
        # norm ----> (2,1226,1024)
        attn_output = self.attn(
            x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask
        )
        # process attention output for input x
        if default_net().plugin_config.remove_input_padding:
            x = x + gate_msa * attn_output
        else:
            x = x + unsqueeze(gate_msa, 1) * attn_output
        ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
        if default_net().plugin_config.remove_input_padding:
            norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
        else:
            norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
            # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
        ff_output = self.ff(norm)
        if default_net().plugin_config.remove_input_padding:
            x = x + gate_mlp * ff_output
        else:
            x = x + unsqueeze(gate_mlp, 1) * ff_output

        return x


class TimestepEmbedding(Module):
    def __init__(self, dim, freq_embed_dim=256, dtype=None):
        super().__init__()
        # self.time_embed = SinusPositionEmbedding(freq_embed_dim)
        self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
        self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)

    def forward(self, timestep):
        t_freq = self.mlp1(timestep)
        t_freq = silu(t_freq)
        t_emb = self.mlp2(t_freq)
        return t_emb
