o
    ٷil                     @   s~  d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlmZ d dl	Z	d dl
Z
d dlZd dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZmZmZmZ e e Z!g dZ"ej#dej$dej%diZ&G dd deZ'G dd deZ(G dd deZ)G dd deZ*G dd deZ+e*ddfe+ddfe)ddfdZ,G dd dZ-G d d! d!Z.dS )"    N)Path)	Precision)float_to_float16_max_diff)FusionOptions)IOBindingHelper)	OnnxModel)optimize_model)torch_onnx_export)
GPT2ConfigGPT2LMHeadModel	GPT2ModelTFGPT2Model)
distilgpt2gpt2zgpt2-mediumz
gpt2-largezgpt2-xlMb@?g?g      @c                       ,   e Zd ZdZ fddZ fddZ  ZS )GPT2ModelNoPastState2Here we wrap a class to disable past state output.c                       t  | d S Nsuper__init__selfconfig	__class__ d/home/ubuntu/.local/lib/python3.10/site-packages/onnxruntime/transformers/models/gpt2/gpt2_helper.pyr   *      zGPT2ModelNoPastState.__init__c                    s   t  j|dddS )NF)	use_cachereturn_dict)r   forwardr   	input_idsr   r   r   r#   -   s   zGPT2ModelNoPastState.forward__name__
__module____qualname____doc__r   r#   __classcell__r   r   r   r   r   '       r   c                       r   )TFGPT2ModelNoPastStater   c                    s   d|_ t | d S )NF)r!   r   r   r   r   r   r   r   4   s   zTFGPT2ModelNoPastState.__init__c                    s   t  j|ddS )NF)r!   )r   callr$   r   r   r   r#   8   r    zTFGPT2ModelNoPastState.forwardr&   r   r   r   r   r-   1   s    r-   c                       s8   e Zd ZdZ fddZedd Z fddZ  ZS )MyGPT2ModelzMHere we wrap a class for Onnx model conversion for GPT2Model with past state.c                    r   r   r   r   r   r   r   r   ?   r    zMyGPT2Model.__init__c                 C   s   t | d d ttfrNt| d |krt| d d dksJ g }t|D ] }|tj| d | d d| d | d dfdd q%| d t|fS | S )N   r      )dim)	
isinstancetuplelistlenrangeappendtorchcat	unsqueeze)result	num_layerpresentir   r   r   post_processB   s   (*zMyGPT2Model.post_processc                    &   t  j||||dd}t|| jjS NF)position_idsattention_maskpast_key_valuesr"   r   r#   r/   r@   r   n_layerr   r%   rC   rD   pastr<   r   r   r   r#   U   s   zMyGPT2Model.forward)	r'   r(   r)   r*   r   staticmethodr@   r#   r+   r   r   r   r   r/   <   s    
r/   c                       r   )MyGPT2LMHeadModelzSHere we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state.c                    r   r   r   r   r   r   r   r   c   r    zMyGPT2LMHeadModel.__init__c                    rA   rB   rF   rH   r   r   r   r#   f   s   zMyGPT2LMHeadModel.forwardr&   r   r   r   r   rK   `   r,   rK   c                       r   )MyGPT2LMHeadModel_NoPaddinga  Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and no padding.
    When you always use batch_size=1 in inference, there is no padding in inputs. In such case, position_ids
    and attention_mask need no be in inputs.
    c                    r   r   r   r   r   r   r   r   x   r    z$MyGPT2LMHeadModel_NoPadding.__init__c                    s"   t  j||dd}t|| jjS )NF)rE   r"   rF   )r   r%   rI   r<   r   r   r   r#   {   s   z#MyGPT2LMHeadModel_NoPadding.forwardr&   r   r   r   r   rL   r   s    rL   logitsTF
