o
    پia                     @   s  d dl Z d dlZd dlmZ d dlZd dlmZ d dlmZ d dl	m
Z
mZ d dlmZmZmZmZmZ d dlmZmZmZ d dlmZ d d	lmZ ej oRejjZed
ddd ZdejfddZ G dd deZ!ejfddZ"G dd deZ#G dd deZ$G dd deZ%ej&fddZ'G dd deZ(dej)d ej)d!ej)fd"d#Z*G d$d% d%eZ+d&d' Z,G d(d) d)eZ-d*d+ Z.G d,d- d-eZ/e0d.krej1d/d0 dS dS )1    N)	lru_cache)
SiluAndMul)	fused_moe)
TopKConfigselect_experts)per_tensor_quant_mla_fp8per_token_group_quant_fp8.per_token_group_quant_mla_deep_gemm_masked_fp8static_quant_fp8w8a8_block_fp8_matmul)input_to_float8mxfp8_group_quantizetriton_mxfp8_blockscaled_linear)is_sm100_supported)CustomTestCase   )maxsizec               
   C   s6   z	ddl m}  W | S  ty } ztd|d }~ww )Nr   )upcast_from_mxfp_torchz@MXFP8 dequantization requires triton_kernels with MXFP8 support.)$triton_kernels.numerics_details.mxfpr   	ExceptionRuntimeError)r   err r   N/home/ubuntu/.local/lib/python3.10/site-packages/sglang/test/test_block_fp8.py_get_triton_mxfp8_upcast   s   r   g|=c                 C   s   | j d | dksJ d|  sJ dt|}|j}|j}| |  | |}| jdddd j	|d
tj}|| }	||	 j	||d
|}
|
| j }
|	| j d	d | j d | f }	|
|	fS )
a/  Function to perform per-token-group quantization on an input tensor `x` using native torch.

    It converts the tensor values into float8 values and returns the
    quantized tensor along with the scaling factor used for quantization.
    Note that only `torch.float8_e4m3fn` is supported for now.
    r   z=the last dimension of `x` cannot be divisible by `group_size``x` is not contiguousT)dimkeepdimminr    maxN)shapeis_contiguoustorchfinfor    r"   reshapenumelabsclamptofloat32)x
group_sizeepsdtyper&   fp8_minfp8_maxx_amaxx_sx_qr   r   r    native_per_token_group_quant_fp8)   s   

&$r7   c                   @   sV   e Zd ZejejejgZg dZg dZ	g dZ
dgZedd Zdd Zd	d
 ZdS )TestPerTokenGroupQuantFP8   S                6  )@         r>   r   c                 C   "   t j s
tdt d d S NCUDA is not availablecudar%   rH   is_availableunittestSkipTestset_default_deviceclsr   r   r   
setUpClassL      

z$TestPerTokenGroupQuantFP8.setUpClassc                 C   s   t | t j|||d}t   t||\}}t||\}	}
W d    n1 s*w   Y  | t j|	t j	|t j	dd | t |
| d S )Nr0   皙?rtol)
r%   manual_seedrandinference_moder7   r   
assertTrueallcloser+   r,   )self
num_tokensdr0   r.   seedr-   ref_out	ref_scaleoutscaler   r   r   _per_token_group_quant_fp8R   s   

