o
    پi                  !   @   s  d Z ddlZddlmZmZmZ ddlZddlmZ ddl	Z	ddlm
Z
mZmZmZmZ ddlmZ ddlmZmZmZmZmZmZmZmZmZmZmZmZmZmZm Z m!Z!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z-m.Z. G d	d
 d
Z/ej0	d#de1de1de2de1de3de2de2defddZ4e									d$de	j5de	j5de	j5de	j5dB de	j5dB de	j5dB de6de1de3dB de2de2de	j5dB deee	j5e	j5f ee	j5e	j5e	j5f f fd d!Z7g d"Z8dS )%ul  
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Fused Add + RMSNorm + FP4 Quantization using CuTe-DSL
======================================================

High-performance fused kernel for element-wise addition followed by RMS normalization
and FP4 quantization. Supports both NVFP4 (block_size=16, E4M3 scales) and MXFP4
(block_size=32, UE8M0 scales) formats.

Operation:
    1. residual = residual + input (in-place update)
    2. output = (residual / sqrt(mean(residual²) + eps)) * weight
    3. quantize output to FP4

The residual tensor is modified in-place to contain the fused value (input + residual).

    N)CallableTupleUnion)Float32Int32Int64Uint32Uint8   )flashinfer_api   )FLOAT4_E2M1_MAXFLOAT8_E4M3_MAX	COPY_BITSget_sm_versionst_global_u64get_ptr_as_int64rcp_approx_ftzfmin_f32fmax_f32hmax2hmax_to_f32bfloat2_hmax2bfloat2_hmax_to_f32cvt_f32_to_e4m3fp8_e4m3_to_f32_and_rcpcvt_f32_to_ue8m0ue8m0_to_output_scale
row_reducepredicate_kload_8_half2half2_mul_8bfloat2_mul_8half2_max_abs_8bfloat2_max_abs_8half2_to_float16bfloat2_to_float16quantize_and_pack_16load_f32_16_from_smemcompute_y_and_max_abs_f32c                   @   sv  e Zd ZdZ			d1dejdedededed	edB d
edB defddZ	e
dedejd	edefddZe
dedefddZe
dedefddZe
dedededefddZe
dededededef
ddZdefd d!Zejd"ejd#ejd$ejd%ejd&ejd'ejd(ejd)ed*efd+d,Zejd"ejd#ejd$ejd%ejd&ejd'ejd(ejd)ed*ed-ejd.ejfd/d0ZdS )2AddRMSNormFP4QuantKernelaA  
    Fused Add + RMSNorm + FP4 Quantization Kernel.

    Computes:
        1. residual = input + residual (in-place update)
        2. y = RMSNorm(residual) * weight
        3. quantize y to FP4

    The residual tensor is modified in-place.
    Supports both NVFP4 (block_size=16) and MXFP4 (block_size=32) formats.
    NFdtypeH
block_sizeoutput_swizzledis_fp16
sm_versionscale_formatoutput_both_sf_layoutsc	                 C   sb  || _ || _|| _|| _|| _|d ur|nt | _|| _|d u r*|dkr&dnd| _n|| _|dv s8J d| | jdv sAJ d| 	||| j| _
|| j
 | _| | j| _| | j| _| j| j | _t| jd d| _|jd	 }	td	 |	 | _td| j| j | j d | j | _| j| j | j | _|| | _|s|r|| }
|
d
 d | _d| _d S d S )N    ue8m0e4m3   r3   z!block_size must be 16 or 32, got )r5   r4   z&scale_format must be 'e4m3' or 'ue8m0'r               )r+   r,   r-   r.   r/   r   r0   r2   r1   _compute_cluster_n	cluster_n	H_per_cta_compute_threads_per_rowthreads_per_row_compute_num_threadsnum_threadsrows_per_blockmaxwarps_per_rowwidthr   vec_sizenum_vec_blockscols_per_tilenum_sf_blocks_per_rownum_k_tilesk_tile_stride)selfr+   r,   r-   r.   r/   r0   r1   r2   
elem_bytesnum_col_vecs rP   \/home/ubuntu/.local/lib/python3.10/site-packages/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py__init__f   sF   


