o
    (i                     @   s  d Z ddlZddlZzddlmZ W n ey$   dadefddZY nw G dd dejZ	e	j
d	iZG d
d dejZejdejdejdejdejdejdejdejdejdejdejdiZi ejdejdejdejdejdejdejdej dej!dej"d ej#d!ej$d"ejd#ej%d$ejdejdejdi ejdejdej&d%ej'd&ej(d'ej)d(ej*d)ejd*ej+d+ejd,ej,d-ejd.ej-d/ej.d0ejd1ej/d2ejd3ej0d4ej1d5ej2d6ej3d7ej4d8ej5d9ej6d:ej7d;ej8d<ej9d=ej:d>ej;d?iZ<i ejdejd@ejdAejdBejdCejdDejdEej dFej!dGej"dHej#dIej$dJejdKej%dLejdMejdNejdOi ejdPejdQej&dRej'dSej(dTej)dUej*dVejdWej+dXejdYej,dZejd[ej-d\ej.d]ejd^ej/d_ejd`ej0daej1dbej2dcej3ddej4deej5dfej6dgej7dhej8diej9djej:dkej;dliZ=i ejdejdmejdnejdoejdpejdqejdrej dsej!dnej"doej#dpej$dqejdrej%dsejdpejdpejdpi ejdtejdoej&dtej'dtej(doej)dpej*dpejdqej+dqejdrej,drejdsej-drej.drejdsej/drejduej0doej1dpej2dqej3drej4dsej5duej6doej7dpej8dqej9drej:dsej;duiZ>G dvdw dwejZ?e?j@dxe?jAdyiZBG dzd{ d{ejZCeCjDd|eCjEd}iZFeCjDd~eCjEdiZGejej-fejejfejejfgZHdd ZIdd ZJdd ZKdd ZLdd ZMdd ZNdd ZOG dd dejZPG dd dejZQeQjRdeQjSdeQjTdeQjUdeQjVdeQjWdeQjXdeQjYdeQjZdeQj[deQj\deQj]diZ^G dd dejZ_i e_j`de_jade_jbde_jcde_jdde_jede_jfde_jgde_jhde_jide_jjde_jkde_jlde_jmde_jnde_jode_jpde_jqde_jrde_jsdiZte_j`e_jae_jae_j`e_jbe_jce_jce_jbe_jde_jee_jee_jde_jfe_jge_jge_jfe_jie_jii	Zui e_j`de_jbde_jdde_jfde_jade_jcde_jede_jgde_jhde_jide_jjde_jkde_jlde_jmde_jode_jndÓe_jpdēe_jqde_jrde_jsdiZve_j`eCjDfde_j`eCjEfde_jaeCjDfde_jaeCjEfdiZwG ddɄ dejZxi exjydʓexjzd˓exj{d̓exj|d͓exj}dΓexj~dϓexjdГexjdѓexjdғexjdӓexjdԓexjdՓexjd֓exjdדexjdؓexjdٓexjdړi exjdۓexjdܓexjdݓexjdޓexjdߓexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdi exjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjd exjdiZi exjydexjzdexj{dexj|dexj}dexj~dexjdexjd	exjd
exjdexjdexjdexjdexjd
exjdexjdexjdi exjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdi exjd
exjdexjd	exjdexjd
exjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdexjdiZG dd dejZi ejydejd ejd!ejd"ejd#ejd$ejd%ejd&ejd'ejd(ejd)ejd*ejd+ejd,ejd-ejd.Zi ejydejdejd/ejd/ejd/ejd/ejd/ejd/ejd0ejd0ejdejd0ejd1ejd2ejd0ejd0ZG d3d4 d4ejZejd5ejÐd6iZĐd7d8 ZŐd9d: ZG d;d< d<ejZejdejɐd=ejʐd>iZejȐdejɐdejʐd?iZG d@dA dAejZejΐdBejϐdCiZejΐdDejϐdEiZG dFdG dGejZejӐdHejԐdIiZejӐdJejԐdKiZG dLdM dMejZejؐdNejِdOiZejؐdPejِdQiZG dRdS dSejZejݐdTejސdUejߐdVejdWejdXiZejݐdYejސdZejߐd[ejd\ejd]iZG d^d_ d_ejZejd`ejdaejdbejdcejddejdeejdfiZG dgdh dhejZdidjdjdkdldmdndodpZdqdqdsdrdsdrdsdtduZdvdw ZG dxdy dyejZejd`ejdzejd`ejd`ejdzejd{ejd|ejd}ejd`ejd}ejd}ejd`ejd}iZG d~d dejZejdaiZ G dd dejZejdciZG dd dejZejddiZG dd dejZejdejdiZG dd dejZG dd dejZ	e	j
de	jde	jde	jde	jde	jde	jde	jde	jʐdi	ZG dd dejZejdejdiZejdejdiZG dd dejZejdejdejdiZejdejdejdiZG dd dejZG dd dejZ e j!de j"de j#de j$de j%diZ&e j!de j"de j#de j$de j%diZ'G dd dejZ(e(j)de(j*de(j+diZ,e(j)de(j*de(j+diZ-G dd dejZ.e.j/de.j0de.j1de.j2diZ3e.j/de.j0de.j1de.j2diZ4g dZ5G ddÄ dÃZ6G dĐdń dŃZ7G dƐdǄ dǃZ8G dȐdɄ dɃZ9G dʐd˄ d˃Z:G d̐d̈́ d̓Z;dΐdτ Z<G dАdф dejZ=G dҐdӄ dejZ>dS (  z;
Data types and tags used for emitting CUTLASS C++ kernels
    N)autoreturnc                  C   s   t } t d7 a | S )N   )__cutlass_library_auto_enum)i r   h/home/ubuntu/veenaModal/venv/lib/python3.10/site-packages/flashinfer/jit/gemm/cutlass/cutlass_library.py	enum_auto2   s   r	   c                   @      e Zd Ze ZdS )GeneratorTargetN)__name__
