o
    i                     @   s  d Z ddlZddlZddlZddlmZmZ ddlZddlm	Z	 ddl
Z
ddlmZmZmZmZmZ ddlmZmZ ddlmZ dZdZd	Zd
Zejd	dddee
jB eB dB defddZedddde	jdedefddZeddddede	jde	jdeddf
ddZ edddde	j!de	jfddZ"eddddedeeeeef fddZ#eddddedefd d!Z$edddd"e	j!d#edefd$d%Z%edddd&edefd'd(Z&edddd&ed)edefd*d+Z'edddd&ed)edefd,d-Z(edddd&edefd.d/Z)edddd&ed)edefd0d1Z*edddd&ed)edefd2d3Z+eddddedefd4d5Z,edddd&ed)edefd6d7Z-eddddedefd8d9Z.edddd:ed;edeeef fd<d=Z/edddd&ed)edefd>d?Z0edddd&ed)edefd@dAZ1eddddedefdBdCZ2edddd&ed)edefdDdEZ3eddddedefdFdGZ4eddddHed;edeeef fdIdJZ5edddd&edefdKdLZ6eddddMedefdNdOZ7eddddPedefdQdRZ8eddddSedefdTdUZ9eddddVedWedXedYedZed[ed\ed]edefd^d_Z:e	j;ddaej<e fdbdcZ=e	j;deddedee	j!dfedef
dgdhZ>e	j;deddedee	j!de	jdiej<e dfedefdjdkZ?e	j;de	j@dde	jAdlej<e dee	j!diej<e dfefdmdnZBe	j;doe	j!dpede	j!fdqdrZCe	j;dse	j!dte	j!duedvedwef
dxdyZDe	j;dze	j!d{e	j!de	j!fd|d}ZEe	j;dze	j!d{e	j!de	j!fd~dZFe	j;de	j!defddZGe	j;de	j!defddZHe	j;de	j!d;ede	j!fddZIe	j;de	j!d;ede	j!fddZJe	j;de	j!dedefddZKe	j;de	j!dedvede	j!fddZLe	j;de	j!de	j!dedee	j!ef fddZMdS )a  
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Common utilities for FP4 quantization kernels using CuTe-DSL.

This module contains shared PTX intrinsics, helper functions, and reduction
utilities used by both rmsnorm_fp4quant.py and add_rmsnorm_fp4quant.py.
    N)CallableTuple)Float32Int32Int64Uint32Uint64)Tdsl_user_op)llvmg      @g      |@      )maxsizedevicereturnc                 C   s<   t j sdS | du rt j } t j| }|jd |j S )a  Get the SM version of a CUDA device.

    Args:
        device: CUDA device to query. Can be an int (device index), torch.device,
            device string (e.g., 'cuda:0'), or None to use current device.

    Returns:
        SM version as an integer (e.g., 100 for SM100).
    P   N
   )torchcudais_availablecurrent_deviceget_device_propertiesmajorminor)r   props r   T/home/ubuntu/vllm_env/lib/python3.10/site-packages/flashinfer/cute_dsl/fp4_common.pyget_sm_version2   s   

r   locipsmem_ptrpeer_cta_rank_in_clusterc             
   C   s>   | j ||d }ttjt || gddddtjjdS )z?Map smem pointer to address at another CTA rank in the cluster.r   z$mapa.shared::cluster.u32 $0, $1, $2;=r,r,rFhas_side_effectsis_align_stackasm_dialect)	tointir_valuer   r   
