import logging

import torch
import triton

from sglang.srt.utils import ceil_div, is_cuda

logger = logging.getLogger(__name__)

_is_cuda = is_cuda()
if _is_cuda:
    from sglang.srt.layers.quantization.fp8_kernel import (
        sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
    )

import triton.language as tl


def _get_launch_config_1d(device, numel):
    MAX_THREADS_PER_BLOCK = 1024
    MIN_THREADS_PER_BLOCK = 512
    MAX_WAVES = 8  # empirical numbers

    props = torch.cuda.get_device_properties(device)
    sm_count = props.multi_processor_count
    max_threads_per_sm = props.max_threads_per_multi_processor
    max_num_blocks = sm_count * max_threads_per_sm // MAX_THREADS_PER_BLOCK

    block_dim = MAX_THREADS_PER_BLOCK

    def get_num_blocks(block_dim):
        return triton.cdiv(numel, block_dim)

    while (
        block_dim > MIN_THREADS_PER_BLOCK
        and get_num_blocks(block_dim // 2) <= max_num_blocks
    ):
        block_dim = block_dim // 2

    num_blocks = get_num_blocks(block_dim)
    grid_dim = min(num_blocks, max_num_blocks * MAX_WAVES)

    return (grid_dim,), block_dim


def _get_launch_config_2d(device, m, n):
    MAX_THREADS_PER_BLOCK = 1024
    MIN_THREADS_PER_BLOCK = 512
    MAX_WAVES = 8  # empirical numbers

    props = torch.cuda.get_device_properties(device)
    sm_count = props.multi_processor_count
    max_threads_per_sm = props.max_threads_per_multi_processor
    max_num_blocks = sm_count * max_threads_per_sm // MAX_THREADS_PER_BLOCK

    block_dim = MAX_THREADS_PER_BLOCK

    def get_num_blocks(block_dim):
        return m * triton.cdiv(n, block_dim)

    while (
        block_dim > MIN_THREADS_PER_BLOCK
        and get_num_blocks(block_dim // 2) <= max_num_blocks
    ):
        block_dim = block_dim // 2

    grid_dim_x = triton.cdiv(n, block_dim)
    grid_dim_y = max(min(m, max_num_blocks * MAX_WAVES // grid_dim_x), 1)

    return (grid_dim_y, grid_dim_x), block_dim


@triton.jit
def deepep_permute_triton_kernel(
    input_ptr,
    gateup_input_ptr,
    src2dst_ptr,
    topk_ids_ptr,
    a1_scales_ptr,
    topk,
    hidden_size,
    BLOCK_SIZE: tl.constexpr,
):
    OutDtype = gateup_input_ptr.dtype.element_ty

    src_idx = tl.program_id(0)
    src2dst_ptr = src2dst_ptr + src_idx * topk
    topk_ids_ptr = topk_ids_ptr + src_idx * topk

    src_ptr = input_ptr + src_idx * hidden_size

    for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
        offset = start_offset + tl.arange(0, BLOCK_SIZE)
        mask = offset < hidden_size
        in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype)

        for idx in range(topk):
            dst_idx = tl.load(src2dst_ptr + idx)
            if dst_idx >= 0:
                dst_ptr = gateup_input_ptr + dst_idx * hidden_size
                tl.store(dst_ptr + offset, in_data, mask=mask)


@triton.jit
def deepep_post_reorder_triton_kernel(
    down_output_ptr,
    output_ptr,
    src2dst_ptr,
    topk_ids_ptr,
    topk_weights_ptr,
    topk,
    hidden_size,
    BLOCK_SIZE: tl.constexpr,
):
    InDtype = down_output_ptr.dtype.element_ty

    src_idx = tl.program_id(0)
    src2dst_ptr = src2dst_ptr + src_idx * topk
    topk_ids_ptr = topk_ids_ptr + src_idx * topk
    topk_weights_ptr = topk_weights_ptr + src_idx * topk

    store_ptr = output_ptr + src_idx * hidden_size
    for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
        offset = start_offset + tl.arange(0, BLOCK_SIZE)
        mask = offset < hidden_size
        sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
        for idx in range(topk):
            dst_idx = tl.load(src2dst_ptr + idx)
            if dst_idx >= 0:
                weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
                load_ptr = down_output_ptr + dst_idx * hidden_size
                in_data = tl.load(load_ptr + offset, mask=mask)
                sum_vec += in_data * weigh_scale
        tl.store(store_ptr + offset, sum_vec, mask=mask)


@triton.jit
def compute_src2dst_triton_kernel(
    reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = dst_id < num_toks
    src_id = tl.load(reorder_ids + dst_id, mask=mask)
    tl.store(src2dst + src_id, dst_id, mask=mask)


@triton.jit
def deepep_compute_src2dst_triton_kernel(
    reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = dst_id < num_toks
    src_id = tl.load(reorder_ids + dst_id, mask=mask)
    num_invalid = tl.load(num_minus_one)
    tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)


def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
    reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
    seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
    src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)

    # Find offset
    expert_ids = torch.arange(
        num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
    )
    torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
    num_minus_one = seg_indptr[0]
    seg_indptr = seg_indptr - num_minus_one

    BLOCK_SIZE = 512
    grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
    deepep_compute_src2dst_triton_kernel[grid](
        reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
    )
    reorder_topk_ids = reorder_topk_ids[num_minus_one:]
    return reorder_topk_ids, src2dst, seg_indptr


@triton.jit
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
    expert_id_minus_1 = tl.program_id(0) - 1
    low = 0
    high = num_toks - 1
    target_location = -1
    while low <= high:
        mid = (low + high) // 2

        if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
            high = mid - 1
        else:
            low = mid + 1
            target_location = mid
    tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)


def cutlass_w4_run_moe_ep_preproess(topk_ids: torch.Tensor):
    _, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)

    BLOCK_SIZE = 512
    grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
    src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
    compute_src2dst_triton_kernel[grid](
        reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
    )

    return src2dst


@triton.jit
def pre_reorder_triton_kernel_for_cutlass_moe(
    input_ptr,
    gateup_input_ptr,
    src2dst_ptr,
    topk_ids_ptr,
    a1_scales_ptr,
    num_local_experts,
    topk,
    num_tokens,
    hidden_size,
    BLOCK_SIZE: tl.constexpr,
    NUM_STAGES: tl.constexpr,
):
    OutDtype = gateup_input_ptr.dtype.element_ty

    if a1_scales_ptr is not None:
        a1_scale = 1.0 / tl.load(a1_scales_ptr)
    else:
        a1_scale = 1.0

    offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE)
    mask = offset < hidden_size

    start_src_idx = tl.program_id(0)
    step = tl.num_programs(0)

    for src_idx_int32 in tl.range(
        start_src_idx, num_tokens, step, num_stages=NUM_STAGES
    ):
        src_idx = src_idx_int32.to(tl.int64)
        token_src2dst_ptr = src2dst_ptr + src_idx * topk
        token_topk_ids_ptr = topk_ids_ptr + src_idx * topk

        src_ptr_offs = input_ptr + src_idx * hidden_size + offset
        dst_ptr_offs = gateup_input_ptr + offset
        in_data = tl.load(src_ptr_offs, mask=mask).to(tl.float32)
        out_data = (in_data * a1_scale).to(OutDtype)
        for idx in range(topk):
            expert_id = tl.load(token_topk_ids_ptr + idx)
            if expert_id != num_local_experts:
                dst_idx = tl.load(token_src2dst_ptr + idx)
                tl.store(dst_ptr_offs + dst_idx * hidden_size, out_data, mask=mask)