__module____qualname__r	   Libraryr   r   r   r   r   =       
r   libraryc                   @   s&  e Zd Ze Ze Ze Ze Ze Ze Z	e Z
e Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Z e Z!e Z"e Z#e Z$e Z%e Z&e Z'e Z(e Z)e Z*e Z+e Z,e Z-e Z.e Z/e Z0e Z1e Z2dS )DataTypeN)3r   r   r   r	   voidb1u2u4u8u16u32u64s2s4s8s16s32s64e4m3e5m2f8f6f4e3m2e2m3e2m1ue8m0ue4m3f16bf16f32tf32f64cf16cbf16cf32ctf32cf64cs2cs4cs8cs16cs32cs64cu2cu4cu8cu16cu32cu64invalidr   r   r   r   r   I   s`    
r   r   r!   r"   hsdczr#   r$   r%   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r'   r&   r(   r)   r*   r+   r,   r-   r.   r/   r0   r1   r2   r3   r4   r;   r<   r=   r>   r?   r@   r5   r6   r7   r8   r9   r:   zcutlass::uint1b_tzcutlass::uint2b_tzcutlass::uint4b_tuint8_tuint16_tuint32_tuint64_tzcutlass::int2b_tzcutlass::int4b_tint8_tint16_tint32_tint64_tzcutlass::float_e4m3_tzcutlass::float_e5m2_tz%cutlass::type_erased_dynamic_float8_tz%cutlass::type_erased_dynamic_float6_tz%cutlass::type_erased_dynamic_float4_tzcutlass::float_e2m3_tzcutlass::float_e3m2_tzcutlass::float_e2m1_tzcutlass::float_ue8m0_tzcutlass::float_ue4m3_tzcutlass::half_tzcutlass::bfloat16_tfloatzcutlass::tfloat32_tdoublez!cutlass::complex<cutlass::half_t>z%cutlass::complex<cutlass::bfloat16_t>zcutlass::complex<float>z%cutlass::complex<cutlass::tfloat32_t>zcutlass::complex<double>z#cutlass::complex<cutlass::uint2b_t>z#cutlass::complex<cutlass::uint4b_t>z"cutlass::complex<cutlass::uint8_t>z#cutlass::complex<cutlass::uint16_t>z#cutlass::complex<cutlass::uint32_t>z#cutlass::complex<cutlass::uint64_t>z"cutlass::complex<cutlass::int2b_t>z"cutlass::complex<cutlass::int4b_t>z!cutlass::complex<cutlass::int8_t>z"cutlass::complex<cutlass::int16_t>z"cutlass::complex<cutlass::int32_t>z"cutlass::complex<cutlass::int64_t>r                   @         c                   @      e Zd Ze Ze ZdS )BlasModeN)r   r   r   r	   	symmetric	hermitianr   r   r   r   rZ   !      
rZ   zcutlass::BlasMode::kSymmetriczcutlass::BlasMode::kHermitianc                   @   rY   )ComplexTransformN)r   r   r   r	   noneconjr   r   r   r   r^   .  r]   r^   z cutlass::ComplexTransform::kNonez%cutlass::ComplexTransform::kConjugatezcute::identityzcute::conjugatec                    s   t  fddtD S )Nc                 3   s    | ]	\}} |kV  qd S Nr   ).0_rrE   	data_typer   r   	<genexpr>I  s    zis_complex.<locals>.<genexpr>)anyRealComplexBijectionrd   r   rd   r   
is_complexH  s   ri   c                 C      | t jt jfv S ra   )GemmKindBlockScaledUniversal3xGroupedBlockScaledUniversal3x	gemm_kindr   r   r   is_block_scaledL     rp   c                 C   rj   ra   )rk   BlockwiseUniversal3xGroupedBlockwiseUniversal3xrn   r   r   r   is_blockwiseS  rq   rt   c                 C   s   | t jt jt jfv S ra   )rk   GroupedUniversal3xrm   rs   rn   r   r   r   
is_groupedZ  s
   rv   c                 C   s$   t D ]\}}| |kr|  S qtjS ra   rh   r   rA   )	real_typerrE   r   r   r   get_complex_from_realc  
   rz   c                 C   s$   t D ]\}}| |kr|  S qtjS ra   rw   )complex_typery   rE   r   r   r   get_real_from_complexk  r{   r}   c                 C   s*   | t jkrdS t|  dkrdS dt|   S )Nr   rW   rX   )r   r   DataTypeSizerd   r   r   r   get_tma_alignments  s
   