inline_asmr	   i32
AsmDialectAD_ATT)r!   r"   r   r    smem_ptr_i32r   r   r   set_block_rankJ   s   
r/   valmbar_ptrc             	   C   sX   t ||||d }t ||||d }tjd|| j||d|gddddtjjd dS )zDStore Float32 value to shared memory on a remote CTA in the cluster.r   NzIst.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];zr,f,rTFr$   )r/   r)   r   r*   r,   r-   )r0   r!   r1   r"   r   r    remote_smem_ptr_i32remote_mbar_ptr_i32r   r   r   store_shared_remote]   s&   
r4   xc                C   s   | j tj|| j||d S )z/Get pointer to element at coordinate in tensor.r   )iteratorcutecrd2idxlayout)r5   coordr   r    r   r   r   elem_pointery   s   r;   base_ptrc                C   s   t jt jt t t t gt| j||dgddddt jj	||d	}t j
t |dg||d}t j
t |dg||d}t j
t |dg||d}t j
t |d	g||d}t|t|t|t|fS )
z.Load 128 bits (4 x uint32) from global memory.r   z(ld.global.v4.u32 {$0, $1, $2, $3}, [$4];z=r,=r,=r,=r,lFr%   r&   r'   r   r    r            )r   r*   
StructTypeget_literalr	   r+   r   r)   r,   r-   extractvaluer   )r<   r   r    resultv0v1v2v3r   r   r   ld_global_v4_u32   s    "rI   valuec             	   C   s@   t jdt| j||dt|j||dgddddt jjd dS )zStore 64 bits to global memory.Nr   zst.global.u64 [$0], $1;zl,lTFr$   )r   r*   r   r)   r   r,   r-   )r<   rJ   r   r    r   r   r   st_global_u64   s   
rK   tensoroffsetc                C   s.   | j t| }tjt |j||d}t|S )z2Get the memory address of tensor[offset] as Int64.r   )r6   r   r   ptrtointr	   i64llvm_ptrr   )rL   rM   r   r    elem_ptrptr_intr   r   r   get_ptr_as_int64   s   rS   ac             
   C   4   t tjt t | j||dgddddtjjdS )z-Fast reciprocal using PTX rcp.approx.ftz.f32.r   zrcp.approx.ftz.f32 $0, $1;=f,fFr$   r   r   r*   r	   f32r)   r,   r-   rT   r   r    r   r   r   rcp_approx_ftz      rZ   bc             
   C   D   t tjt t | j||dt |j||dgddddtjjdS )z4Compute min of two float32 values using PTX min.f32.r   zmin.f32 $0, $1, $2;=f,f,fFr$   rW   rT   r\   r   r    r   r   r   fmin_f32      "r`   c             
   C   r]   )z4Compute max of two float32 values using PTX max.f32.r   zmax.f32 $0, $1, $2;r^   Fr$   rW   r_   r   r   r   fmax_f32   ra   rb   c             
   C   rU   )z4Compute absolute value of float32 using PTX abs.f32.r   zabs.f32 $0, $1;rV   Fr$   rW   rY   r   r   r   fabs_f32   r[   rc   c             
   C   r]   )z;Multiply two Half2 values element-wise: (a.x*b.x, a.y*b.y).r   zmul.f16x2 $0, $1, $2;r#   Fr$   r   r   r*   r	   r+   r)   r,   r-   r_   r   r   r   	half2_mul   ra   re   c             
   C   r]   )z6Add two Half2 values element-wise: (a.x+b.x, a.y+b.y).r   zadd.f16x2 $0, $1, $2;r#   Fr$   rd   r_   r   r   r   hadd2  ra   rf   c             
   C   rU   )z<Half2 absolute value - clears sign bits of both fp16 values.r   and.b32 $0, $1, 0x7FFF7FFF;=r,rFr$   rd   r5   r   r    r   r   r   habs2   r[   rj   c             
   C   r]   )z-Half2 max - element-wise max of 2 fp16 pairs.r   zmax.f16x2 $0, $1, $2;r#   Fr$   rd   r_   r   r   r   hmax20  ra   rk   c             
   C   4   t tjt t| j||dgddddtjjdS )z1Extract max of 2 fp16 values in half2 as float32.r   z
            {
                .reg .b16 h0, h1;
                .reg .f32 f0, f1;
                mov.b32 {h0, h1}, $1;
                cvt.f32.f16 f0, h0;
                cvt.f32.f16 f1, h1;
                max.f32 $0, f0, f1;
            }
            =f,rFr$   	r   r   r*   r	   rX   r   r)   r,   r-   ri   r   r   r   hmax_to_f32@  s   