z4TestPerTokenGroupQuantFP8._per_token_group_quant_fp8c              	   C   x   t | j| j| j| j| jD ]+}| j|d |d |d |d |d d | j|  W d    n1 s4w   Y  qd S )Nr   r            )r\   r]   r0   r.   r^   )		itertoolsproduct
NUM_TOKENSDDTYPES
GROUP_SIZESEEDSsubTestrc   r[   paramsr   r   r   test_per_token_group_quant_fp8`   $   z8TestPerTokenGroupQuantFP8.test_per_token_group_quant_fp8N)__name__
__module____qualname__r%   halfbfloat16r,   rl   rj   rk   rm   rn   classmethodrP   rc   rr   r   r   r   r   r8   E   s    
r8   c           	      C   s   |   sJ d| dksJ dt|}|j}|j}| |  | jd  | jd }d| }|| j||d	|}|| j}||fS )zFunction to perform static quantization on an input tensor `x` using native torch.

    It converts the tensor values into float8 values and returns the
    quantized tensor along with the scaling factor used for quantization.
    r   r   zonly supports per-tensor scaler   g      ?r!   )
r$   r(   r%   r&   r    r"   r'   r#   r*   r+   )	r-   r5   r0   r&   r1   r2   r3   x_s_invr6   r   r   r   native_static_quant_fp8s   s   
 r{   c                   @   sN   e Zd ZejejejgZg dZg dZ	dgZ
edd Zdd Zdd	 Zd
S )TestStaticQuantFP8r9   r=   r   c                 C   rE   rF   rI   rN   r   r   r   rP      rQ   zTestStaticQuantFP8.setUpClassc                 C   s   t | t j|||d}t t jj}| | }t   t||\}}	t||dd\}
}	W d    n1 s9w   Y  | 	t j
|
t j|t jdd d S )NrR   T)repeat_scale      ?rT   )r%   rV   rW   r&   float8_e4m3fnr"   rX   r{   r
   rY   rZ   r+   r,   )r[   r\   r]   r0   r^   r-   r2   r5   r_   _ra   r   r   r   _static_quant_fp8   s   

z$TestStaticQuantFP8._static_quant_fp8c              	   C   n   t | j| j| j| jD ](}| j|d |d |d |d d | j|  W d    n1 s/w   Y  qd S )Nr   r   re   rf   )r\   r]   r0   r^   )rh   ri   rj   rk   rl   rn   ro   r   rp   r   r   r   test_static_quant_fp8       z(TestStaticQuantFP8.test_static_quant_fp8N)rt   ru   rv   r%   rw   rx   r,   rl   rj   rk   rn   ry   rP   r   r   r   r   r   r   r|      s    
r|   c                   @   s\   e Zd ZejejejgZg dZg dZ	ddgZ
dgZdgZedd Zdd	 Zd
d ZdS )TestPerTensorQuantMlaFP8r9   r=      r   r>   c                 C   rE   rF   rI   rN   r   r   r   rP      rQ   z#TestPerTensorQuantMlaFP8.setUpClassc                 C   s   t | t j||| || f|d}|j||gdd\}}	t   t|dd\}
}t|dd\}}W d    n1 sAw   Y  | |	  | t j
|t j|
t jdd | t 
|t j|t j d S )NrR   r   r   r   r   r~   rT   )r%   rV   rW   splitrX   r   	transposer   rY   r$   rZ   r+   r,   )r[   r\   r]   
last_d_extlast_dr0   r^   r-   x_subr   r_   ref_sra   out_sr   r   r   _per_tensor_quant_mla_fp8   s"   

z2TestPerTensorQuantMlaFP8._per_tensor_quant_mla_fp8c              	   C      t | j| j| j| j| j| jD ].}| j|d |d |d |d |d |d d | j	|  W d    n1 s9w   Y  qd S )Nr   r   re   rf   rg      )r\   r]   r   r   r0   r^   )
rh   ri   rj   rk   
LAST_D_EXTLAST_Drl   rn   ro   r   rp   r   r   r   test_per_tensor_quant_mla_fp8   (   z6TestPerTensorQuantMlaFP8.test_per_tensor_quant_mla_fp8N)rt   ru   rv   r%   rw   rx   r,   rl   rj   rk   r   r   rn   ry   rP   r   r   r   r   r   r   r      s    
r   c                   @   sZ   e Zd ZejejejgZdgZg dZ	ddgZ
dgZdgZedd Zdd Zd	d
 ZdS )*TestPerTokenGroupQuantMlaDeepGemmMaskedFP8rC   )r:   r;   r<   i @  r>   r   c                 C   rE   rF   rI   rN   r   r   r   rP      rQ   z5TestPerTokenGroupQuantMlaDeepGemmMaskedFP8.setUpClassc                 C   s   t | t j||||d}t  4 t||d\}}	t||\}
}}}}|
d d d |d d f }
|d d d |d d f }W d    n1 sIw   Y  | t j|
t j	|t j	ddd | t ||	 d S )NrR   g-q=rS   {Gz?)rU   atol)
r%   rV   rW   rX   r7   r	   rY   rZ   r+   r,   )r[   br\   r]   r0   r.   r^   r-   r_   r`   ra   rb   r   r   r   r   /_per_token_group_quant_mla_deep_gemm_masked_fp8   s    

zZTestPerTokenGroupQuantMlaDeepGemmMaskedFP8._per_token_group_quant_mla_deep_gemm_masked_fp8c              	   C   r   )Nr   r   re   rf   rg   r   )r   r\   r]   r0   r.   r^   )
rh   ri   Brj   rk   rl   rm   rn   ro   r   rp   r   r   r   3test_per_token_group_quant_mla_deep_gemm_masked_fp8  r   z^TestPerTokenGroupQuantMlaDeepGemmMaskedFP8.test_per_token_group_quant_mla_deep_gemm_masked_fp8N)rt   ru   rv   r%   rw   rx   r,   rl   r   rj   rk   rm   rn   ry   rP   r   r   r   r   r   r   r      s    
r   c              
      s`    tj  tj jd jd ksJ jdkr& r&|jdks(J t|dks0J |d |d  jd  d  jd ksKJ  jdd jdd ks[J    jd  }j\ jdd f } | jd  |jd  d  } d  ||jd ksJ |jd ksJ |f}	tj	|	tj j
d fddtD }
fd	dt|D }fd
dt|D }fddtD }tD ]7}t|D ]0}|
| }|| | }|| }|| || |  }|ddddf  t|| | 7  < qq| |S )zThis function performs matrix multiplication with block-wise quantization using native torch.

    It takes two input tensors `A` and `B` with scales `As` and `Bs`.
    The output is returned in the specified `output_dtype`.
    r   re   r   r   Nr0   devicec                    2   g | ]} d d | t |d  f qS Nr   r   .0i)AKblock_kr   r   
<listcomp>@     2 z0native_w8a8_block_fp8_matmul.<locals>.<listcomp>c                    s,   g | ]  fd dt D qS )c                    sD   g | ]}  t d   | t |d   f qS )r   r   r   )r   r   Nr   block_njr   r   r   B  s    z;native_w8a8_block_fp8_matmul.<locals>.<listcomp>.<listcomp>)range)r   )r   r   r   r   r   k_tiles)r   r   r   A  s    c                    r   r   r   )r   r   )Cr   r   r   r   r   K  r   c                    s$   g | ]} d d ||d f qS r   r   r   )Asr   r   r   L  s   $ )r+   r%   r,   r#   ndimr$   lenr(   r'   zerosr   r   matmult)r   r   r   Bs
block_sizeoutput_dtypeMorigin_C_shapen_tilesC_shapeA_tilesB_tilesC_tilesAs_tilesr   r   ar   csr   )	r   r   r   r   r   r   r   r   r   r   native_w8a8_block_fp8_matmul#  sF    $ 

.r   c                   @   s   e Zd Zes"ejejejgZg dZ	dd dD Z
ddggZdgZnejgZg dZ	g dZ
ddggZdgZed	d
 Zdd Zdd ZdS )TestW8A8BlockFP8Matmul)r   r:   r;   r>   r<   c                 C   s   g | ]}d D ]}||fqqS ))rD   r?   r@   i,  rA   r   )r   r   r   r   r   r   r   _  s    z!TestW8A8BlockFP8Matmul.<listcomp>)rC   r>   r   r?   iD  rA   rC   r   )rB   rC   r>   r   r?   )
)i@     )   r   )i   r   )i `  r   )r?   r>   )r   r<   )i   r   )r>   r   )r   i 	  )r   r>   c                 C   rE   rF   rI   rN   r   r   r   rP   z  rQ   z!TestW8A8BlockFP8Matmul.setUpClassc              
   C   s  |\}}t | d}t t j}	|	j|	j}
}t j||t jdd d |
 }|j||