def pre_reorder_for_cutlass_moe(
    input,
    gateup_input,
    src2dst,
    topk_ids,
    a1_scales,
    num_local_experts,
    topk,
    num_tokens,
    hidden_size,
):
    grid, block_dim = _get_launch_config_2d(input.device, num_tokens, hidden_size)

    pre_reorder_triton_kernel_for_cutlass_moe[grid](
        input_ptr=input,
        gateup_input_ptr=gateup_input,
        src2dst_ptr=src2dst,
        topk_ids_ptr=topk_ids,
        a1_scales_ptr=a1_scales,
        num_local_experts=num_local_experts,
        topk=topk,
        num_tokens=num_tokens,
        hidden_size=hidden_size,
        BLOCK_SIZE=block_dim,
        NUM_STAGES=3,
    )


# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
    input_ptr,
    stride_input_0,
    stride_input_1,
    stride_input_2,
    output_ptr,
    stride_output_0,
    stride_output_1,
    stride_output_2,
    output_scale_ptr,
    stride_output_scale_0,
    stride_output_scale_1,
    stride_output_scale_2,
    masked_m_ptr,
    size_n,
    fp8_max,
    fp8_min,
    BLOCK_N: tl.constexpr,
    NUM_STAGE: tl.constexpr,
    SCALE_UE8M0: tl.constexpr,
):
    expert_id = tl.program_id(2)
    token_id = tl.program_id(1)
    hidden_dim_block_index = tl.program_id(0)

    block_num_per_expert = tl.num_programs(1)

    token_num_cur_expert = tl.load(masked_m_ptr + expert_id)

    stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
    stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
    stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
    stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)

    offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
    input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
    output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
    output_scale_offs = (
        output_scale_ptr
        + expert_id * stride_output_scale_0
        + hidden_dim_block_index * stride_output_scale_2
    )

    for token_index in tl.range(
        token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
    ):
        gate = tl.load(
            input_ptr_offs + token_index * stride_input_1,
            mask=offs_in_d < size_n,
            other=0.0,
        ).to(tl.float32)
        up = tl.load(
            input_ptr_offs + token_index * stride_input_1 + size_n,
            mask=offs_in_d < size_n,
            other=0.0,
        )
        gate = gate / (1 + tl.exp(-gate))
        gate = gate.to(input_ptr.dtype.element_ty)
        gate_up = up * gate
        _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
        output_s = _absmax / fp8_max
        if SCALE_UE8M0:
            output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
        output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
            output_ptr.dtype.element_ty
        )
        tl.store(
            output_ptr_offs + token_index * stride_output_1,
            output_q,
            mask=offs_in_d < size_n,
        )
        tl.store(
            output_scale_offs + token_index * stride_output_scale_1,
            output_s,
        )