r   c                   @   rY   )ComplexMultiplyOpN)r   r   r   r	   multiply_addgaussianr   r   r   r   r   }  r]   r   c                   @   sT   e Zd Ze Ze Ze Ze Ze Ze Z	e Z
e Ze Ze Ze Ze ZdS )MathOperationN)r   r   r   r	   r   multiply_add_saturatemultiply_add_mixed_input_upcastxor_popcand_popcmultiply_add_fast_bf16multiply_add_fast_f16multiply_add_fast_f32multiply_add_complex_fast_f32multiply_add_complexmultiply_add_complex_gaussianmultiply_add_fast_accumr   r   r   r   r     s    
r   zcutlass::arch::OpMultiplyAddz$cutlass::arch::OpMultiplyAddSaturatez,cutlass::arch::OpMultiplyAddMixedInputUpcastzcutlass::arch::OpXorPopczcutlass::arch::OpAndPopcz$cutlass::arch::OpMultiplyAddFastBF16z#cutlass::arch::OpMultiplyAddFastF16z#cutlass::arch::OpMultiplyAddFastF32z*cutlass::arch::OpMultiplyAddComplexFastF32z#cutlass::arch::OpMultiplyAddComplexz+cutlass::arch::OpMultiplyAddGaussianComplexz%cutlass::arch::OpMultiplyAddFastAccumc                   @   s   e Zd Ze Ze Ze Ze Ze Ze Z	e Z
e Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze ZdS )
LayoutTypeN)r   r   r   r	   ColumnMajorRowMajorColumnMajorInterleaved2RowMajorInterleaved2ColumnMajorInterleaved32RowMajorInterleaved32ColumnMajorInterleaved64RowMajorInterleaved64	TensorNWC
TensorNHWCTensorNDHWC
TensorNCHWTensorNGHWCTensorNC32HW32TensorNC64HW64TensorC32RSK32TensorC64RSK64	TensorKCS
TensorKCSRTensorKCSRTr   r   r   r   r     s*    
r   zcutlass::layout::ColumnMajorzcutlass::layout::RowMajorz*cutlass::layout::ColumnMajorInterleaved<2>z'cutlass::layout::RowMajorInterleaved<2>z+cutlass::layout::ColumnMajorInterleaved<32>z(cutlass::layout::RowMajorInterleaved<32>z+cutlass::layout::ColumnMajorInterleaved<64>z(cutlass::layout::RowMajorInterleaved<64>zcutlass::layout::TensorNWCzcutlass::layout::TensorNHWCzcutlass::layout::TensorNDHWCzcutlass::layout::TensorNCHWzcutlass::layout::TensorNGHWCz!cutlass::layout::TensorNCxHWx<32>z!cutlass::layout::TensorCxRSKx<32>z!cutlass::layout::TensorNCxHWx<64>z!cutlass::layout::TensorCxRSKx<64>zcutlass::layout::TensorKCSzcutlass::layout::TensorKCSRzcutlass::layout::TensorKCSRTnn2n32n64tt2t32t64nwcnhwcndhwcnchwnghwcnc32hw32nc64hw64c32rsk32c64rsk64kcskcsrkcsrtc                   @   s\  e Zd Ze Ze Ze Ze Ze Ze Z	e Z
e Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Ze Z e Z!e Z"e Z#e Z$e Z%e Z&e Z'e Z(e Z)e Z*e Z+e Z,e Z-e Z.e Z/e Z0e Z1e Z2e Z3e Z4e Z5e Z6e Z7e Z8e Z9e Z:e Z;dS )KernelScheduleTypeN)<r   r   r   r	   ScheduleAuto
MultistageCpAsyncWarpSpecializedCpAsyncWarpSpecializedPingpong!CpAsyncWarpSpecializedCooperativeTmaTmaWarpSpecializedTmaWarpSpecializedPingpongTmaWarpSpecializedCooperativeTmaWarpSpecializedFP8FastAccum)TmaWarpSpecializedCooperativeFP8FastAccum&TmaWarpSpecializedPingpongFP8FastAccumImplicitTmaWarpSpecializedSm90%PtrArrayTmaWarpSpecializedCooperative1PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum"PtrArrayTmaWarpSpecializedPingpong.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum&BlockwiseTmaWarpSpecializedCooperative.PtrArrayBlockwiseTmaWarpSpecializedCooperativeTmaWarpSpecialized1SmSm100TmaWarpSpecialized2SmSm100"ImplicitTmaWarpSpecialized1SmSm100"ImplicitTmaWarpSpecialized2SmSm100"PtrArrayTmaWarpSpecialized1SmSm100"PtrArrayTmaWarpSpecialized2SmSm100-PtrArrayTmaWarpSpecialized1SmBlockScaledSm100-PtrArrayTmaWarpSpecialized2SmBlockScaledSm100&PtrArrayNvf4TmaWarpSpecialized1SmSm100&PtrArrayNvf4TmaWarpSpecialized2SmSm100&PtrArrayMxf4TmaWarpSpecialized1SmSm100&PtrArrayMxf4TmaWarpSpecialized2SmSm100*PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100*PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 SparseTmaWarpSpecialized1SmSm100 SparseTmaWarpSpecialized2SmSm100%BlockScaledTmaWarpSpecialized1SmSm100%BlockScaledTmaWarpSpecialized2SmSm100"Mxf8f6f4TmaWarpSpecialized1SmSm100"Mxf8f6f4TmaWarpSpecialized2SmSm100#BlockwiseTmaWarpSpecialized1SmSm100#BlockwiseTmaWarpSpecialized2SmSm100+PtrArrayBlockwiseTmaWarpSpecialized1SmSm100+PtrArrayBlockwiseTmaWarpSpecialized2SmSm100Mxf4TmaWarpSpecialized1SmSm100Mxf4TmaWarpSpecialized2SmSm100Nvf4TmaWarpSpecialized1SmSm100Nvf4TmaWarpSpecialized2SmSm100*Mxf8f6f4TmaWarpSpecializedCooperativeSm120'Mxf8f6f4TmaWarpSpecializedPingpongSm120&Nvf4TmaWarpSpecializedCooperativeSm120#Nvf4TmaWarpSpecializedPingpongSm120&Mxf4TmaWarpSpecializedCooperativeSm120#Mxf4TmaWarpSpecializedPingpongSm120.F8f6f4SparseTmaWarpSpecializedCooperativeSm120+BlockwiseTmaWarpSpecializedCooperativeSm120(BlockwiseTmaWarpSpecializedPingpongSm120r   r   r   r   r     sr    
r   z-cutlass::gemm::collective::KernelScheduleAutozcutlass::gemm::KernelMultistagez+cutlass::gemm::KernelCpAsyncWarpSpecializedz3cutlass::gemm::KernelCpAsyncWarpSpecializedPingpongz6cutlass::gemm::KernelCpAsyncWarpSpecializedCooperativezcutlass::gemm::KernelTmaz'cutlass::gemm::KernelTmaWarpSpecializedz/cutlass::gemm::KernelTmaWarpSpecializedPingpongz2cutlass::gemm::KernelTmaWarpSpecializedCooperativez3cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumz>cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccumz;cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccumz3cutlass::conv::KernelImplicitTmaWarpSpecializedSm90zEcutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccumz/cutlass::gemm::KernelTmaWarpSpecialized1SmSm100z/cutlass::gemm::KernelTmaWarpSpecialized2SmSm100z7cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100z7cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100z7cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100z7cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100z5cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100z5cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100z:cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100z:cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100z7cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100z7cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100z8cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100z8cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100z@cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100z@cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100z3cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100z3cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100z3cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100z3cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100z:cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativezFcutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccumz7cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongzCcutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccumzMcutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccumzBcutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100zBcutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100z;cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100z;cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100z;cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100z;cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100z?cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100z?cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100z4cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120z<cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120z0cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120z8cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120z0cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120z8cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120z.cutlass::gemm::KernelScheduleSparseF8f6f4Sm120z@cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120z=cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120 _cpasync_cpasync_warpspecialized!_cpasync_warpspecialized_pingpong$_cpasync_warpspecialized_cooperative_unspecialized_warpspecialized_warpspecialized_pingpong_warpspecialized_cooperative_warpspecialized_fp8_fastaccum*_warpspecialized_cooperative_fp8_fastaccum'_warpspecialized_pingpong_fp8_fastaccum_1sm_2sm_q_1sm_q_2sm_o_vs32_1sm_o_vs32_2sm_o_vs16_1sm_o_vs16_2sm_cooperative_q_pingpong_q_cooperative_o_vs16_pingpong_o_vs16_cooperative_o_vs32_pingpong_o_vs32_qc                   @   sl   e Zd Ze Ze Ze Ze Ze Ze Z	e Z
e Ze Ze Ze Ze Ze Ze Ze Ze ZdS )EpilogueScheduleTypeN)r   r   r   r	   r   EpilogueTransposedNoSmemWarpSpecializedPtrArrayNoSmemWarpSpecializedNoSmemWarpSpecialized1SmNoSmemWarpSpecialized2Sm PtrArrayNoSmemWarpSpecialized1Sm PtrArrayNoSmemWarpSpecialized2Smr   r   TmaWarpSpecialized1SmTmaWarpSpecialized2SmPtrArrayTmaWarpSpecialized1SmPtrArrayTmaWarpSpecialized2Smr   r   r   r   r   r   r    s"    
r  z3cutlass::epilogue::collective::EpilogueScheduleAutoz!cutlass::gemm::EpilogueTransposedz(cutlass::epilogue::NoSmemWarpSpecializedz0cutlass::epilogue::PtrArrayNoSmemWarpSpecializedz+cutlass::epilogue::NoSmemWarpSpecialized1Smz+cutlass::epilogue::NoSmemWarpSpecialized2Smz3cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Smz3cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Smz%cutlass::epilogue::TmaWarpSpecializedz0cutlass::epilogue::TmaWarpSpecializedCooperativez(cutlass::epilogue::TmaWarpSpecialized1Smz(cutlass::epilogue::TmaWarpSpecialized2Smz0cutlass::epilogue::PtrArrayTmaWarpSpecialized1Smz0cutlass::epilogue::PtrArrayTmaWarpSpecialized2Smz8cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperativez5cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong_epi_nosmem_epi_tma_tma_1sm_tma_2smc                   @   rY   )EpilogueFunctor3xN)r   r   r   r	   LinearCombination!LinearCombinationBlockScaleFactorr   r   r   r   r    r]   r  z,cutlass::epilogue::fusion::LinearCombinationz2cutlass::epilogue::fusion::LinCombBlockScaleFactorc              
   C   s,   | t jt jt jt jt jt jt jt jt j	f	v S ra   )