z!AddRMSNormFP4QuantKernel.__init__returnc                 C   sh   |dk rdS t jt j }|j}|jd }dD ]}| | dkr"qt| ||}||kr1|  S qdS )a'  Compute optimal cluster size based on H and device shared memory.

        Dynamically determines the minimum cluster_n that fits within the
        device's shared memory limit, making it compatible with different
        GPU architectures (e.g., SM100 with 228KB vs SM120 with 128KB).
        Z   r   r8   )r   r
   r:   r8   r7   r   r7   )torchcudaget_device_propertiescurrent_deviceshared_memory_per_block_optinrF   r*   _estimate_smem_bytes)r,   r+   r0   propsmax_smem_bytes	elem_sizer=   smem_neededrP   rP   rQ   r<      s   
z+AddRMSNormFP4QuantKernel._compute_cluster_nr>   c                 C   s@   | dkrdS | dkrdS | dkrdS | dkrdS | dkrdS d	S )
z Compute optimal threads per row.@   r8      r7   i   r3   i    @     rP   r>   rP   rP   rQ   r?      s   z1AddRMSNormFP4QuantKernel._compute_threads_per_rowc                 C   s   | dkrdS dS )z Compute total threads per block.ra   r`   rb   rP   rc   rP   rP   rQ   rA      s   z-AddRMSNormFP4QuantKernel._compute_num_threadsr=   r]   c                 C   s   | | }t |}t |}|| }t|d d}td | }td|| | d | }	||	 | }
||
 | }|dkrFd| || d  S d| || | d  d S )zEstimate shared memory bytes needed for given configuration.

        This is used to dynamically determine cluster_n based on device
        shared memory limits.
        r3   r   r8   r:   r
   )r*   r?   rA   rD   r   )r,   r=   r]   r>   r@   rB   rC   rE   rG   rH   rI   
tile_bytesrP   rP   rQ   rZ      s   

z-AddRMSNormFP4QuantKernel._estimate_smem_bytesr@   rC   rG   rH   c                 C   s4   | |f||ff}|| df||| |  ff}||fS )zBCreate Thread-Value layout for coalesced vectorized memory access.r   rP   )r@   rC   rG   rH   shapestriderP   rP   rQ   _make_tv_layout   s   	
z(AddRMSNormFP4QuantKernel._make_tv_layoutc                 C   s   | j jd }| j| j | }| j| j | }| jdkr4| j| j | }| j| j | }| j| j d }nd}d}| j| j | j d }| jdkrJdnd}|| | | | | S )z$Calculate shared memory requirement.r8   r   r:   r   )r+   rF   rC   rI   r=   rE   )rM   r]   x_tile_bytesr_tile_bytesw_tile_bytesh_tile_bytesreduction_bytes
mbar_bytesrP   rP   rQ   _smem_size_in_bytes   s0   
z,AddRMSNormFP4QuantKernel._smem_size_in_bytesmXmRmWmYmSmS_unswizzledmGlobalScaleMepsc                 C   s   |  | j| j| j| j\}}tj||d}| j| jf}| |||||||||	||j	t
|| j| jdg| jddgt| jdkrGd| jdgnd|  |
d dS )aI  Host function to launch the kernel.

        Takes tensors directly via TVM-FFI.
        - mX: Input tensor, shape (M, H), row-major (read-only)
        - mR: Residual tensor, shape (M, H), row-major (modified in-place to input + residual)
        - mW: Weight tensor, shape (H,)
        - mY: Output FP4 tensor, shape (M, H // 2), row-major (packed)
        - mS: Scale factor tensor, shape depends on swizzle mode
        - mS_unswizzled: Unswizzled scale factor tensor (used when output_both_sf_layouts=True)
        - mGlobalScale: Global scale tensor, shape (1,), float32
        rf   r   N)gridblockclustersmemstream)rg   r@   rC   rG   rH   cutemake_layoutrI   kernellaunchceil_divr=   rB   cutlass
const_exprrn   )rM   ro   rp   rq   rr   rs   rt   ru   rv   rw   r}   tv_shape	tv_stride	tv_layouttiler_mnrP   rP   rQ   __call__  s(   

z!AddRMSNormFP4QuantKernel.__call__r   r   c           |   	   C   s  t j \}}}t j \}}}| j}| j}| j}| j}| j}t	
|dkr.t j d }nt	
d}|
jd d }t|d d}|d }|| }|| }ttt}t	j }|j|jt j|dddd}|j|jt j|dddd}t	
|dkr|j|jt j|dddd}|j|jt j|dddd}t	
|dkr|jtt ||fdd} d	}!n|jtt |||ffdd} |jtdd
}!t	
|dkr|dkrt j|!d t j  t j  t j  t |j}"t ||||f}#t ||||f}$t |"|||f}%t	
|dkr-t |jt j|d fdd}&t  |j!|&}'t |'|d|f}(t j"t j#j$% |jt&d})t j"t j#' |jt&d}*t (|)|
|}+t (|*|
|},|+)|}-|+)|}.|,)|}/|-*|#}0|-+|}1|.*|$}2|.+|}3|-*|%}4|/+|$}5t	
|dkr|+)|}6|6*|(}7|6+|}8|-+|}9t ,|0}:t ,|2};t ,|5}<t-|4|d}=|4d }>|>d |k }?|?rt j.|)|0|1|=d t j.|)|2|3|=d t	
|dkrt j.|)|7|8|=d t j/  t j0d t 1|1|: t 1|3|; |:2 3t}@|;2 3t}A|@|A }B|B|B }C|<4|B3| j5 t	
|dkr+|B3|j}D|94|D t6|Ct j7j8|| |!|td}E|E| }Ft j9j:|F|	 dd}G|d }Ht	
|dkr^t j  t j  nt j;  |?rot j.|*|<|5|=d || | }I|I|k ru|| d | }Jt<|JD ]}K||K|  }L|L|k rs|L| }Mt	
|dkrt	
|dkrt=|||M}Nt=|||M}Ot>|N|O|G\}P}Q|H|Q | }Rt?|Rtt@}RtA|R}StB|StCd@ }TtD|S|H }Ut	
| jEr,|LtFd }V|ItFd tFd }W|ItFd }X|LtFd }Y|ItFd }Z| jG| jH }[|Z|[ |Y| jH  |XtFd  |WtFd  |V }\|T||\< |T||I|Lf< nQt	
| jIrw|LtFd }V|ItFd tFd }W|ItFd }X|LtFd }Y|ItFd }Z| jG| jH }[|Z|[ |Y| jH  |XtFd  |WtFd  |V }\|T||\< n|T||I|Lf< tJ|P|U}]|Md }^tK||I|d  |^ }_tL|_|] qtM|||I|M|\}`}at	
|rtN|`|a}btO|b}ctP|c}dtQ|b|G}PntR|`|a}btS|b}ctT|c}dtU|b|G}P|d|G }Q|H|Q | }Rt?|Rtt@}RtA|R}StB|StCd@ }TtD|S|H }Ut	
| jErA|LtFd }V|ItFd tFd }W|ItFd }X|LtFd }Y|ItFd }Z| jG| jH }[|Z|[ |Y| jH  |XtFd  |WtFd  |V }\|T||\< |T||I|Lf< nQt	
| jIr|LtFd }V|ItFd tFd }W|ItFd }X|LtFd }Y|ItFd }Z| jG| jH }[|Z|[ |Y| jH  |XtFd  |WtFd  |V }\|T||\< n|T||I|Lf< tJ|P|U}]|Md }^tK||I|d  |^ }_tL|_|] qt	
|dkrt=|||M}et=|||M}ft>|e|f|G\}g}ht=|||MtFd }it=|||MtFd }jt>|i|j|G\}k}ltV|h|l}Qt	
| jWdkr|Q| }RtX|R}mtB|mtCd@ }ntY|m}Un|H|Q | }Rt?|Rtt@}RtA|R}StB|StCd@ }ntD|S|H }Ut	
| jErx|LtFd }V|ItFd tFd }W|ItFd }X|LtFd }Y|ItFd }Z| jG| jH }[|Z|[ |Y| jH  |XtFd  |WtFd  |V }\|n||\< |n||I|Lf< nQt	
| jIr|LtFd }V|ItFd tFd }W|ItFd }X|LtFd }Y|ItFd }Z| jG| jH }[|Z|[ |Y| jH  |XtFd  |WtFd  |V }\|n||\< n|n||I|Lf< tJ|g|U}otJ|k|U}p|I|d  |L|d   }qtK||q}rtK||qtFd }stL|r|o tL|s|p qtM|||I|M|\}t}utM|||I|MtFd |\}v}wt	
|r=tN|t|u}xtN|v|w}ytO|x}ztO|y}{tZ|z|{}ctP|c}dtQ|x|G}gtQ|y|G}kn%tR|t|u}xtR|v|w}ytS|x}ztS|y}{t[|z|{}ctT|c}dtU|x|G}gtU|y|G}k|d|G }Qt	
| jWdkr|Q| }RtX|R}mtB|mtCd@ }ntY|m}Un|H|Q | }Rt?|Rtt@}RtA|R}StB|StCd@ }ntD|S|H }Ut	
| jEr|LtFd }V|ItFd tFd }W|ItFd }X|LtFd }Y|ItFd }Z| jG| jH }[|Z|[ |Y| jH  |XtFd  |WtFd  |V }\|n||\< |n||I|Lf< nQt	
| jIr?|LtFd }V|ItFd tFd }W|ItFd }X|LtFd }Y|ItFd }Z| jG| jH }[|Z|[ |Y| jH  |XtFd  |WtFd  |V }\|n||\< n|n||I|Lf< tJ|g|U}otJ|k|U}p|I|d  |L|d   }qtK||q}rtK||qtFd }stL|r|o tL|s|p qd	S d	S )a  Device kernel with cluster sync and Half2 SIMD.

        Performs:
        1. h = input + residual (writes h back to mR in-place)
        2. y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale
        3. quantizes y to FP4

        mGlobalScale contains the global scale value. The kernel reads it and
        computes 1/global_scale, which is multiplied with rstd to apply:
        y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale
        r   r   r3   r   r   )orderr7   )byte_alignmentr:   N)	num_elems)r   rx   )num_bits_per_copy)limit))r   r   r   r   )predg        T)fastmath   r`   r
   r4   r8   )\r~   arch
thread_idx	block_idxr,   r-   rJ   r/   r=   r   r   re   rD   r   r   r   utilsSmemAllocatorallocate_tensorelement_typemake_ordered_layoutr   allocate_arrayr   mbarrier_initmbarrier_init_fencecluster_arrive_relaxedcluster_waitmake_identity_tensor
local_tileprependlayoutmake_tensoriteratormake_copy_atomnvgpucpasync	CopyG2SOpr   CopyUniversalOpmake_tiled_copy	get_slicepartition_Spartition_Dmake_fragment_liker   copycp_async_commit_groupcp_async_wait_groupautovec_copyloadtostorer+   r   ReductionOpADDmathrsqrtbarrierranger(   r)   r   r   r   r	   r   r   r2   r   rK   rL   r.   r'   r   r   r    r!   r#   r   r%   r"   r$   r   r&   r   r1   r   r   r   r   )|rM   ro   rp   rq   rr   rs   rt   ru   rv   rw   r   r   tidx_bidxr,   r-   rJ   r/   r=   	cluster_yr@   rE   rC   lane_in_rowrow_in_blockfp4_max_rcpr|   sXsRsWsHreduction_buffermbar_ptridXgXgRcXmW_expanded_layoutmW_2dgWcopy_atom_load_asynccopy_atom_storetiled_copy_loadtiled_copy_store
thr_copy_X
thr_copy_Rthr_copy_R_storetXgXtXsXtRgRtRsRtXcXtRgO
thr_copy_WtWgWtWsWtHsHtXrXtRrRtRrOtXpX	row_coordrow_in_boundsx_valsr_valsh_valsh_sqh_elemsum_sqmean_sqrstdglobal_scale_valactual_row_idxnum_sf_per_threadsf_itersf_idxblock_starth_f32w_f32y_f32max_absscale_floatscale_fp8_u32	scale_fp8	inv_scaleinner_k_idxinner_m_idxouter_m_idx
k_tile_idx
m_tile_idxm_tile_strideswizzled_offsetpacked64
out_offsetout_ptrh_h2w_h2hw_h2max_hwmax_xwh_f32_c0w_f32_c0y_f32_c0
max_abs_c0h_f32_c1w_f32_c1y_f32_c1
max_abs_c1scale_ue8m0scale_u8packed64_c0packed64_c1
fp4_offset	fp4_ptr_0	fp4_ptr_1h_h2_c0w_h2_c0h_h2_c1w_h2_c1hw_h2_c0hw_h2_c1	max_c0_h2	max_c1_h2rP   rP   rQ   r   D  s|  













































































  zAddRMSNormFP4QuantKernel.kernel)NNF)__name__
__module____qualname____doc__r   NumericintboolstrrR   staticmethodr<   r?   rA   rZ   tuplerg   rn   r~   jitTensorr   r   r   r   LayoutShaperP   rP   rP   rQ   r*   Y   s    	
6	
.	
r*   Fhidden_sizer-   r/   r0   r1   is_sf_swizzled_layoutr2   rS   c                    s  |rt jnt j}t|| ||||d}t }	tjj||	| fddd}
tjj||	| fddd}tjj|| fdd}tjjt j|	| d fddd}sMr]t }tjjt j|fdd}ntjjt j|	| | fddd}tjjt j|	| | fddd}tjj	dd}tjjt j
d	d
d}tj||
||||||tdt
d|dd dtjdtjdtjdtjdtjdtjdtjdtdtddf fdd}|S )z
    Get a compiled kernel closure that takes torch.Tensor directly.

    Uses TVM-FFI for efficient tensor passing without manual pointer construction.
    )r+   r,   r-   r.   r/   r0   r1   r2   r   r`   )stride_orderassumed_align)r5  r
   T)use_tvm_ffi_env_stream)r   r:   r   ư>z--enable-tvm-ffi)optionsxrwyss_unswizzledglobal_scalerv   rw   rS   Nc	                    sL   sr|  n| }	|tj}
 | |||