def silu_and_mul_masked_post_quant_fwd(
    input: torch.Tensor,
    output: torch.Tensor,
    output_scale: torch.Tensor,
    quant_group_size: int,
    masked_m: torch.Tensor,
    scale_ue8m0: bool = False,
):
    """
    input shape [expert_num, token_num_padded, hidden_dim]
    output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
    output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
    quant_group_size  int,
    masked_m shape [expert_num],
    """

    assert input.is_contiguous()
    assert output.dtype == torch.float8_e4m3fn
    assert output.is_contiguous()
    assert len(input.shape) == 3
    assert input.shape[0] == masked_m.shape[0]
    assert input.shape[-1] % 2 == 0

    size_n = input.shape[-1] // 2
    assert size_n % quant_group_size == 0

    expert_num = len(masked_m)

    if expert_num < 4:
        BLOCK_NUM_PER_EXPERT = 64
    else:
        BLOCK_NUM_PER_EXPERT = 32

    BLOCK_N = quant_group_size
    num_warps = 1
    NUM_STAGES = 6
    hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
    assert BLOCK_N % quant_group_size == 0

    grid = (
        hidden_dim_split_block_num,
        BLOCK_NUM_PER_EXPERT,
        expert_num,
    )

    finfo = torch.finfo(torch.float8_e4m3fn)
    fp8_max = finfo.max
    fp8_min = -fp8_max

    _silu_and_mul_post_quant_kernel[grid](
        input,
        *input.stride(),
        output,
        *output.stride(),
        output_scale,
        *output_scale.stride(),
        masked_m,
        size_n,
        fp8_max,
        fp8_min,
        BLOCK_N=BLOCK_N,
        NUM_STAGE=NUM_STAGES,
        num_warps=num_warps,
        SCALE_UE8M0=scale_ue8m0,
    )
    return


@triton.jit
def silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe(
    input_ptr,
    output_ptr,
    scale_ptr,
    num_tokens_tensor_ptr,
    intermediate_size,
    BLOCK_SIZE: tl.constexpr,
    NUM_STAGES: tl.constexpr,
):
    OutDtype = output_ptr.dtype.element_ty

    num_tokens = tl.load(num_tokens_tensor_ptr)
    numel = num_tokens * intermediate_size
    gate_ptr = input_ptr
    up_ptr = input_ptr + intermediate_size
    scale = 1.0 / tl.load(scale_ptr)

    start_idx = tl.program_id(0) * BLOCK_SIZE
    step = tl.num_programs(0) * BLOCK_SIZE

    for id in tl.range(start_idx, numel, step, num_stages=NUM_STAGES):
        ids = id + tl.arange(0, BLOCK_SIZE)
        token_ids = ids // intermediate_size
        mask = ids < numel

        offs = ids + token_ids * intermediate_size
        gate = tl.load(gate_ptr + offs, mask=mask, other=0.0).to(tl.float32)
        up = tl.load(up_ptr + offs, mask=mask, other=0.0).to(tl.float32)
        output = gate / (1 + tl.exp(-gate)) * up * scale
        tl.store(output_ptr + ids, output.to(OutDtype), mask=mask)


def silu_mul_static_tensorwise_quant_for_cutlass_moe(
    input: torch.Tensor,
    output: torch.Tensor,
    scale: torch.Tensor,
    num_tokens_tensor: torch.Tensor,
    expected_num_tokens: int,
    intermediate_size: int,
):
    grid, block_dim = _get_launch_config_1d(
        input.device, expected_num_tokens * intermediate_size
    )

    silu_mul_static_tensorwise_quant_triton_kernel_for_cutlass_moe[grid](
        input_ptr=input,
        output_ptr=output,
        scale_ptr=scale,
        num_tokens_tensor_ptr=num_tokens_tensor,
        intermediate_size=intermediate_size,
        BLOCK_SIZE=block_dim,
        NUM_STAGES=3,
    )