r  r   r   r   r  r  r  r  r   r   )epilogue_schedule_typer   r   r   is_tma_epilogue  s   r   c                 C   s   |s| S i t jt jt jt jt jt jt jt jt j	t j
tjtjtjtjtjtjt jt jt jt jt jt jt jt jt jt jt jt jt jt jt jt jt jt j t j!t j"tj#tj$tj%tj&i}||  S ra   )'r   r   r   r   r   r   r   r   r   r   r   r  r   r  r  r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  )schedulegroupedgroup_schedule_mapr   r   r   to_grouped_schedule  sT   	r$  c                   @      e Zd Ze Ze Ze ZdS )TileSchedulerTypeN)r   r   r   r	   Default
PersistentStreamKr   r   r   r   r&  =      
r&  z"cutlass::gemm::PersistentSchedulerzcutlass::gemm::StreamKScheduler	_stream_kc                   @   rY   )SideModeN)r   r   r   r	   LeftRightr   r   r   r   r,  U  r]   r,  zcutlass::SideMode::kLeftzcutlass::SideMode::kRightlsrsc                   @   rY   )FillModeN)r   r   r   r	   LowerUpperr   r   r   r   r1  g  r]   r1  zcutlass::FillMode::kLowerzcutlass::FillMode::kUpperluc                   @   rY   )DiagTypeN)r   r   r   r	   NonUnitUnitr   r   r   r   r6  y  r]   r6  zcutlass::DiagType::kNonUnitzcutlass::DiagType::kUnitnuunc                   @   s*   e Zd Ze Ze Ze Ze Ze ZdS )OpcodeClassN)	r   r   r   r	   SimtTensorOpWmmaTensorOpSparseTensorOpBlockScaledTensorOpr   r   r   r   r;    s    
r;  simttensoropwmma_tensorop
sptensorop
bstensoropzcutlass::arch::OpClassSimtzcutlass::arch::OpClassTensorOpz"cutlass::arch::OpClassWmmaTensorOpz$cutlass::arch::OpClassSparseTensorOpz)cutlass::arch::OpClassBlockScaledTensorOpc                   @   s6   e Zd Ze Ze Ze Ze Ze Ze Z	e Z
dS )OperationKindN)r   r   r   r	   GemmRankKRank2KTrmmSymmConv2dConv3dr   r   r   r   rF    s    
rF  gemmrank_krank_2ktrmmsymmconv2dconv3dc                   @   r
   )TargetN)r   r   r   r	   r   r   r   r   r   rU    r   rU  maxwellpascalvoltaturingampereadahopper)2   <   =   F   K   P   Y   Z   `      c      )r`  H   ra  rb  V   W   rc  rd  c                 C   sP   | }d}|r&d}|  D ]\}}d| }t|||}||kr!d}|}q|s|S )NTFz\$\{%s\})itemsresub)templatevaluestextchangedkeyvalueregexnewtextr   r   r   SubstituteTemplate  s   rw  c                   @   sZ   e Zd Ze Ze Ze Ze Ze Ze Z	e Z