last_state)r   GPT2LMHeadModel_NoPaddingr   c                   @   s8   e Zd Zdd ZdefddZdefddZdd	 Zd
S )
Gpt2Inputsc                 C   s   || _ || _|| _|| _d S r   )r%   rC   rD   rI   )r   r%   rC   rD   rI   r   r   r   r      s   
zGpt2Inputs.__init__returnc                 C   s0   dd | j | j| jfD }| jr|| j |S )Nc                 S   s   g | ]}|d ur|qS r   r   .0vr   r   r   
<listcomp>       z&Gpt2Inputs.to_list.<locals>.<listcomp>)r%   rC   rD   rI   extend)r   
input_listr   r   r   to_list   s   zGpt2Inputs.to_listc                 C   s"   t dd | j| j| j| jfD S )Nc                 s   s    | ]	}|d ur|V  qd S r   r   rR   r   r   r   	<genexpr>   s    z&Gpt2Inputs.to_tuple.<locals>.<genexpr>)r4   r%   rC   rD   rI   )r   r   r   r   to_tuple   s   "zGpt2Inputs.to_tuplec                 C   sT   d }| j d ur| j jtjkr| j jtjdn| j }dd | jD }t| j| j	||S )Ndtypec                 S   s   g | ]	}|j tjd qS )r\   )tor9   float32rS   pr   r   r   rU      s    z&Gpt2Inputs.to_fp32.<locals>.<listcomp>)
rD   r]   r9   float16r^   r_   rI   rP   r%   rC   )r   rD   rI   r   r   r   to_fp32   s   
zGpt2Inputs.to_fp32N)	r'   r(   r)   r   r5   rY   r4   r[   rc   r   r   r   r   rP      s
    rP   c                "   @   s  e Zd ZdZedddejejejdfdededededed	ed
edejde	de	de	dej
dej
dej
de	def ddZe	dXdedededededeeee f fddZedd ZedYddZedYdd ZedZd"d#Zed[d%d&Zeddddejejejfd'ed(e	d)e	de	de	dej
dej
dej
fd*d+Ze			,d\d-d.Zeg d/fd0ed1ee fd2d3Zed]d4ed5efd6d7Zed]d4ed5efd8d9Zed:d; Zed^d<d=Ze	,		d_d4ed>eeejf d?eeee f d5ed@e	dAe	fdBdCZ edDdE Z!edFdG Z"eddHdHdIdJddddejejejd,ddfdKdLZ#eddMddddejejejdNdJdOfdPdQZ$ed`dRdSZ%edddg dTfdefdUdVZ&dWS )a
Gpt2HelperzEA helper class for Gpt2 model conversion, inference and verification.FT
batch_sizepast_sequence_lengthsequence_lengthnum_attention_headshidden_sizer=   
vocab_sizedevicerb   has_position_idshas_attention_maskinput_ids_dtypeposition_ids_dtypeattention_mask_dtypeleft_side_paddingrQ   c                    s"  |rt jnt jd| ||t|| g fddt|D }t jd|d | |f| d}d}|
rh|| }t j| |g| d}|dkrht| D ]}td|d }|r]d||d|f< qHd|||| df< qHd}|	r| 	d	d }|
|dk d |dd|df |}t||||S )
zCreate random inputs for GPT2 model.
        Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
        r1   c                    s$   g | ]}t j d d d qS )r]   rk   g       @      ?)r9   rand)rS   _rk   
float_type
past_shaper   r   rU      s   $ z/Gpt2Helper.get_dummy_inputs.<locals>.<listcomp>r   r0   )lowhighsizer]   rk   Nrr   )r9   rb   r_   intr7   randintonesrandomlongcumsummasked_fill_r^   rP   )re   rf   rg   rh   ri   r=   rj   rk   rb   rl   rm   rn   ro   rp   rq   rI   r%   rD   total_sequence_lengthr?   padding_lengthrC   r   rv   r   get_dummy_inputs   sF   
zGpt2Helper.get_dummy_inputsr   r   model_classc                 C   s~   |j }|j}|j}|j}t| d }	| ||	dkr|n|g}
d| ||| t|| g}|	|
i}t|D ]
}||dt| < q2|S )zAReturns a dictionary with output name as key, and shape as value.r0   rM   r1   present_)rh   ri   num_hidden_layersrj   MODEL_CLASSESr}   r7   str)re   rf   rg   r   r   rh   ri   r=   rj   output_namelast_state_shapepresent_state_shapeoutput_shapesr?   r   r   r   get_output_shapes   s&   	
zGpt2Helper.get_output_shapesc                 C   sZ   |D ](}|| v s
J | | }t || | kr*tjt || |j|jd| |< qd S )Nrr   )numpyprodnelementr9   emptyr]   rk   )output_buffersr   keybufferr   r   r   auto_increase_buffer_size  s   
z$Gpt2Helper.auto_increase_buffer_sizec                 C   sD   |rt jnt j}i }|  D ]\}}t jt|||d||< q|S )zpReturns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape.rr   )r9   rb   r_   itemsr   r   r   )r   rk   
is_float16	data_typer   nameshaper   r   r   get_output_buffers  s
   zGpt2Helper.get_output_buffersc                 C   sH   | d    }t||d  }|rt|t|d  S t|S )zGReturns the maximum difference between PyTorch and OnnxRuntime outputs.r   ư>)cpur   absamax)torch_outputsort_outputsrelativeexpected_outputsdiffr   r   r   diff_outputs%  s
   