@triton.jit
def post_reorder_triton_kernel_for_cutlass_moe(
    down_output_ptr,
    output_ptr,
    src2dst_ptr,
    topk_ids_ptr,
    topk_weights_ptr,
    num_local_experts,
    topk,
    num_tokens,
    hidden_size,
    routed_scaling_factor: float,
    BLOCK_SIZE: tl.constexpr,
    NUM_STAGES: tl.constexpr,
):
    OutDtype = output_ptr.dtype.element_ty

    offset = BLOCK_SIZE * tl.program_id(1) + tl.arange(0, BLOCK_SIZE)
    mask = offset < hidden_size

    down_output_ptr_offs = down_output_ptr + offset
    output_ptr_offs = output_ptr + offset

    start_src_idx = tl.program_id(0)
    step = tl.num_programs(0)

    for src_idx_int32 in tl.range(
        start_src_idx, num_tokens, step, num_stages=NUM_STAGES
    ):
        src_idx = src_idx_int32.to(tl.int64)
        token_src2dst_ptr = src2dst_ptr + src_idx * topk
        token_topk_ids_ptr = topk_ids_ptr + src_idx * topk
        token_topk_weights_ptr = topk_weights_ptr + src_idx * topk

        sum_vec = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
        for idx in range(topk):
            expert_id = tl.load(token_topk_ids_ptr + idx)
            if expert_id != num_local_experts:
                dst_idx_int32 = tl.load(token_src2dst_ptr + idx)
                dst_idx = dst_idx_int32.to(tl.int64)
                dst_idx = dst_idx
                weight_scale = tl.load(token_topk_weights_ptr + idx).to(tl.float32)
                load_ptr_offs = down_output_ptr_offs + dst_idx * hidden_size
                in_data = tl.load(load_ptr_offs, mask=mask).to(tl.float32)
                sum_vec += in_data * weight_scale
        sum_vec *= routed_scaling_factor
        store_ptr_offs = output_ptr_offs + src_idx * hidden_size
        tl.store(store_ptr_offs, sum_vec.to(OutDtype), mask=mask)


def post_reorder_for_cutlass_moe(
    down_output,
    output,
    src2dst,
    topk_ids,
    topk_weights,
    num_local_experts,
    topk,
    num_tokens,
    hidden_size,
    routed_scaling_factor: float,
):
    grid, block_dim = _get_launch_config_2d(down_output.device, num_tokens, hidden_size)

    post_reorder_triton_kernel_for_cutlass_moe[grid](
        down_output_ptr=down_output,
        output_ptr=output,
        src2dst_ptr=src2dst,
        topk_ids_ptr=topk_ids,
        topk_weights_ptr=topk_weights,
        num_local_experts=num_local_experts,
        topk=topk,
        num_tokens=num_tokens,
        hidden_size=hidden_size,
        routed_scaling_factor=routed_scaling_factor,
        BLOCK_SIZE=block_dim,
        NUM_STAGES=3,
    )


@triton.jit
def post_reorder_triton_kernel(
    down_output_ptr,
    output_ptr,
    src2dst_ptr,
    topk_ids_ptr,
    topk_weights_ptr,
    topk,
    hidden_size,
    BLOCK_SIZE: tl.constexpr,
):
    InDtype = down_output_ptr.dtype.element_ty

    src_idx_int32 = tl.program_id(0)
    src_idx = src_idx_int32.to(tl.int64)
    src2dst_ptr = src2dst_ptr + src_idx * topk
    topk_ids_ptr = topk_ids_ptr + src_idx * topk
    topk_weights_ptr = topk_weights_ptr + src_idx * topk

    store_ptr = output_ptr + src_idx * hidden_size

    vec = tl.arange(0, BLOCK_SIZE)

    for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
        offset = start_offset + vec
        mask = offset < hidden_size

        sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
        for idx in range(topk):
            expert_id = tl.load(topk_ids_ptr + idx)
            if expert_id > 0:
                dst_idx_int32 = tl.load(src2dst_ptr + idx)
                dst_idx = dst_idx_int32.to(tl.int64)
                weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
                load_ptr = down_output_ptr + dst_idx * hidden_size
                in_data = tl.load(load_ptr + offset, mask=mask)
                sum_vec += in_data * weigh_scale
        tl.store(store_ptr + offset, sum_vec, mask=mask)


@triton.jit
def _fwd_kernel_ep_scatter_1(
    num_recv_tokens_per_expert,
    expert_start_loc,
    m_indices,
    num_experts: tl.constexpr,
    BLOCK_E: tl.constexpr,
    BLOCK_EXPERT_NUM: tl.constexpr,
):
    cur_expert = tl.program_id(0)

    offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
    tokens_per_expert = tl.load(
        num_recv_tokens_per_expert + offset_cumsum,
        mask=offset_cumsum < num_experts,
        other=0,
    )
    cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
    tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)

    cur_expert_start = tl.load(expert_start_loc + cur_expert)
    cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)

    m_indices_start_ptr = m_indices + cur_expert_start
    off_expert = tl.arange(0, BLOCK_E)

    for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
        tl.store(
            m_indices_start_ptr + start_m + off_expert,
            cur_expert,
        )