e Ze Ze Ze Ze Ze ZdS )rk   N)r   r   r   r	   rG  Sparse	UniversalUniversal3xSparseUniversal3xPlanarComplexPlanarComplexArrayGroupedrl   ru   rm   rr   rs   r   r   r   r   rk     s    
rk   spgemmgemm_planar_complexgemm_planar_complex_arraygemm_groupedc                   @   r
   )	RankKKindNr   r   r   r	   ry  r   r   r   r   r    r   r  c                   @   r
   )TrmmKindNr  r   r   r   r   r    r   r  c                   @   r
   )SymmKindNr  r   r   r   r   r  %  r   r  c                   @   rY   )EpilogueFunctorN)r   r   r   r	   r  LinearCombinationClampr   r   r   r   r  .  r]   r  z,cutlass::epilogue::thread::LinearCombinationz1cutlass::epilogue::thread::LinearCombinationClampc                   @   r%  )MixedInputModeN)r   r   r   r	   ConvertOnly	ScaleOnlyScaleWithZeroPointr   r   r   r   r  ;  r*  r  c                   @   sB   e Zd Ze Ze Ze Ze Ze Ze Z	e Z
e Ze ZdS )SwizzlingFunctorN)r   r   r   r	   	Identity1	Identity2	Identity4	Identity8
HorizontalStridedDgradIdentity1StridedDgradIdentity4StridedDgradHorizontalr)  r   r   r   r   r  B  s    
r  z=cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>z=cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>z=cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>z=cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>z<cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzlezEcutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>zEcutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>zDcutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzlez5cutlass::gemm::threadblock::ThreadblockSwizzleStreamKc                   @   s   e Zd Ze fZe ZdS )GroupScheduleModeN)r   r   r   r	   DeviceHostr   r   r   r   r  ]  s    
r  z5cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnlyz9cutlass::gemm::kernel::GroupScheduleMode::kHostPrecomputer  r  c                   @      e Zd ZdZdZdZdS )ConvKindr   r   rQ   N)r   r   r   FpropDgradWgradr   r   r   r   r  r      r  zcutlass::conv::Operator::kFpropzcutlass::conv::Operator::kDgradzcutlass::conv::Operator::kWgradfpropdgradwgradc                   @   s   e Zd ZdZdZdS )ConvModer   r   N)r   r   r   CrossCorrelationConvolutionr   r   r   r   r    s    r  c                   @       e Zd ZdZdZdZdZdZdS )IteratorAlgorithmr   r   rQ      rR   N)r   r   r   Analytic	OptimizedFixedChannelsFewChannelsFixedStrideDilationr   r   r   r   r    s    r  z+cutlass::conv::IteratorAlgorithm::kAnalyticz,cutlass::conv::IteratorAlgorithm::kOptimizedz0cutlass::conv::IteratorAlgorithm::kFixedChannelsz.cutlass::conv::IteratorAlgorithm::kFewChannelsz6cutlass::conv::IteratorAlgorithm::kFixedStrideDilationanalytic	optimizedfixed_channelsfew_channelsfixed_stride_dilationc                   @   r  )StrideSupportr   r   rQ   N)r   r   r   StridedUnityFixedr   r   r   r   r    r  r  z&cutlass::conv::StrideSupport::kStridedz$cutlass::conv::StrideSupport::kUnityz$cutlass::conv::StrideSupport::kFixedunity_stridefixed_stridec                   @   s$   e Zd Ze Ze Ze Ze ZdS )	GroupModeN)r   r   r   r	   	NoneGroupSingleGroupMultipleGroup	Depthwiser   r   r   r   r    s
    