d	t j}t j||t jdd d |
 }|j||
d	t j}|d |d }}|| d | }|| d | }t j||t jd| }t j||t jd| }t 
  t||||||}t||||||}W d    n1 sw   Y  | t t |	t j|	t j t t |	t j dk  d S )	Nr   rR   r~   re   r!   r   r   gMbP?)r%   rV   r&   r   r"   r    rW   r,   r*   r+   rX   r   r   rY   meanr)   )r[   r   NKr   	out_dtyper^   r   r   factor_for_scalefp8_infor2   r1   A_fp32A_fp8B_fp32B_fp8r   r   r   r   r   r   r_   ra   r   r   r   _w8a8_block_fp8_matmul  s6   

"z-TestW8A8BlockFP8Matmul._w8a8_block_fp8_matmulc              	   C   rd   )Nr   r   re   rf   rg   )r   NKsr   r   r^   )	rh   ri   r   r   
BLOCK_SIZE
OUT_DTYPESrn   ro   r   rp   r   r   r   test_w8a8_block_fp8_matmul  rs   z1TestW8A8BlockFP8Matmul.test_w8a8_block_fp8_matmulN)rt   ru   rv   _is_cudar%   r,   rw   rx   r   r   r   r   rn   ry   rP   r   r   r   r   r   r   r   Z  s$    