zGpt2Helper.diff_outputsMbP?c           
   	   K   s   t j|d | d    ||d}td|  |}t|d }t|D ])}t j|d|  | d |    ||d}td| d| d|  |oM|}q%|s`t| |}	t	d|	d	 |S )
zReturns True if torch and ORT outputs are close for given thresholds, and False otherwise.
        Note: need kwargs since Gpt2BeamSearchHelper.compare_outputs has an extra parameter model_class
        r   )rtolatolz9PyTorch and OnnxRuntime output 0 (last_state) are close: r0   zPyTorch and OnnxRuntime layer z state (present_z) are close:z@PyTorch and OnnxRuntime results are not all close: max_abs_diff=.5f)
r   allcloser   loggerdebugr6   r7   rd   r   info)
r   r   r   r   kwargsis_closeis_all_close
num_layerslayermax_abs_diffr   r   r   compare_outputs/  s"   "

zGpt2Helper.compare_outputsr   c                 C   s  d}d}g }g }t t|D ]}|| }|dkr| d n| d |d    }	tj||	|dd}
|tt|	|  |oA|
}t|		 rRt
d| d t|		 rbt
d| d t|	 rrt
d	| d t|	 rt
d	| d t||	 }t| |j}|d
|| dd| d|| ddt|	| d |dkrttj|dd|j}ttj|	dd|	j}t||}q|t|}|t||||fS )a  Compare outputs from PyTorch and OnnxRuntime

        Args:
            torch_outputs (Tuple[Torch.Tensor]): PyTorch model output
            ort_outputs (List[numpy.ndarray]): OnnxRuntime output
            atol (float, optional): Absolute tollerance. Defaults to 1e-06.

        Returns:
            is_all_close(bool): whether all elements are close.
            max_abs_diff(float): maximum absolute difference.
            messages(str): a list of debug message for each output
        TFr   r0   )r   r   zPyTorch output z has nanz has infzORT output zdiff=z.9fz index=z ort=z torch=N)axis)r7   r6   r   r   r   r8   r   r   isnananyr   r   isinffabsunravel_indexargmaxr   floatarray_equalindexmax)r   r   r   r   is_top1_matched	max_diffsmessagesr?   
ort_outputtorch_outputr   r   idxort_max_indextorch_max_indexmax_diff_output_indexr   r   r   compare_outputs_v2J  sF   (0zGpt2Helper.compare_outputs_v2onnx_model_pathverboseuse_external_data_formatc
                 C   s  | j }
|
j}tjddd|
j|
j||
j|d|||||	d}| }t	  | | }W d   n1 s3w   Y  dd t
|D }dd t
|D }|d jd	 |
jks`|d jd	 |
jks`J |d jd	 |
jkrld
ndg|}dddd|d dddi}|D ]	}ddd||< q|D ]	}ddd||< qdg}|rddd|d< |d |rddd|d< |d || t|d	krt|d |ksJ td|jj d|jd j d|d j d|d d j  t|jjddd |rAt ;}tj|d}t|jjddd t| t||d|||ddd|d tj|dd} tj | |ddd W d   dS 1 s:w   Y  dS t| t||d|||ddd|d dS ) z1Export GPT-2 model with past state to ONNX model.r0   F)re   rf   rg   rh   ri   r=   rj   rk   rb   rl   rm   rn   ro   rp   Nc                 S      g | ]}d | qS )past_r   rS   r?   r   r   r   rU         z*Gpt2Helper.export_onnx.<locals>.<listcomp>c                 S   r   )r   r   r   r   r   r   rU     r   r   r1   rM   rN   r%   re   seq_len)r   r0   past_seq_len)r0      total_seq_lenrC   rD   zShapes: input_ids=z past=z output=z	 present=T)parentsexist_okz	gpt2.onnx   )
argsfexport_paramsinput_namesoutput_namesdynamic_axesopset_versiondo_constant_foldingr   r   )load_external_data)save_as_external_dataall_tensors_to_one_file)!r   rG   rd   r   rh   ri   rj   rY   r9   no_gradr7   r   r8   rW   r6   r   r   r%   rI   r   parentmkdirtempfileTemporaryDirectoryospathjoinr	   r4   onnx
load_modelr   save)modelrk   r   r   r   rl   rm   rn   ro   rp   r   r=   dummy_inputsrX   outputs
past_namespresent_namesr   r   r   r   tmp_dir_nametemp_onnx_model_pathr   r   r   export_onnx  s   

,"



 6
$
zGpt2Helper.export_onnxr   c              	   K   sf   t d}	t| d||d|	dd}
|r+|rt|
 nd|vr!d|d< |
jd	ddi| |
|| |
S )
zHOptimize ONNX model with an option to convert it to use mixed precision.r   r   F)
model_type	num_headsri   	opt_leveloptimization_optionsuse_gpukeep_io_typesuse_symbolic_shape_inferTNr   )r   r   rd   auto_mixed_precisionconvert_float_to_float16save_model_to_file)r   optimized_model_pathr   rh   ri   r   r   stager   r   mr   r   r   optimize_onnx  s$   
zGpt2Helper.optimize_onnx)AddLayerNormalizationSkipLayerNormalizationFastGeluEmbedLayerNormalization
onnx_modelop_block_listc                 C   sP  dd |   D }t|}||}td| d|  |  jd j}d}|  }||v s1J || }d}	|j	dkro|}	td	|j  d}
|j
D ]}| |}
|
durY nqLt|
}td
|j d|  |dk }ntd|j	 d|j  g }g }|s|	dur|g}|	jg}||||d}td|  | jdddi| |S )a?  Convert GPT-2 model to mixed precision.
           It detects whether original model has fp16 weights, and set parameters for float16 conversion automatically.
        Args:
            onnx_model (OnnxModel): optimized ONNX model
            op_block_list (List[str], optional): operators to compute in fp32. Defaults to ["Add", "LayerNormalization",
                                                 "SkipLayerNormalization", "FastGelu", "EmbedLayerNormalization"]
        Returns:
            parameters(dict): a dictionary of parameters used in float16 conversion
        c                 S   s   h | ]}|j qS r   )op_type)rS   noder   r   r   	<setcomp>0  s    z2Gpt2Helper.auto_mixed_precision.<locals>.<setcomp>z	fp32 op: z
 fp16 op: r   FNMatMulz#Found last MatMul node for logits: z3max diff of converting weights in last MatMul node : r   z-Failed to find MatMul node for logits. Found z	 of node )r   r  node_block_listforce_fp16_initializersz!auto_mixed_precision parameters: r   Tr   )nodesset
differencer   r   graphoutputr   output_name_to_noder  inputget_initializerr   r   warningr  )r  r  op_full_setfp32_op_setfp16_op_setlogits_output_nameis_weight_fp16_precisionr  r  last_matmul_nodeinitializerr  max_diffr   r  
parametersr   r   r   r     sH   




zGpt2Helper.auto_mixed_precisioninputs
total_runsc           	      C   s   t d |  }t  | | }W d   n1 sw   Y  |dkr)|S g }t   t|D ]}t }| | }|t |  q4W d   n1 sRw   Y  t	|d t
| }t dt|d ||fS )zfRun inference of PyTorch model, and returns average latency in ms when total_runs > 0 besides outputs.zstart pytorch_inferenceNr     zPyTorch inference time = {} ms.2f)r   r   rc   rY   r9   r   r7   timer8   sumr6   format)	r   r'  r(  rX   r   latencyru   startaverage_latencyr   r   r   pytorch_inferenceb  s$   



zGpt2Helper.pytorch_inferencec                 C   s"  t d dt|j  i}|jdur.t|jD ]\}}t|  |d| < q|jdur?t|j  |d< |j	durPt|j	  |d< | 
d|}|dkr\|S g }t|D ]}t }	| 
d|}|t |	  qbt|d t| }
t d	t|
d
 ||
fS )zcRun inference of ONNX model, and returns average latency in ms when total_runs > 0 besides outputs.zstart onnxruntime_inferencer%   Nr   rD   rC   r   r)  z"OnnxRuntime Inference time = {} msr*  )r   r   r   ascontiguousarrayr%   r   rI   	enumeraterD   rC   runr7   r+  r8   r,  r6   r-  )ort_sessionr'  r(  
ort_inputsr?   past_ir   r.  ru   r/  r0  r   r   r   onnxruntime_inference|  s(   



z Gpt2Helper.onnxruntime_inferencec              	   C   s   t | ||||||S )z)Returnas IO binding object for a session.)r   prepare_io_binding)r5  r%   rC   rD   rI   r   r   r   r   r   r9    s   zGpt2Helper.prepare_io_bindingc                 C   s   t | |||S )z3Copy results to cpu. Returns a list of numpy array.)r   "get_outputs_from_io_binding_buffer)r5  r   r   return_numpyr   r   r   r:    s   z-Gpt2Helper.get_outputs_from_io_binding_bufferr   r   r;  include_copy_output_latencyc              	   C   s   t d t| |j|j|j|j||}| | t	| |||}|dkr'|S g }	t
|D ]}
t }| | |rBt	| |||}
|	t |  q-t|	d t|	 }t d| ||fS )zUInference with IO binding. Returns outputs, and optional latency when total_runs > 0.z*start onnxruntime_inference_with_binded_ior   r)  z4OnnxRuntime with IO binding inference time = %.2f ms)r   r   rd   r9  r%   rC   rD   rI   run_with_iobindingr:  r7   r+  r8   r,  r6   )r5  r'  r   r   r(  r;  r<  
io_bindingr   r.  ru   r/  r0  r   r   r   $onnxruntime_inference_with_binded_io  s8   


z/Gpt2Helper.onnxruntime_inference_with_binded_ioc                 C   s   t d|  dd}t|| W d    n1 sw   Y  td|  d t d|  dd}t|| W d    n1 sBw   Y  td|  d d S )Nort_outputs_.picklewbz$ORT output are saved to ort_outputs_torch_outputs_z(Torch output are saved to torch_outputs_openpickledumpr   r   )r?   r   r   r   r   r   r   save_outputs  s   zGpt2Helper.save_outputsc                 C   sT   t d|  dd}t|| W d    n1 sw   Y  td|  d d S )Ndummy_inputs_rA  rB  z!inputs are saved to dummy_inputs_rD  )r?   r   r   r   r   r   r   r   save_inputs  s   zGpt2Helper.save_inputsr   i'  r0   c           ,         s(  |j }td| d d| d| d|	 d| d d}d	}d
}d}|r5t|||||	}t|||}d}d}g  dg| }| }t|D ]}t| }t	d|}|dkr\dnt	d|}t	d|} t
d|  d| d tj| |||j|j|j|j|||
||||dd}!t||!}"|rt| |!}#nt| ||||	}$t| |!||$}#tj|"|#|d\}%}&}'}(})t|&sÈ |& |%r|d7 }|)r|d7 }||  d7  < |r|%std| d|  d| d| d|& 
 t|(D ]\}}*td| d|  | j d|*  q|r*t|&s|&d| kr*t||! t||#|" qH r8 fdddD }+ndd dD }+|d  | |+d!< fd"d#|D |+d$< |d  | |+d%< |t  d  | |+d&< td'| d(| d)|t   d*|  |d+| krtd,t|d | d-d. |+S )/zKGenerate random inputs and compare the results of PyTorch and Onnx Runtime.zRunning parity test (atol=z, test_cases=z, runs=z, use_io_binding=z, model_class=z, is_float16=z) ...      r1   Nr   r0   z#Running parity test for batch_size=z past_sequence_length=z...T)rn   ro   rp   rq   )r   z
test_case=z batch_size=z sequence_length=z	 MaxDiff=	z: Name=z, d   c                    s$   i | ]}d | t  |dqS )max_diff_percentile_r   )r   
percentiler`   )max_abs_diff_listr   r   
<dictcomp>o  s    z*Gpt2Helper.test_parity.<locals>.<dictcomp>)2   Z   _   c   c                 S   s   i | ]}d | dqS )rO  nanr   r`   r   r   r   rR  s  rV   rs   top1_match_ratec                    s   g | ]}|d    qS )rs   r   )rS   x)test_cases_per_runr   r   rU   v  rV   z*Gpt2Helper.test_parity.<locals>.<listcomp>top1_match_rate_per_rundiff_pass_ratenan_ratezParity Test Cases=z	; Passed=z; Nan=z; Top1_Matched=gffffff?zParity is good: passed rate=z.0f%)r   r   r   rd   r   r   r7   r}   r   r~   r   r   rh   ri   rG   rj   r1  r8  r?  r   r   r   r8   r3  get_outputsr   rJ  rH  r6   ),r5  r   rk   r   r   r   rZ  r(  use_io_bindingr   rl   rm   rn   ro   rp   r  r   enable_pickle_outputr   max_batch_sizemax_past_seq_lenmax_seq_lenr   max_output_shapespassed_test_casestop1_matched_casestop1_matched_cases_per_runtotal_test_casesr?   run_idrg   rf   re   r   r   r   r   r   r   r   r   r   messager<   r   )rQ  rZ  r   test_parity  s   (




 ( 
" zGpt2Helper.test_parityrN  rK      c                 C   s   |j }d}|rt|||||}t|||}tj||||j|j|j|j|||||	|
|d}|r;t	| ||\}}|S t
| ||||\}}|S )zCGenerate random inputs and measure average latency of Onnx Runtime.N)rn   ro   rp   )r   rd   r   r   r   rh   ri   rG   rj   r8  r?  )r5  r   rk   r   r(  r`  r   rl   rm   rn   ro   rp   re   rg   rf   r   r   r   r   ru   r.  r   r   r   test_performance  s<   