r  zcutlass::conv::GroupMode::kNonez&cutlass::conv::GroupMode::kSingleGroupz(cutlass::conv::GroupMode::kMultipleGroupz$cutlass::conv::GroupMode::kDepthwisesingle_groupmultiple_group	depthwise)r   r   r   c                   @   s   e Zd ZejdfddZdS )MathInstructionNc                 C   .   || _ || _|| _|| _|| _|| _|| _d S ra   )instruction_shape	element_a	element_belement_accumulatoropcode_classmath_operationelement_scale_factor)selfr  r  r  r  r  r  r  r   r   r   __init__     

zMathInstruction.__init__)r   r   r   r   r   r  r   r   r   r   r    s    r  c                   @   s"   e Zd Z		dddZdd ZdS )TileDescriptionr   r   r   Nc	           	      C   s:   || _ || _|| _|| _|| _|| _|| _|| _|| _d S ra   )	threadblock_shape
tile_shapestages
warp_countmath_instructionminimum_compute_capabilitymaximum_compute_capabilitycluster_shapeexplicit_vector_sizes)	r  r  r  r  r  min_computemax_computer  r  r   r   r   r    s   
zTileDescription.__init__c              	   C   sl   | j dkr$dj| jd | jd | jd | jd | jd | jd | jdS d| jd | jd | jd | jf S )Nrd  z${tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}r   r   rQ   )tbmtbntbkcmcnckrC   z%dx%d_%dx%d)r  formatr  r  r  )r  r   r   r   procedural_name  s    