@triton.jit
def _fwd_kernel_ep_scatter_2(
    total_token_num,
    expert_start_loc,
    recv_x,
    recv_x_stride0,
    recv_x_stride1,
    recv_x_scale,
    recv_x_scale_stride0,
    recv_x_scale_stride1,
    recv_topk,
    recv_topk_stride0,
    recv_topk_stride1,
    output_tensor,
    output_tensor_stride0,
    output_tensor_stride1,
    output_tensor_scale,
    output_tensor_scale_stride0,
    output_tensor_scale_stride1,
    output_index,
    output_index_stride0,
    output_index_stride1,
    topk_num: tl.constexpr,
    HIDDEN_SIZE: tl.constexpr,
    HIDDEN_SIZE_PAD: tl.constexpr,
    SCALE_HIDDEN_SIZE: tl.constexpr,
    SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
):
    start_token_id = tl.program_id(0)
    grid_num = tl.num_programs(0)

    offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
    mask = offset_in < HIDDEN_SIZE

    index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
    mask_s = index_in_s < SCALE_HIDDEN_SIZE

    for token_id_int32 in range(start_token_id, total_token_num, grid_num):
        token_id = token_id_int32.to(tl.int64)
        to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
        to_copy_s = tl.load(
            recv_x_scale
            + token_id * recv_x_scale_stride0
            + index_in_s * recv_x_scale_stride1,
            mask=mask_s,
        )

        for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
            topk_index = topk_idx_int32.to(tl.int64)
            expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
            if expert_id >= 0:
                dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
                dest_token_index = dest_token_index_int32.to(tl.int64)

                tl.store(
                    output_index + token_id * output_index_stride0 + topk_index,
                    dest_token_index_int32,
                )
                output_tensor_ptr = (
                    output_tensor + dest_token_index * output_tensor_stride0
                )
                output_tensor_scale_ptr = (
                    output_tensor_scale + dest_token_index * output_tensor_scale_stride0
                )
                tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
                tl.store(
                    output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
                    to_copy_s,
                    mask=mask_s,
                )


# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
@torch.no_grad()
def ep_scatter(
    recv_x: torch.Tensor,
    recv_x_scale: torch.Tensor,
    recv_topk: torch.Tensor,
    num_recv_tokens_per_expert: torch.Tensor,
    expert_start_loc: torch.Tensor,
    output_tensor: torch.Tensor,
    output_tensor_scale: torch.Tensor,
    m_indices: torch.Tensor,
    output_index: torch.Tensor,
    scale_ue8m0: bool = False,
):
    BLOCK_E = 128  # token num of per expert is aligned to 128
    BLOCK_D = 128  # block size of quantization
    num_warps = 8
    num_experts = num_recv_tokens_per_expert.shape[0]
    hidden_size = recv_x.shape[1]
    # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
    grid = num_experts

    scale_hidden_size = hidden_size // BLOCK_D
    if scale_ue8m0:
        # ue8m0 scales are packed here (4 scales per int32),
        # hence the effective size of this dimension is divided by 4.
        scale_hidden_size = ceil_div(scale_hidden_size, 4)

    assert m_indices.shape[0] % BLOCK_E == 0
    assert (
        recv_x_scale.dtype == output_tensor_scale.dtype
    ), f"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}"
    assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size

    _fwd_kernel_ep_scatter_1[(grid,)](
        num_recv_tokens_per_expert,
        expert_start_loc,
        m_indices,
        num_experts=num_experts,
        num_warps=num_warps,
        BLOCK_E=BLOCK_E,
        BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
    )

    grid = min(recv_topk.shape[0], 1024 * 8)

    _fwd_kernel_ep_scatter_2[(grid,)](
        recv_topk.shape[0],
        expert_start_loc,
        recv_x,
        recv_x.stride(0),
        recv_x.stride(1),
        recv_x_scale,
        recv_x_scale.stride(0),
        recv_x_scale.stride(1),
        recv_topk,
        recv_topk.stride(0),
        recv_topk.stride(1),
        output_tensor,
        output_tensor.stride(0),
        output_tensor.stride(1),
        output_tensor_scale,
        output_tensor_scale.stride(0),
        output_tensor_scale.stride(1),
        output_index,
        output_index.stride(0),
        output_index.stride(1),
        topk_num=recv_topk.shape[1],
        num_warps=num_warps,
        HIDDEN_SIZE=hidden_size,
        HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
        SCALE_HIDDEN_SIZE=scale_hidden_size,
        SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
    )
    return


@triton.jit
def _fwd_kernel_ep_gather(
    total_token_num,
    input_tensor,
    input_tensor_stride0,
    input_tensor_stride1,
    recv_topk_ids,
    recv_topk_ids_stride0,
    recv_topk_ids_stride1,
    recv_topk_weight,
    recv_topk_weight_stride0,
    recv_topk_weight_stride1,
    input_index,
    input_index_stride0,
    input_index_stride1,
    output_tensor,
    output_tensor_stride0,
    output_tensor_stride1,
    topk_num: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    cur_block_int32 = tl.program_id(0)
    cur_block = cur_block_int32.to(tl.int64)

    start_cur_token_int32 = tl.program_id(1)

    grid_num = tl.num_programs(1)

    for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
        cur_token = cur_token_int32.to(tl.int64)

        off_d = tl.arange(0, BLOCK_D)
        accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)

        for topk_index_int32 in range(0, topk_num):
            topk_index = topk_index_int32.to(tl.int64)

            expert_id = tl.load(
                recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
            )
            if expert_id >= 0:
                source_token_index_int32 = tl.load(
                    input_index + cur_token * input_index_stride0 + topk_index
                )
                source_token_index = source_token_index_int32.to(tl.int64)

                acc_weight = tl.load(
                    recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
                )
                tmp = tl.load(
                    input_tensor
                    + source_token_index * input_tensor_stride0
                    + cur_block * BLOCK_D
                    + off_d
                )
                accumulator += tmp.to(tl.float32) * acc_weight

        tl.store(
            output_tensor
            + cur_token * output_tensor_stride0
            + cur_block * BLOCK_D
            + off_d,
            accumulator.to(output_tensor.dtype.element_ty),
        )