!r   qscale_u8returnc                 C   s   t  }|| |tjddS )Nr   )axis)r   r%   r,   )r   r   r   r   r   r   _mxfp8_group_dequant  s   r   c                   @   sF   e Zd ZejgZg dZg dZdgZe	dd Z
dd Zdd	 Zd
S )TestMXFP8DenseLinear)r      rC         rD   ))rD   r>   )i  r   )r>   r<   )i   r   r   c                 C   s2   t j s
tdt stdt d d S )NrG   z!MXFP8 requires Blackwell (SM100+)rH   )r%   rH   rJ   rK   rL   r   rM   rN   r   r   r   rP     s
   


zTestMXFP8DenseLinear.setUpClassc              
   C   sn  |\}}t | t j||ft jdd }||}t j||ft jdd }	t|	\}
}t  7 t|t j\}}t||}t|
|}t ||	 |}t
||
|d}t
||
|||d}W d    n1 sjw   Y  | t t |t j|t j t t |t j dk  | t t |t j|t j t t |t j dk  d S )NrR   rg   )inputweightweight_scale)r   r   r   input_scaler   {Gz?)r%   rV   randnr,   r+   r   rX   r   r   r   r   rY   r   r)   )r[   r   r   r0   r^   r   r   
input_fp32
input_fp16weight_fp32weight_qweight_scale_u8q_inputinput_scale_u8a_dqb_dqr_   ra   out_prequantr   r   r   _mxfp8_dense_linear  sP   




"z(TestMXFP8DenseLinear._mxfp8_dense_linearc              	   C   r   )Nr   r   re   rf   )r   r   r0   r^   )rh   ri   r   r   rl   rn   ro   r   rp   r   r   r   test_mxfp8_dense_linear  r   z,TestMXFP8DenseLinear.test_mxfp8_dense_linearN)rt   ru   rv   r%   rx   rl   r   r   rn   ry   rP   r   r   r   r   r   r   r     s    
*r   c              	   C   sp  | j \}}	| |d|	d|dd|	} tj|| |j d | j| jd}
tj|dtj	d}t
||\}}|d}|d}|d |d }}t| |\}}|tj	}t|j d D ]A}||k}| rt|| || || || || jd}t |}t||\}}|tj	}t||| ||| || jd|
|< q]|
|d|j d ||dd|
j jddS )zQThis function performs fused moe with block-wise quantization using native torch.r   r   r   )r   r0   r   r   r   )r#   viewrepeatr'   r%   r   r0   r   softmaxr,   topkr7   r+   r   sumr   r   forward_native)r   w1w2w1_sw2_sscorer   block_shaper   rk   ra   topk_weighttopk_idsr   r   a_qa_sr   mask	inter_outact_out	act_out_q	act_out_sr   r   r   torch_w8a8_block_fp8_moe  s4   
  


(r  c                   @   s   e Zd ZejejejgZg dZg dZ	g dZ
ddgZddgZddgdd	gd	dgd	d	ggZd
gZedd Zdd Zdd ZdS )TestW8A8BlockFP8FusedMoE)r   !   rB      i   )rC   r   r<   )rD   r?   r@         re      rB   rC   r   c                 C   rE   rF   rI   rN   r   r   r   rP   1  rQ   z#TestW8A8BlockFP8FusedMoE.setUpClassc	                 C   s  t | d}	t t j}
|
j|
j}}t j||f|dd }t j|d| |ft jdd d | }|j	||d
t j}t j|||ft jdd d | }|j	||d
t j}|d |d }}d| | d | }|| d | }|| d | }|| d | }t j|||ft jd|	 }t j|||ft jd|	 }t j||f|d}t  * t||||||||}t||t|d	d
d}t||||d|||d}W d    n1 sw   Y  | t t |
t j|
t j t t |
t j dk  d S )Nr   rR   
   re   r~   r!   r   r   F)top_krenormalize)hidden_statesrouter_logitstopk_configT)use_fp8_w8a8w1_scalew2_scaler   r   )r%   rV   r&   r   r"   r    r   rW   r,   r*   r+   rX   r  r   r   r   rY   r   r)   )r[   r   r   r   Er   r   r0   r^   r   r   r2   r1   r   w1_fp32r   w2_fp32r   r   r   
n_tiles_w1
n_tiles_w2
k_tiles_w1
k_tiles_w2r   r   r   r_   topk_outputra   r   r   r   _w8a8_block_fp8_fused_moe7  sd   
&"