zTileDescription.procedural_name)r  Nr   r   r   r  r  r   r   r   r   r    s
    	
r  c                   @   s   e Zd Zdd Zdd ZdS ).Direct2dConvFixedStrideDilationTileDescriptionc
           
      C   sj   |d |d  |d  |d |d |d  g| _ || _|| _|| _|| _|| _|| _|| _|| _|	| _	d S )Nr   r   rQ   r  )
r  threadblock_output_shapefilter_shaper  r  stridedilationr  r  r  )
r  r  r  r  r  r  r  r  r  r  r   r   r   r    s"   
z7Direct2dConvFixedStrideDilationTileDescription.__init__c                 C   s   d| j d | j d | j d | jd | jd | jd | jd | j| jd | jd f
 }| jddgkrN| jddgkrN|d| jd | jd | jd | jd f 7 }|S )Nz#%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%dr   r   rQ   r  z_stride%dx%d_dilation%dx%d)r  r  r  r  r  r  )r  str_namer   r   r   r  5  s(   z>Direct2dConvFixedStrideDilationTileDescription.procedural_nameNr  r   r   r   r   r    s    r  c                   @      e Zd ZdejfddZdS )TensorDescriptionr   c                 C   s   || _ || _|| _|| _d S ra   )elementlayout	alignmentcomplex_transform)r  r  r  r  r  r   r   r   r  O  s   
