o
    XiZ                     @   sf   d Z ddlZddlZddlmZ ddlmZ ddlm	Z	m
Z
 dd Zdd	 ZG d
d dZdd ZdS )z
A one-layer SmolLM model test case, with inputs: input_ids, position_ids, and pask key/values.
This is an onnxscript version of the model.
    N)script)opset18)FLOATINT64c                    sl   t  dtd dtd dtd dtd dtd td	 td	 ff
 	
fd
d}| }|S )N	input_ids      position_idspast_key_values_0_0)r          @   past_key_values_0_1returnr   r	      )r   r   .   r   c           s        s"  t j | dd}t dd}t dd|}t jddd}t jd	d
gdd}t ||}	t jdd}
t j|	|
dd}t dd
}t dd
}t |d
|}t jddgdd}t j||dd}||k}t j|dd}|| }t jdgd}t j|dd}t 	|}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jdgd} t 
||||| }!t jdgd}"t j|"dd}#t 	|!|#}$t j|$dd}%t jg ddd}&t |&}'t |%|'}(t jddd})t jdgd}*t j|)|*dd}+t jddd},t jdgd}-t j|,|-dd}.t jddd}/t jdgd}0t j|/|0dd}1t jdgd}2t 
||+|.|1|2}3t jdgd}4t j|4dd}5t 	|3|5}6t jddd}7t jdgd}8t j|7|8dd}9t jddd}:t jdgd};t j|:|;dd}<t jddd}=t jdgd}>t j|=|>dd}?t jdgd}@t 
|6|9|<|?|@}At j|Add}Bt j|(dd}Ct j|Bdd}Dt jg ddd}Et |E}Ft |C|F}Gt jg ddd}Ht j|G|Hdd}It jg ddd}Jt |J}Kt |D|K}Lt jg ddd}Mt j|L|Mdd}N|I|N }Ot jg ddd}Pt j|O|Pdd}Qt j|Qg dd}Rt j|R|Rdd}St |S}Tt |S}U|Td }V|Ud }Wt j|Vdd}Xt j|Wdd}Yt j|dd}Zt jddd}[|Z|[ }\t jdgd}]t jdg|]dd}^t j|\|^ddd}_|_d }`t |`}at |a}b|Z|b }ct j|cdd}d|d }et jddgd}ft jd	dgdd}gt j|e|gdd}h|h|f }it jg ddd}jt j|i|jdd}kt jddgd}lt jd	dgdd}mt j|e|mdd}n|n|l }ot jg ddd}pt j|o|pdd}qt j	ddgd}rt jd	dgdd}st j|e|sdd}t|t|r }ut jg ddd}vt j|u|vdd}wt jg ddd}xt j|k|xdd}yt j|yg dd}zt jg ddd}{t j|q|{dd}|t j||g dd}}t jg ddd}~t j|w|~dd}t j|g dd}t jdgd}t j|dd}t 	|X|}t jdgd}t j|dd}t 	|Y|}|z| }t jddd}t jdgd}t j||dd}t jd dd}t jdgd}t j||dd}t jd!dd}t jdgd}t j||dd}t jdgd}t 
|z||||}t jd dd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jd!dd}t jdgd}t j||dd}t jdgd}t 
|z||||}t |}t j||dd}|| }|| }|}| }t jddd}t jdgd}t j||dd}t jd dd}t jdgd}t j||dd}t jd!dd}t jdgd}t j||dd}t jdgd}t 
|}||||}t jd dd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jd!dd}t jdgd}t j||dd}t jdgd}t 
|}||||}t |}t j||dd}|| }|| }t j||d"d}t j||d"d}t jdgd}t j|dd}t 	||}t jdgd}t j|dd}t 	||á}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jdgd}t 
|||||Ρ}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jd!dd}t jdgd}t j||dd}t jdgd}t 
|||||١}t jg d#dd}t |ۡ}t ||ܡ}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jdgd}t 
|||||}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jdgd}t 
|||||}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jddd}t jdgd}t j||dd}t jdgd}t 
|||||}t jddd}t jdgd} t j|| dd}t jddd}t jdgd}t j||dd}t jd!dd}t jdgd}t j||dd}t jdgd}t 
|||||}	t j|dd$}
t jdgd}t j|
|dd}t ||}t jdd%}t ||}t |}|| }t ||}t j|dd$}t jdgd}t 
|dg|}t 
|d"gdg}t jd&gd}t 
||d"g}t jdgd}t j|||dd}t j||dd}t j|g dd}t j|||dd}t j||dd}t |}|| } t |}!||! }"| |" }#|#|	 }$t j|$dd}%t |%d\}&}'|&| }(t j|dd$})t 
|)dgdg}*t 
|)dgdg}+t 
|)d"gdg},t j|+dd}-|-d' }.t |.}/|/d' }0t j|0dd}1t j|*|,|1dd}2t d|2}3t j|(g dd}4t jg d(dd}5t j|4|5dd}6t jddgd}7t jd	dgdd}8t j|6|8dd}9|9|7 }:t jg ddd};t j|:|;dd}<||< }=t j|=dd}>t jddd}?|>|? }@t jdgd}At jdg|Add}Bt j|@|Bddd}C|Cd }Dt |D}Et |E}F|>|F }Gt j|Gdd}H|H }It jddgd}Jt jd	dgdd}Kt j|I|Kdd}L|L|J }Mt jg d)dd}Nt j|M|Ndd}Ot |O}P|O|P }Qt jddgd}Rt jd	dgdd}St j|I|Sdd}T|T|R }Ut jg d)dd}Vt j|U|Vdd}W|Q|W }Xt jddgd}Yt jd	d*gdd}Zt j|X|Zdd}[|[|Y }\t jg ddd}]t j|\|]dd}^|=|^ }_t j|_dd}`t jddd}a|`|a }bt jdgd}ct jdg|cdd}dt j|b|dddd}e|ed }ft |f}gt |g}h|`|h }it j|idd}j
|j }kt j ddgd}lt jd	dgdd}mt j|k|mdd}n|n|l }ot jg d+dd}pt j|o|pdd}qt j|qdd}r|r||fS ),Nr   )axisg      ?r   r   gMr   )tor	   /      )	value_int)upperg        )	allowzero)
value_intsl       )r   r   r   )r   r   r   )r   r   r	   )r   r   r	   )r   r   r   )perm)keepdimsnoop_with_empty_axesgh㈵>   )r   r	   r!   )r   r	   r   r   )r   r   r      r   r"   )r   r   r   r   )start)value_floatl         g      @@)r   r	   r   )r   r	       r&   r   )r   GatherCastLikeRangeCastExpandConstantTriluReshape	UnsqueezeSliceAbs	TransposeConcatCosSin
ReduceMeanSqrt
ReciprocalNegShapeSoftmaxDropoutCeilSigmoid(s  r   r
   r   r   	embeddingval_2arangeval_5val_7fulldiagonal__1triuval_10val_11arange_1val_13viewgtconvert_element_type_defaultmuldim__2dim_0__2	unsqueezeval_15val_16val_17val_19val_20val_21val_23val_24val_25val_26slice_1dim__3dim_0__3unsqueeze_1_to_copy	size_0__4	size_1__4expandval_28val_29val_30val_31val_32val_33val_34val_35val_36val_37slice_2dim__5dim_0__5unsqueeze_2val_38val_39val_40val_41val_42val_43val_45val_46val_47val_48slice_3
_to_copy_1
_to_copy_2
_to_copy_3	size_0__6	size_1__6expand_1val_50view_1	size_0__7	size_1__7expand_2val_52view_2bmmval_54view_3	transposecatcossinmul_1mul_2
_to_copy_4
_to_copy_5
_to_copy_6scalar_tensor_defaultpow_1val_55val_57meanaddval_59rsqrtmul_3
_to_copy_7mul_4tval_61view_4mmval_63view_5t_1val_64view_6mm_1val_65view_7t_2val_66view_8mm_2val_67view_9val_69view_10transpose_1val_70view_11transpose_2val_71view_12transpose_3dim__8dim_0__8unsqueeze_3dim__9dim_0__9unsqueeze_4mul_5val_72val_73val_74val_76val_77val_78val_80val_81val_82val_83slice_4val_84val_85val_86val_87val_88val_89val_90val_91val_92val_93slice_5negcat_1mul_6add_1mul_7val_94val_95val_96val_97val_98val_99val_100val_101val_102val_103slice_6val_104val_105val_106val_107val_108val_109val_110val_111val_112val_113slice_7neg_1cat_2mul_8add_2cat_3cat_4dim__10	dim_0__10unsqueeze_5dim__11	dim_0__11unsqueeze_6val_114val_115val_116val_117val_118val_119val_120val_121val_122val_123slice_8val_124val_125val_126val_127val_128val_129val_130val_131val_132val_133slice_9
size_0__12
size_1__12expand_3val_135val_136val_137val_138val_139val_140val_141val_142val_143val_144slice_10val_145val_146val_147val_148val_149val_150val_151val_152val_153val_154slice_11val_155val_156val_157val_158val_159val_160val_161val_162val_163val_164slice_12val_165val_166val_167val_168val_169val_170val_171val_172val_173val_174slice_13val_175val_176val_177val_178val_179val_180val_181val_182val_183val_184val_185val_186val_188val_189val_190val_191val_192val_193val_194val_195val_196val_197val_198val_199val_200val_201val_202val_203val_204_unusedgetitemval_206val_209val_211val_212val_213val_215val_216val_217val_218val_219._scaled_dot_product_flash_attention_for_cpu__1transpose_4val_221view_13t_3val_222view_14mm_3val_223view_15add_3
_to_copy_8scalar_tensor_default_1pow_2val_224val_225mean_1add_4val_226rsqrt_1mul_9
_to_copy_9mul_10t_4val_227view_16mm_4val_229view_17val_230silut_5val_231view_18mm_5val_232view_19mul_11t_6val_234view_20mm_6val_235view_21add_5_to_copy_10scalar_tensor_default_2pow_3val_236val_237mean_2add_6val_238rsqrt_2mul_12_to_copy_11mul_13t_7val_239view_22mm_7val_241view_23_to_copy_12lm_head_weight%model_layers_0_input_layernorm_weight#model_layers_0_mlp_down_proj_weight#model_layers_0_mlp_gate_proj_weight!model_layers_0_mlp_up_proj_weight.model_layers_0_post_attention_layernorm_weight&model_layers_0_self_attn_k_proj_weight&model_layers_0_self_attn_o_proj_weight&model_layers_0_self_attn_q_proj_weight&model_layers_0_self_attn_v_proj_weightmodel_norm_weightmodel_rotary_emb_inv_freq X/home/ubuntu/.local/lib/python3.10/site-packages/onnxscript/rewriter/models/_smollm_2.py
main_graph   s  









zmake_model.<locals>.main_graph)r   r   r   to_model_proto)r  r  r  r  r  r  r  r  r  r  r  r  r  modelr  r  r  
make_model   s    &  xr  c                  C   s"  t jdt j} t jdt j}t jdt j}t jddt j}t jddt j}t jddt j}t jddt j}t jddt j}t jddt j}t jddt j}	t jddt j}
t jdt j}t| |||||||||	|
|}|S )Nr!   r   r&   r   )numpyrandomrandastypefloat32r  )r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  make_model_with_random_weights  sJ   r  c                   @   s   e Zd Zdd Zdd ZdS )_SmollmTest2c                 C   s(   t | dst }tj|}|| _| jS )N_onnx_model)hasattrr  irserdedeserialize_modelr  )selfmodel_protor  r  r  r  get_onnx_model  s
   
z_SmollmTest2.get_onnx_modelc              	   C   sz   t | ds:tjdddtjtdddtjtjdddd	tj	tjdddd	tj	d
}|| _
| j
S )N_ort_inputsr   r   r   r	   r   r   r   r   )r   r
   r   r   )r  r  r  randintr  int64rA   reshaper  r  r  )r  inputsr  r  r  get_ort_inputs  s   
z_SmollmTest2.get_ort_inputsN)__name__
__module____qualname__r  r  r  r  r  r  r    s    r  c                   C   s   t  S )N)r  r  r  r  r  smollm_test_2  s   r  )__doc__r  onnx_irr  
onnxscriptr   onnxscript.onnx_opsetr   onnxscript.onnx_typesr   r   r  r  r  r  r  r  r  r  <module>   s      (