@torch.no_grad()
def ep_gather(
    input_tensor: torch.Tensor,
    recv_topk_ids: torch.Tensor,
    recv_topk_weight: torch.Tensor,
    input_index: torch.Tensor,
    output_tensor: torch.Tensor,
):
    num_warps = 2
    num_tokens = output_tensor.shape[0]
    hidden_size = input_tensor.shape[1]
    BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024  # block size of quantization
    assert hidden_size % BLOCK_D == 0
    grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
    _fwd_kernel_ep_gather[grid](
        num_tokens,
        input_tensor,
        input_tensor.stride(0),
        input_tensor.stride(1),
        recv_topk_ids,
        recv_topk_ids.stride(0),
        recv_topk_ids.stride(1),
        recv_topk_weight,
        recv_topk_weight.stride(0),
        recv_topk_weight.stride(1),
        input_index,
        input_index.stride(0),
        input_index.stride(1),
        output_tensor,
        output_tensor.stride(0),
        output_tensor.stride(1),
        topk_num=recv_topk_ids.shape[1],
        num_warps=num_warps,
        BLOCK_D=BLOCK_D,
    )
    return


# copy from
# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
def get_tma_aligned_size(x: int, element_size: int) -> int:
    """
    Global memory address of TMA must be 16-byte aligned.
    Since we use column-major layout for the LHS scaling tensor,
        the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.

    Arguments:
        x: original M-axis shape of the LHS scaling tensor.
        element_size: element size of the LHS scaling tensor.

    Returns:
        M-axis shape of the LHS scaling tensor after padding.
    """
    tma_alignment_bytes = 16
    assert tma_alignment_bytes % element_size == 0
    alignment = tma_alignment_bytes // element_size
    return ceil_div(x, alignment) * alignment


@triton.jit
def _tma_align_input_scale_kernel(
    input_scale_ptr,
    output_ptr,
    m,
    k_div_block_size,
    input_scale_stride_m,
    input_scale_stride_k,
    output_stride_m,
    output_stride_k,
    BLOCK_SIZE_K: tl.constexpr,
):
    pid_m = tl.program_id(axis=0)
    grid_m = tl.num_programs(0)
    k_offsets = tl.arange(0, BLOCK_SIZE_K)

    for m_base in range(pid_m, m, grid_m):
        input_offset = (
            input_scale_ptr
            + m_base * input_scale_stride_m
            + k_offsets * input_scale_stride_k
        )
        input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)

        output_offset = (
            output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
        )
        tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)


# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
def tma_align_input_scale(input_scale: torch.Tensor):
    assert input_scale.dim() == 2
    m, k_div_block_size = input_scale.shape
    padd_m = get_tma_aligned_size(m, input_scale.element_size())
    output = torch.empty(
        (k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
    )

    grid_m = min(m, 8192)
    BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)

    _tma_align_input_scale_kernel[(grid_m,)](
        input_scale_ptr=input_scale,
        output_ptr=output,
        m=m,
        k_div_block_size=k_div_block_size,
        input_scale_stride_m=input_scale.stride(0),
        input_scale_stride_k=input_scale.stride(1),
        output_stride_m=output.stride(1),  # Note: these are swapped
        output_stride_k=output.stride(0),  # for column-major
        BLOCK_SIZE_K=BLOCK_SIZE_K,
    )
    return output.t()[:m]


@triton.jit
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
    expert_id = tl.program_id(0)
    start = tl.load(seg_indptr + expert_id)
    end = tl.load(seg_indptr + expert_id + 1)
    tl.store(masked_m + expert_id, (end - start))