zTensorDescription.__init__Nr   r   r   r^   r_   r  r   r   r   r   r  N  s    r  c                   @   s    e Zd ZdejejfddZdS )SymmetricTensorDescriptionr   c                 C   s(   || _ || _|| _|| _|| _|| _d S ra   )r  r  	fill_moder  r  	side_mode)r  r  r  r  r  r  r  r   r   r   r  Z  s   	
z#SymmetricTensorDescription.__init__N)r   r   r   r^   r_   r,  r-  r  r   r   r   r   r  Y  s
    r  c                   @   r  )TriangularTensorDescriptionr   c                 C   r  ra   )r  r  r  r  	diag_typer  r  )r  r  r  r  r  r  r  r  r   r   r   r  m  r  z$TriangularTensorDescription.__init__Nr  r   r   r   r   r  l  s    r  c                 C   s"  | j j}| j j}| jtjkr[| jtjkr[t	| j
j dkrd}nt	| j
j dkr*d}nd}t	| j
j |d  |d d  d t	| jj |d  |d  d  |d |d d  |  }n.t	| j
j }t	| j
j }|  rqt	| jj }||d  |d  d ||d  |d  d  }|| }|d? S )NrU   rQ   rR   rS   r   r   
   )tile_descriptionr  r  operation_kindrF  rG  ro   rk   rx  r~   Ar  Bis_mixed_input)	operation	cta_shaper  elements_per_8b_mdsmem_per_stagedata_type_size_adata_type_size_b
smem_usager   r   r   CalculateSmemUsage  s2   "r  c                   @   r  )GemmUniversalModez2
    Types corresponding to GemmUniversalMode
    r   r   rQ   r  N)r   r   r   __doc__rG  GemmSplitKParallelBatchedArrayr   r   r   r   r    s    r  c                   @   s   e Zd ZdZdZdZdZdS )
SplitKModez+
    Types corresponding to SplitKMode
    r   r   rQ   N)r   r   r   r  
NoneSplitKSerialParallelr   r   r   r   r    s
    r  (?  r  enumrm  r   r	   ImportErrorr   intEnumr   r   GeneratorTargetNamesr   r   r!   r"   r+   r-   r/   r2   r4   r#   r$   r%   ShortDataTypeNamesr   r   r   r   r   r   r   r   r   r   r   r   r    r'   r&   r(   r)   r*   r,   r.   r0   r1   r3   r;   r<   r=   r>   r?   r@   r5   r6   r7   r8   r9   r:   DataTypeNamesDataTypeTagr~   rZ   r[   r\   BlasModeTagr^   r_   r`   ComplexTransformTagComplexTransformTag3xrh   ri   rp   rt   rv   rz   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   MathOperationTagr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   	LayoutTagTransposedLayoutShortLayoutTypeNamesShortComplexLayoutNamesr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   KernelScheduleTagKernelScheduleSuffixesr  r  r  r  r  r  r  r  r  r  r  r  EpilogueScheduleTagEpilogueScheduleSuffixesr  r  r  EpilogueFunctor3xTagr   r$  r&  r'  r(  r)  TileSchedulerTagTileSchedulerSuffixesr,  r-  r.  SideModeTagShortSideModeNamesr1  r2  r3  FillModeTagShortFillModeNamesr6  r7  r8  DiagTypeTagShortDiagTypeNamesr;  r<  r=  r>  r?  r@  OpcodeClassNamesOpcodeClassTagrF  rG  rH  rI  rJ  rK  rL  rM  OperationKindNamesrU  ArchitectureNamesSharedMemPerCCrw  rk   rx  ry  rz  r{  r|  r}  r~  rl   ru   rm   rr   rs   GemmKindNamesr  RankKKindNamesr  TrmmKindNamesr  SymmKindNamesr  r  EpilogueFunctorTagr  r  r  r  r  r  r  r  r  r  SwizzlingFunctorTagr  r  r  GroupScheduleModeTagShortGroupScheduleModeNamesIntEnumr  r  r  r  ConvKindTagConvKindNamesr  r  r  r  r  r  r  IteratorAlgorithmTagIteratorAlgorithmNamesr  r  r  r  StrideSupportTagStrideSupportNamesr  r  r  r  r  GroupModeTagGroupModeNamesDynamicClusterShaper  r  r  r  r  r  r  r  r  r   r   r   r   <module>   s   
4	
 !"#1	
 !"#1	
 !"#4


	
		
	
	G	
 !"#$%&'()*+,-./01234<	
 !"#$%&'()*+,-./01234<	
	
 
		

	
	
	*7% 