zGpt2Helper.test_performancec                 C   s:   t jddd|j|j|j|j|d||d }tj	| |S )zJIT trace for TorchScript.r0   F)re   rf   rg   rh   ri   r=   rj   rk   rb   rl   rm   )
rd   r   rh   ri   rG   rj   rY   r9   jittrace)r   r   rk   rl   rm   rX   r   r   r   torchscript  s    zGpt2Helper.torchscriptrawfp32fp16int8c                 C   s  |}t j|rt|jd }n|dd  |dkr!|d| 7 }|r'|d7 }|rdddd	d
}d
D ]P}t j| |||  }	t j|	r||v rwzt	|	 t
d|	  W q2 tyv }
 zt
d|	 d|
j  W Y d}
~
q2d}
~
ww t
d| d|	  q2t jt j| ||d t jt j| |d |d t jt j| |d |d t jt j| |d	 |d d
S t j| |d t j| |d t j| |d t j| |d d
S )z=Build a  path name for given model based on given attributes.r|   /r   ru   _past _fp32_fp16_int8rr  zRemoved the existed directory: zFailed to remove the directory r  NzDirectory for z
 existed: z.onnxz
_fp32.onnxz
_fp16.onnxz
_int8.onnx)r   r   isdirr   partssplitr   existsshutilrmtreer   r   OSErrorstrerror)
output_dirmodel_name_or_pathr   has_past
new_folderremove_existing
model_namesuffixr   new_direr   r   r   get_onnx_paths  sT   

$zGpt2Helper.get_onnx_pathsN)r   )F)r   r   )r   )FFr   )r   )T)r   TF)TT)'r'   r(   r)   r*   rJ   r9   int32r}   rk   boolr]   rP   r   r
   r   dictr5   r   r   r   r   r   r   r   r  r   r   r1  r8  r9  r:  Tensorr?  rH  rJ  rl  rn  rq  r  r   r   r   r   rd      s`   
	
@"
		5	
w#E
2
	
 6rd   )/loggingr   rF  r   r  r   r+  pathlibr   r   r   r9   benchmark_helperr   rb   r   fusion_optionsr   io_binding_helperr   r  r   	optimizerr   torch_onnx_export_helperr	   transformersr
   r   r   r   	getLoggerr'   r   PRETRAINED_GPT2_MODELSFLOAT32FLOAT16INT8DEFAULT_TOLERANCEr   r-   r/   rK   rL   r   rP   rd   r   r   r   r   <module>   sH   

$