@triton.jit
def deepgemm_compute_src2dst_triton_kernel(
    topk_ids,
    reorder_ids,
    seg_indptr,
    src2dst,
    m_max,
    num_toks,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = dst_id < num_toks
    src_id = tl.load(reorder_ids + dst_id, mask=mask)
    expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
    expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
    expert_dst_offset = dst_id - expert_dst_start
    dst_id = expert_id * m_max + expert_dst_offset
    tl.store(src2dst + src_id, dst_id, mask=mask)


@triton.jit
def fill_gateup_input_triton_kernel(
    input_ptr,
    scale_ptr,
    gateup_input_ptr,
    gateup_input_scale_ptr,
    src2dst_ptr,
    topk_ids_ptr,
    topk,
    hidden_size,
    scale_size,
    BLOCK_SIZE: tl.constexpr,
):

    src_idx_int32 = tl.program_id(0)
    src_idx = src_idx_int32.to(tl.int64)
    src2dst_ptr = src2dst_ptr + src_idx * topk
    topk_ids_ptr = topk_ids_ptr + src_idx * topk
    src_ptr = input_ptr + src_idx * hidden_size
    scale_src_ptr = scale_ptr + src_idx * scale_size

    vec = tl.arange(0, BLOCK_SIZE)
    for idx in range(topk):
        expert_id = tl.load(topk_ids_ptr + idx)
        if expert_id >= 0:
            dst_idx_int32 = tl.load(src2dst_ptr + idx)
            dst_idx = dst_idx_int32.to(tl.int64)
            dst_ptr = gateup_input_ptr + dst_idx * hidden_size
            for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
                offset = start_offset + vec
                mask = offset < hidden_size
                in_data = tl.load(src_ptr + offset, mask=mask)
                tl.store(dst_ptr + offset, in_data, mask=mask)
            scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
            for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
                offset = start_offset + vec
                mask = offset < scale_size
                in_scale = tl.load(scale_src_ptr + offset, mask=mask)
                tl.store(scale_dst_ptr + offset, in_scale, mask=mask)


def moe_ep_deepgemm_preprocess(
    topk_ids: torch.Tensor,
    num_local_experts: int,
    hidden_states: torch.Tensor,
    top_k: int,
    block_shape,
    output_dtype: torch.dtype = torch.float8_e4m3fn,
):
    reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
    seg_indptr = torch.zeros(
        num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
    )
    src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
    masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)

    compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
        reorder_topk_ids, seg_indptr, topk_ids.numel()
    )

    grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
    compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)

    # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
    m_max = (hidden_states.size(0) // 256 + 1) * 256
    expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
    gateup_input = torch.empty(
        (num_local_experts, m_max, hidden_states.size(1)),
        device=hidden_states.device,
        dtype=output_dtype,
    )

    deepgemm_compute_src2dst_triton_kernel[grid](
        topk_ids,
        reorder_ids,
        seg_indptr,
        src2dst,
        m_max,
        topk_ids.numel(),
        BLOCK_SIZE=256,
    )

    if block_shape is None:
        block_shape = [128, 128]
    assert len(block_shape) == 2
    block_n, block_k = block_shape[0], block_shape[1]

    # TODO: fuse this with the preprocess
    hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)

    gateup_input_scale = torch.empty(
        (gateup_input.size(0), gateup_input.size(1), scale.size(1)),
        device=hidden_states.device,
        dtype=scale.dtype,
    )

    fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
        hidden_states,
        scale,
        gateup_input,
        gateup_input_scale,
        src2dst,
        topk_ids,
        top_k,
        hidden_states.size(1),
        scale.size(1),
        BLOCK_SIZE=1024,
    )

    return (
        masked_m,
        expected_m,
        src2dst,
        gateup_input,
        gateup_input_scale,
    )


@triton.jit
def compute_identity_kernel(
    top_k,
    hidden_states_ptr,
    expert_scales_ptr,
    num_tokens,
    output_ptr,
    hidden_dim,
    scales_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    batch_id = pid // (hidden_dim // BLOCK_SIZE)
    dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE

    if batch_id >= num_tokens or dim_offset >= hidden_dim:
        return

    h = tl.load(
        hidden_states_ptr
        + batch_id * hidden_dim
        + dim_offset
        + tl.arange(0, BLOCK_SIZE),
        mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
    )

    result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for i in range(top_k):
        scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
        result += h * scale

    tl.store(
        output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
        result,
        mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
    )


def zero_experts_compute_triton(
    expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
):
    N = expert_indices.numel()
    top_k = expert_indices.size(-1)
    grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)

    if zero_expert_type == "identity":
        zero_expert_mask = expert_indices < num_experts
        zero_expert_scales = expert_scales.clone()
        zero_expert_scales[zero_expert_mask] = 0.0

    normal_expert_mask = expert_indices >= num_experts
    expert_indices[normal_expert_mask] = -1
    expert_scales[normal_expert_mask] = 0.0

    output = torch.zeros_like(hidden_states).to(hidden_states.device)
    hidden_dim = hidden_states.size(-1)
    num_tokens = hidden_states.size(0)

    grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
    compute_identity_kernel[grid](
        top_k,
        hidden_states,
        zero_expert_scales,
        num_tokens,
        output,
        hidden_dim,
        zero_expert_scales.stride(0),
        BLOCK_SIZE=256,
    )

    return output