"z2TestW8A8BlockFP8FusedMoE._w8a8_block_fp8_fused_moec                 C   s   t | j| j| j| j| j| j| j| j	D ]4}| j
|d |d |d |d |d |d |d |d d	 | j|  W d    n1 sCw   Y  qd S )
Nr   r   re   rf   rg   r   r  r:   )r   r   r   r  r   r   r0   r^   )rh   ri   r   r   r   r  TOP_KSr   rl   rn   ro   r  rp   r   r   r   test_w8a8_block_fp8_fused_moeq  s0   

z6TestW8A8BlockFP8FusedMoE.test_w8a8_block_fp8_fused_moeN)rt   ru   rv   r%   r,   rw   rx   rl   r   r   r   r  r   r   rn   ry   rP   r  r!  r   r   r   r   r  &  s    
:r  c              	   C   sj   |j \}}}| j \}}	}tj||	|f|| jd}
t|D ]}t| | || || || ||d|
|< q|
S )zKThis function performs bmm with block-wise quantization using native torch.r   r   )r#   r%   emptyr   r   r   )r   r  ww_sr   r   r   r   r   r   ra   r   r   r   r   torch_w8a8_block_fp8_bmm  s   r%  c                   @   s^   e Zd ZejgZg dZddgZddgZdgZ	ddggZ
dgZedd Zdd Zd	d
 ZdS )TestW8A8BlockFP8BatchedDeepGemm)r   r	  rB   r
  i    rC   r>   r   c                 C   sH   t j s
tdzdd l}W n ty   tdw t d d S )NrG   r   zDeepGEMM is not availablerH   )r%   rH   rJ   rK   rL   	deep_gemmImportErrorrM   )rO   r'  r   r   r   rP     s   


z*TestW8A8BlockFP8BatchedDeepGemm.setUpClassc               
   C   sb  t | d}t t j}	|	j|	j}
}t j|||ft jdd }|j||
d	t j}t j
|||ft jdd d |
 }|j||
d	t j}|d |d }}|| d | }|| d | }t j
|||ft jd| }t j
|||ft jd| }|||d	 d
 d
 |}|||d	 d
 d
 |}t j||d	 d
 d
 |f|d}||d d d |d d f< ||d d d |d d f< t j|f|t jd}|}||f}||f}ddlm} t  & t||||||}|||||| |d d d |d d f }W d    n	1 sw   Y  | t t |	t j|	t j t t |	t j dk  d S )Nr   rR   r  r!   r~   re   r   r   r   rD   )fp8_m_grouped_gemm_nt_maskedg-C6?)r%   rV   r&   r   r"   r    r   r,   r*   r+   rW   	new_emptyr"  fullintr'  r)  rX   r%  rY   r   r)   ) r[   r   r   r   r   r   r0   r^   r   r   r2   r1   a_fp32r   w_fp32r#  r   r   	n_tiles_w	k_tiles_wr$  r  aeae_soemasked_m
expected_mlhsrhsr)  r_   ra   r   r   r   !_w8a8_block_fp8_batched_deep_gemm  sT   
" 
"zATestW8A8BlockFP8BatchedDeepGemm._w8a8_block_fp8_batched_deep_gemmc              
   C   s   t | j| j| j| j| j| j| jD ]1}| j	|d |d |d |d |d |d |d d | j
|  W d    n1 s>w   Y  qd S )	Nr   r   re   rf   rg   r   r  )r   r   r   r   r   r0   r^   )rh   ri   r   r   r   BATCHr   rl   rn   ro   r8  rp   r   r   r   %test_w8a8_block_fp8_batched_deep_gemm  s,   		zETestW8A8BlockFP8BatchedDeepGemm.test_w8a8_block_fp8_batched_deep_gemmN)rt   ru   rv   r%   rx   rl   r   r   r   r9  r   rn   ry   rP   r8  r:  r   r   r   r   r&    s    

	4r&  __main__re   )	verbosity)2rh   rK   	functoolsr   r%   sglang.srt.layers.activationr   0sglang.srt.layers.moe.fused_moe_triton.fused_moer   sglang.srt.layers.moe.topkr   r   )sglang.srt.layers.quantization.fp8_kernelr   r   r	   r
   r   (sglang.srt.layers.quantization.fp8_utilsr   r   r   sglang.srt.utilsr   sglang.test.test_utilsr   rH   rJ   versionr   r   r   r7   r8   r{   r|   r   r   float16r   r   Tensorr   r   r  r  r%  r&  rt   mainr   r   r   r   <module>   s@    

.+797YN d^