ro   h2scalec                C      t jt jt t gt| j||dt|j||dgddddt j	j
||d	}t jt |dg||d}t jt |dg||d}t|t|fS )z.Convert half2 to float2 AND multiply by scale.r   z
        {
            .reg .b16 h0, h1;
            .reg .f32 f0, f1;
            mov.b32 {h0, h1}, $2;
            cvt.f32.f16 f0, h0;
            cvt.f32.f16 f1, h1;
            mul.f32 $0, f0, $3;
            mul.f32 $1, f1, $3;
        }
        	=f,=f,r,fFr=   r   r>   r   r*   rA   rB   r	   rX   r   r)   r   r,   r-   rC   )rp   rq   r   r    rD   f0f1r   r   r   half2_to_float2_scaledY  s   "rw   c             
   C   r]   )z=Multiply two BFloat2 values element-wise: (a.x*b.x, a.y*b.y).r   zmul.bf16x2 $0, $1, $2;r#   Fr$   rd   r_   r   r   r   bfloat2_mul  ra   rx   c             
   C   r]   )z8Add two BFloat2 values element-wise: (a.x+b.x, a.y+b.y).r   zadd.bf16x2 $0, $1, $2;r#   Fr$   rd   r_   r   r   r   bfloat2_add  ra   ry   c             
   C   rU   )zABFloat16x2 absolute value - clears sign bits of both bf16 values.r   rg   rh   Fr$   rd   ri   r   r   r   bfloat2_habs2  r[   rz   c             
   C   r]   )z2BFloat16x2 max - element-wise max of 2 bf16 pairs.r   zmax.bf16x2 $0, $1, $2;r#   Fr$   rd   r_   r   r   r   bfloat2_hmax2  ra   r{   c             
   C   rl   )z3Extract max of 2 bf16 values in bfloat2 as float32.r   ae  
            {
                .reg .b32 lo, hi;
                .reg .f32 f0, f1;
                and.b32 lo, $1, 0xFFFF;
                shr.b32 hi, $1, 16;
                shl.b32 lo, lo, 16;
                shl.b32 hi, hi, 16;
                mov.b32 f0, lo;
                mov.b32 f1, hi;
                max.f32 $0, f0, f1;
            }
            rm   Fr$   rn   ri   r   r   r   bfloat2_hmax_to_f32  s   r|   bf2c                C   rr   )z3Convert bfloat16x2 to float2 AND multiply by scale.r   aU  
        {
            .reg .b32 lo, hi;
            .reg .f32 f0, f1;
            and.b32 lo, $2, 0xFFFF;
            shr.b32 hi, $2, 16;
            shl.b32 lo, lo, 16;
            shl.b32 hi, hi, 16;
            mov.b32 f0, lo;
            mov.b32 f1, hi;
            mul.f32 $0, f0, $3;
            mul.f32 $1, f1, $3;
        }
        rs   Fr=   r   r>   rt   )r}   rq   r   r    rD   ru   rv   r   r   r   bfloat2_to_float2_scaled  s   "r~   c             
   C   rl   )zAConvert float32 to E4M3 using native cvt.rn.satfinite.e4m3x2.f32.r   a  
            {
                .reg .b16 fp8_pair;
                .reg .f32 zero;
                mov.f32 zero, 0f00000000;
                cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $1;
                cvt.u32.u16 $0, fp8_pair;
            }
            =r,fFr$   	r   r   r*   r	   r+   r   r)   r,   r-   rY   r   r   r   cvt_f32_to_e4m3  s   	r   fp8_valc             
   C   rl   )z3Convert FP8 E4M3 to float32 AND compute reciprocal.r   a  
            {
                .reg .pred p_zero;
                .reg .u32 exp_u, mant_u;
                .reg .s32 exp_s;
                .reg .f32 exp_f, mant_f, fp8_float, result;

                setp.eq.u32 p_zero, $1, 0;
                and.b32 mant_u, $1, 7;
                shr.b32 exp_u, $1, 3;
                and.b32 exp_u, exp_u, 15;
                sub.s32 exp_s, exp_u, 7;
                cvt.rn.f32.s32 exp_f, exp_s;
                ex2.approx.f32 exp_f, exp_f;
                cvt.rn.f32.u32 mant_f, mant_u;
                fma.rn.f32 mant_f, mant_f, 0f3E000000, 0f3F800000;
                mul.f32 fp8_float, exp_f, mant_f;
                rcp.approx.ftz.f32 result, fp8_float;
                selp.f32 $0, 0f00000000, result, p_zero;
            }
            rm   Fr$   rn   )r   r   r    r   r   r   fp8_e4m3_to_f32_and_rcp  s   r   max_valc             
   C   rl   )aS  
    Convert float32 max value to UE8M0 scale factor.

    UE8M0 is unsigned 8-bit exponent-only format:
    - value = 2^(ue8m0 - 127)
    - ue8m0 = ceil(log2(max_val)) + 127

    Uses lg2.approx.f32 for fast log2 approximation.
    Uses cvt.rpi (round towards positive infinity, i.e., ceiling).
    Returns value clamped to [0, 255].
    r   a  
            {
                .reg .pred p_zero, p_neg, p_ovf;
                .reg .f32 log2_val;
                .reg .s32 exp_int, result;

                // Check for zero/negative
                setp.le.f32 p_zero, $1, 0f00000000;

                // Compute ceil(log2(max_val)) using cvt.rpi (round towards +inf)
                lg2.approx.f32 log2_val, $1;
                cvt.rpi.s32.f32 exp_int, log2_val;

                // Add bias and clamp to [0, 255]
                add.s32 result, exp_int, 127;
                setp.lt.s32 p_neg, result, 0;
                setp.gt.s32 p_ovf, result, 255;
                selp.s32 result, 0, result, p_neg;
                selp.s32 result, 255, result, p_ovf;
                selp.s32 $0, 0, result, p_zero;
            }
            r   Fr$   r   )r   r   r    r   r   r   cvt_f32_to_ue8m0E  s   r   	ue8m0_valc             
   C   rl   )z
    Convert UE8M0 to output_scale for MXFP4 quantization.

    UE8M0 value = 2^(ue8m0 - 127)
    Returns 1 / 2^(ue8m0 - 127) = 2^(127 - ue8m0)
    r   a  
            {
                .reg .pred p_zero;
                .reg .s32 neg_exp;
                .reg .f32 neg_exp_f, result;

                // Check for zero
                setp.eq.u32 p_zero, $1, 0;

                // Compute 2^(127 - ue8m0) = 1 / 2^(ue8m0 - 127)
                sub.s32 neg_exp, 127, $1;
                cvt.rn.f32.s32 neg_exp_f, neg_exp;
                ex2.approx.f32 result, neg_exp_f;
                selp.f32 $0, 0f00000000, result, p_zero;
            }
            rm   Fr$   rn   )r   r   r    r   r   r   ue8m0_to_output_scalet  s   r   rE   rF   rG   rH   v4v5v6v7c          
      C   s   t tjt t| j||	dt|j||	dt|j||	dt|j||	dt|j||	dt|j||	dt|j||	dt|j||	dgddddtjjdS )zMConvert eight float32 values to eight E2M1 (4-bit) values packed into uint32.r   a  
            {
                .reg .b8 byte0, byte1, byte2, byte3;
                cvt.rn.satfinite.e2m1x2.f32 byte0, $2, $1;
                cvt.rn.satfinite.e2m1x2.f32 byte1, $4, $3;
                cvt.rn.satfinite.e2m1x2.f32 byte2, $6, $5;
                cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $7;
                mov.b32 $0, {byte0, byte1, byte2, byte3};
            }
            z=r,f,f,f,f,f,f,f,fFr$   r   )