|	| |t|t|	 dS )z;Runtime API that passes torch tensors directly via TVM-FFI.N)flatten
contiguousviewrU   uint8r   r   )r9  r:  r;  r<  r=  r>  r?  rv   rw   s_tensory_uint8compiled_kernelr3  r2   rP   rQ   
tensor_api  s$   
z(_get_compiled_kernel.<locals>.tensor_api)r   Float16BFloat16r*   r~   sym_intruntimemake_fake_compact_tensorr	   make_fake_streamr   compiler   rU   r/  r)  float)r2  r-   r/   r0   r1   r3  r2   cutlass_dtype
kernel_objsym_mx_faker_fakew_fakey_fakesym_swizzled_sizes_fakes_unswizzled_fakestream_fakeglobal_scale_fakerH  rP   rF  rQ   _get_compiled_kernelr  s   


	
!r]  r7  r7   inputresidualweighty_fp4block_scaler?  rw   block_scale_unswizzledc           "   
   C   s  |   dk}|r#| j\}}}| || | }||| | }n| }|}|j\}}| j}|| dks9J d|dksAJ d|dv sIJ d|tjk}|rR|n|dkrXd	nd
}t| j}|d	krftj	ntj
}|| }|du r|rtj|||d ftj| jd}ntj||d ftj| jd}|du r|	s|
r|d d }|d d }d}|| | }tj|f|| jd}n|rtj|||f|| jd}ntj||f|| jd}|du r|rtj|||f|| jd}ntj||f|| jd}|r||| d}|	s	|
s	||| dn|}||| d} n|}|}|} |du r)tjdtj| jd}t||||||	|
}!|!| | | ||tj	| tj	| ||	 |
rW|||fS ||fS )a  
    Fused Add + RMS normalization + FP4 quantization using CuTe-DSL.

    Computes:
        1. ``residual = residual + input`` (in-place update)
        2. ``y = RMSNorm(residual) * weight``
        3. Optionally applies global scaling (``y = y / global_scale``)
        4. Quantizes ``y`` to FP4

    The residual tensor is modified in-place to contain the fused value.

    Parameters
    ----------
    input : torch.Tensor
        Input tensor, shape ``(batch_size, hidden_size)`` or ``(batch_size, seq_len, hidden_size)``.
        Must be ``torch.float16`` or ``torch.bfloat16``. Read-only.
    residual : torch.Tensor
        Residual tensor. Must have the same shape and dtype as ``input``.
        **Modified in-place** to contain ``residual + input``.
    weight : torch.Tensor
        Weight tensor for RMSNorm, shape ``(hidden_size,)``.
        Must have the same dtype as input.
    y_fp4 : torch.Tensor, optional
        Output tensor for quantized values in FP4_E2M1 format with dtype
        ``torch.float4_e2m1fn_x2``.
        Shape must be ``(batch_size, hidden_size // 2)`` or matching 3D input.
        If ``None``, will be allocated automatically.
    block_scale : torch.Tensor, optional
        Output tensor for per-block scale factors.

        - If ``is_sf_swizzled_layout=False`` and ``output_both_sf_layouts=False``: row-major
          layout with shape ``(batch_size, hidden_size // block_size)`` or matching 3D input.
        - If ``is_sf_swizzled_layout=True`` or ``output_both_sf_layouts=True``: swizzled layout
          for efficient tensor core access, with shape
          ``(batch_size * hidden_size // block_size,)`` flattened.
          The swizzle pattern uses 128x4 tiles where scales are arranged as:
          ``[m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)]``.

        Dtype should be ``torch.float8_e4m3fn`` for E4M3 format or ``torch.uint8``
        for UE8M0 format. If ``None``, will be allocated automatically.
    global_scale : torch.Tensor, optional
        Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
        If provided, the RMSNorm output is divided by this value before quantization:
        ``y = rmsnorm(h, w) / global_scale`` where ``h = input + residual``.
        This is used for NVFP4 format where a pre-computed global scale lifts
        per-block scales into optimal dynamic range.
        If ``None``, no global scaling is applied (equivalent to global_scale=1.0).
    eps : float
        Epsilon for numerical stability in RMSNorm. Default is ``1e-6``.
    block_size : int
        Number of elements per quantization block. Default is ``16``.

        - ``16``: NVFP4 format with E4M3 scale factors
        - ``32``: MXFP4 format with UE8M0 scale factors
    scale_format : str, optional
        Scale factor format: ``"e4m3"`` or ``"ue8m0"``.
        If ``None``, auto-selects based on ``block_size``:
        ``"e4m3"`` for block_size=16, ``"ue8m0"`` for block_size=32.
    is_sf_swizzled_layout : bool
        If ``True``, output scale factors in swizzled layout optimized for
        tensor core GEMM operations. The swizzle uses 128x4 tiles with the pattern:
        ``[m_tile_idx * k_tiles * 512 + k_tile_idx * 512 + outer_m * 16 + inner_m * 4 + inner_k]``
        where ``outer_m = row % 32``, ``inner_m = (row % 128) // 32``, etc.
        Default is ``False`` (row-major layout).
        Note: This parameter is ignored when ``output_both_sf_layouts=True``.
    output_both_sf_layouts : bool
        If ``True``, return both swizzled and unswizzled scale factors.
        When enabled, ``block_scale`` contains the swizzled layout and
        ``block_scale_unswizzled`` contains the row-major layout.
        This overrides ``is_sf_swizzled_layout``.
        Default is ``False``.
    block_scale_unswizzled : torch.Tensor, optional
        Output tensor for unswizzled per-block scale factors (row-major layout).
        Only used when ``output_both_sf_layouts=True``.
        Shape is ``(batch_size, hidden_size // block_size)`` or matching 3D input.
        Dtype should be ``torch.float8_e4m3fn`` for E4M3 format or ``torch.uint8``
        for UE8M0 format. If ``None``, will be allocated automatically when
        ``output_both_sf_layouts=True``.

    Returns
    -------
    Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
        When ``output_both_sf_layouts=False``:
            A tuple of ``(y_fp4, block_scale)``:

            - ``y_fp4``: Quantized FP4 values packed as uint8.
            - ``block_scale``: Per-block scale factors (swizzled or row-major based on
              ``is_sf_swizzled_layout``).

        When ``output_both_sf_layouts=True``:
            A tuple of ``(y_fp4, block_scale, block_scale_unswizzled)``:

            - ``y_fp4``: Quantized FP4 values packed as uint8.
            - ``block_scale``: Per-block scale factors in swizzled layout.
            - ``block_scale_unswizzled``: Per-block scale factors in row-major layout.

    Notes
    -----
    - Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics.
    - For block_size=16 (NVFP4): uses E4M3 scale factors (max value 448.0).
    - For block_size=32 (MXFP4): uses UE8M0 scale factors (power-of-2 scales).
    - FP4 E2M1 format has a max representable value of 6.0.
    r9   r   z+hidden_size must be divisible by block_sizer_   zhidden_size must be >= 64r6   zblock_size must be 16 or 32r3   r4   r5   Nr
   )r+   device   r`   r:   r;   r   )dimre   rB  rA  r+   rU   float16r   rd  rC  float8_e4m3fnemptyfloat4_e2m1fn_x2onesfloat32r]  )"r^  r_  r`  ra  rb  r?  rw   r-   r1   r3  r2   rc  is_3dBSr,   input_2dresidual_2d
batch_sizer2  r+   r/   actual_scale_formatr0   scale_dtyperJ   num_m_tilesrK   rL   swizzled_sizey_fp4_2dblock_scale_2dblock_scale_unswizzled_2drH  rP   rP   rQ   add_rmsnorm_fp4quant  s   x




	
	


r{  )r*   r{  r   )F)	NNNr7  r7   NFFN)9r'  	functoolstypingr   r   r   r   cutlass.cuter~   rU   r   r   r   r   r	   api_loggingr   
fp4_commonr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   r(   r)   r*   cacher)  r*  r+  r]  r/  rP  r{  __all__rP   rP   rP   rQ   <module>   s    |0       	
" y