@triton.jit
def compute_problem_sizes_w4a8_kernel(
    masked_m_ptr,
    problem_sizes1_ptr,
    problem_sizes2_ptr,
    n,
    k,
    num_experts,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = pid < num_experts
    final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)

    ps1_idx_0 = pid * 3
    ps1_idx_1 = ps1_idx_0 + 1
    ps1_idx_2 = ps1_idx_0 + 2

    ps2_idx_0 = pid * 3
    ps2_idx_1 = ps2_idx_0 + 1
    ps2_idx_2 = ps2_idx_0 + 2

    ps1_mask_0 = ps1_idx_0 < num_experts * 3
    ps1_mask_1 = ps1_idx_1 < num_experts * 3
    ps1_mask_2 = ps1_idx_2 < num_experts * 3
    ps2_mask_0 = ps2_idx_0 < num_experts * 3
    ps2_mask_1 = ps2_idx_1 < num_experts * 3
    ps2_mask_2 = ps2_idx_2 < num_experts * 3

    tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
    tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
    tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)

    tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
    tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
    tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)


def compute_problem_sizes_w4a8(
    masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
):
    BLOCK_SIZE = 256
    grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
    compute_problem_sizes_w4a8_kernel[grid](
        masked_m,
        problem_sizes1,
        problem_sizes2,
        n,
        k,
        num_experts,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return problem_sizes1, problem_sizes2


def deepep_ll_get_cutlass_w4a8_moe_mm_data(
    masked_m,
    problem_sizes1,
    problem_sizes2,
    num_experts,
    n,
    k,
):
    problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
        masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
    )
    return (
        problem_sizes1.to(torch.int32),
        problem_sizes2.to(torch.int32),
    )


@triton.jit
def _silu_and_mul_post_per_tensor_quant_kernel(
    input_ptr,
    stride_input_expert,
    stride_input_token,
    stride_input_dim,
    output_ptr,
    stride_output_expert,
    stride_output_token,
    stride_output_dim,
    scale_ptr,
    masked_m_ptr,
    inner_dim,
    fp8_max,
    fp8_min,
    BLOCK_N: tl.constexpr,
    NUM_STAGE: tl.constexpr,
):
    """
    Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.

    Shape:
        input:  [E, T_padded, 2*D]  -> gate: [:,:,D], up: [:,:,D]
        output: [E, T_padded, D], dtype=float8_e4m3fn
    """
    expert_id = tl.program_id(2)
    block_id_token = tl.program_id(1)
    block_id_dim = tl.program_id(0)

    num_token_blocks = tl.num_programs(1)

    token_num_cur_expert = tl.load(masked_m_ptr + expert_id)

    scale = 1.0 / tl.load(scale_ptr).to(tl.float32)

    stride_input_expert = tl.cast(stride_input_expert, tl.int32)
    stride_output_expert = tl.cast(stride_output_expert, tl.int32)
    stride_input_token = tl.cast(stride_input_token, tl.int32)
    stride_output_token = tl.cast(stride_output_token, tl.int32)

    offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
    mask_d = offset_d < inner_dim

    # base pointers for current expert and dim block
    input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
    output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d

    for token_idx in tl.range(
        block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
    ):
        gate_ptr = input_base_offs + token_idx * stride_input_token
        up_ptr = gate_ptr + inner_dim
        gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
        up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)

        # SiLU: x * sigmoid(x)
        gate = gate / (1 + tl.exp(-gate))
        gate = gate.to(input_ptr.dtype.element_ty)
        gate_up = up * gate

        scaled = gate_up * scale
        output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
        out_ptr = output_base_offs + token_idx * stride_output_token
        tl.store(out_ptr, output_q, mask=mask_d)


def silu_and_mul_masked_post_per_tensor_quant_fwd(
    input: torch.Tensor,
    output: torch.Tensor,
    masked_m: torch.Tensor,
    scale: torch.Tensor,
) -> torch.Tensor:
    """
    Fused SiLU + Mul + Per-Tensor Quantization to FP8.

    Args:
        input: [expert_num, token_num_padded, 2 * inner_dim]
        output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
        masked_m: [expert_num], actual token count for each expert
        scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)

    Returns:
        output tensor
    """
    assert input.is_contiguous()
    assert output.is_contiguous()
    assert output.dtype == torch.float8_e4m3fn
    assert input.ndim == 3
    assert input.shape[0] == masked_m.shape[0]
    assert input.shape[-1] % 2 == 0
    assert scale.numel() == 1 or scale.shape[0] == input.shape[0]

    expert_num = input.shape[0]
    #  3584
    inner_dim = input.shape[-1] // 2

    BLOCK_N = 256
    BLOCK_M = 64 if expert_num < 4 else 32
    NUM_STAGES = 3
    hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)

    grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
    finfo = torch.finfo(torch.float8_e4m3fn)
    fp8_max = finfo.max
    fp8_min = -fp8_max

    _silu_and_mul_post_per_tensor_quant_kernel[grid](
        input,
        *input.stride(),
        output,
        *output.stride(),
        scale,
        masked_m,
        inner_dim,
        fp8_max,
        fp8_min,
        BLOCK_N=BLOCK_N,
        NUM_STAGE=NUM_STAGES,
    )
    return output