rE   rF   rG   rH   r   r   r   r   r   r    r   r   r   cvt_e2m1x8_f32  s&   

r       widthc                 C   s   t t| tjr0t| j| j}||  t 	t
| jD ]}t|| ||||< q| S t 	tt|D ]}|| tjj| d|> d} q:| S )z8Reduce across threads in a warp using butterfly shuffle.r>   )rM   )cutlass
const_expr
isinstancer7   	TensorSSAmake_rmem_tensorshapedtypestorerange_constexprsizewarp_reduceloadintmathlog2archshuffle_sync_bfly)r0   opr   resir   r   r   r     s   
r   r   reduction_bufferinit_valc           
      C   st   t j }t j }t |jd }|| }|| }|dkr$| |||f< t j  |}	||k r5|||f }	t|	|S )z:Block reduction across multiple warps using shared memory.r>   r   )r7   r   lane_idxwarp_idxr   r   barrierr   )
r0   r   r   r   r   r   warps_per_rowrow_idxcol_idxblock_reduce_valr   r   r   block_reduce  s   



r   	cluster_nc                 C   s(  t j }t j }t j }|jd }	|jd d }
||
 }||
 }|dkrMt j  |	|
 }|| d }t j|| W d   n1 sHw   Y  ||k r`t| t	||||ff||d t jj
|dd |
| }t |d}|}t|D ]}||d  }||k r|||||f }qyt||S )z6Cluster reduction across multiple CTAs using mbarrier.r   r>      N)r"   )phaser   )r7   r   block_idx_in_clusterr   r   r   	elect_onembarrier_arrive_and_expect_txr4   r;   mbarrier_waitceil_divr   r   r   )r0   r   r   r1   r   r   cta_rank_in_clusterr   r   rows_per_blockr   r   r   	num_warpsexpected_bytes	num_totalnum_iterr   r   idxr   r   r   cluster_reduce  s<   





r   threads_per_rowc                 C   s   | j ||dd}tjjtjtjjtjji| }t	|d}	t
|||	d}
t|d d}t|dkp3|dkrLt|dkrCt|
|||S t|
|||||S |
S )z,Row reduction with optional cluster support.r   )r   reduction_profiler   )r   r>   )reducer7   ReductionOpADDoperatoraddMAXr   fmaxminr   maxr   r   r   r   )r5   r   r   r   r1   r   r   	local_valwarp_op
warp_widthwarp_valr   r   r   r   
row_reduce+  s    

r   tXcXlimitc              	   C   s   t t jt j| ddgdt j| dgdt j| dgdft j| dgdddfdtj}t|jd D ]!}t|jd D ]}t | d|fd|f d |||d|f< q=q3|S )z,Create predicate tensor for bounds checking.r   r>   )moder?   )stride)	r7   r   make_layoutr   r   Booleanr   r   	elem_less)r   r   tXpXrest_vrest_kr   r   r   predicate_kQ  s"   r   mXmW
row_offset
col_offsetHc                 C   s   t dt}t dt}t| || | }t| || | td }t|\|d< |d< |d< |d< t|\|d< |d< |d	< |d
< t||}	t||td }
t|	\|d< |d< |d< |d< t|
\|d< |d< |d	< |d
< ||fS )zLoad 16 elements (8 half2 pairs) of X and W from global memory.

    Returns:
        x_h2: rmem_tensor of shape (8,) containing X as half2
        w_h2: rmem_tensor of shape (8,) containing W as half2
       r   r   r>   r?   r@   r            )r7   r   r   rS   r   rI   )r   r   r   r   r   x_h2w_h2x_ptr0x_ptr1w_ptr0w_ptr1r   r   r   load_8_half2l  s   
  
  r   r   r   c                 C   6   t dt}tdD ]}t| | || ||< q|S )z$Multiply 8 half2 pairs element-wise.r   r   )r7   r   r   r   r   re   r   r   xw_h2r   r   r   r   half2_mul_8     r   c                 C   r   )z&Multiply 8 bfloat2 pairs element-wise.r   r   )r7   r   r   r   r   rx   r   r   r   r   bfloat2_mul_8  r   r   r   c           	      C      t dt}tdD ]
}t| | ||< qt|d |d }t|d |d }t|d |d }t|d	 |d
 }t||}t||}t||S )zFCompute max absolute value across 8 half2 values using tree reduction.r   r   r   r>   r?   r@   r   r   r   r   )r7   r   r   r   r   rj   rk   	r   abs_h2r   max_01max_23max_45max_67max_0123max_4567r   r   r   half2_max_abs_8     


r   c           	      C   r   )zHCompute max absolute value across 8 bfloat2 values using tree reduction.r   r   r   r>   r?   r@   r   r   r   r   )r7   r   r   r   r   rz   r{   r   r   r   r   bfloat2_max_abs_8  r   r   c                 C   F   t dt}tdD ]}t| | |\||d < ||d d < q|S )z+Convert 8 half2 to 16 float32 with scaling.r   r   r?   r>   )r7   r   r   r   r   rw   r   rq   y_f32r   r   r   r   half2_to_float16     (r   c                 C   r   )z-Convert 8 bfloat2 to 16 float32 with scaling.r   r   r?   r>   )r7   r   r   r   r   r~   r   r   r   r   bfloat2_to_float16  r   r  r   	inv_scalec              
   C   s   t dt}tdD ]
}| | | ||< qt|d |d |d |d |d |d |d	 |d
 }t|d |d |d |d |d |d |d |d }t|td> t|B S )z7Quantize 16 float32 values to FP4 and pack into uint64.r   r   r   r>   r?   r@   r   r   r   r   r   	   r                  r   )r7   r   r   r   r   r   r   )r   r  qr   	packed_lo	packed_hir   r   r   quantize_and_pack_16  s   66r  sHr   c                 C   s8   t dt}tdD ]}t| ||| f ||< q|S )z*Load 16 Float32 values from shared memory.r   r   )r7   r   r   r   r   )r  r   r   h_f32r   r   r   r   load_f32_16_from_smem  s   r  r  w_f32rstdc                 C   st   t dt}| d | |d  |d< t|d }tddD ]}| | | ||  ||< t|t|| }q||fS )z;Compute y = h * rstd * w and max_abs for 16 Float32 values.r   r   r>   r   )r7   r   r   rc   r   r   rb   )r  r  r  r   max_absr   r   r   r   compute_y_and_max_abs_f32  s   r  )N)r   )N__doc__	functoolsr   r   typingr   r   r   cutlass.cuter7   r   r   r   r   r   r   cutlass.cutlass_dslr	   r
   cutlass._mlir.dialectsr   FLOAT4_E2M1_MAXFLOAT8_E4M3_MAXSF_VEC_SIZE	COPY_BITS	lru_cacher   r   strr   Pointerr/   r4   Tensorr;   rI   rK   rS   rZ   r`   rb   rc   re   rf   rj   rk   ro   rw   rx   ry   rz   r{   r|   r~   r   r   r   r   r   jit	Constexprr   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r   r   r   r   <module>   s   
$ 
"     
%   
((.(23%
