o
    	۷iR                     @   s  d Z ddlZddlmZmZ ddlZddlmZ ddlmZm	Z	m
Z
 ddlmZ ddlmZmZmZmZmZmZmZ dd	lmZ dd
lmZmZ ddlmZ eeZG dd dejZ G dd dejZ!G dd dej"Z#G dd dejZ$G dd dejZ%G dd dejZ&G dd dejZ'G dd dejZ(G dd dejZ)G dd  d ejZ*G d!d" d"ejZ+G d#d$ d$ejZ,eG d%d& d&eZ-eG d'd( d(e-Z.eG d)d* d*e-Z/ed+d,G d-d. d.e-Z0eG d/d0 d0e-Z1eG d1d2 d2e-Z2eG d3d4 d4e-Z3g d5Z4dS )6zPyTorch SqueezeBert model.    N)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)auto_docstringlogging   )SqueezeBertConfigc                       s*   e Zd ZdZ fddZdddZ  ZS )SqueezeBertEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _| jdt|jddd d S )N)padding_idxepsposition_ids)r   F)
persistent)super__init__r   	Embedding
vocab_sizeembedding_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormhidden_sizelayer_norm_epsDropouthidden_dropout_probdropoutregister_buffertorcharangeexpandselfconfig	__class__ j/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/squeezebert/modeling_squeezebert.pyr   0   s   

zSqueezeBertEmbeddings.__init__Nc           
      C   s   |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u r3tj|tj| jjd}|d u r<| |}| |}| |}|| | }	| 	|	}	| 
|	}	|	S )Nr   r   dtypedevice)sizer   r/   zeroslongr;   r#   r%   r'   r(   r-   )
r3   	input_idstoken_type_idsr   inputs_embedsinput_shape
seq_lengthr%   r'   
embeddingsr7   r7   r8   forward@   s    





zSqueezeBertEmbeddings.forward)NNNN__name__
__module____qualname____doc__r   rE   __classcell__r7   r7   r5   r8   r   -   s    r   c                       (   e Zd ZdZ fddZdd Z  ZS )MatMulWrapperz
    Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call
    torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul.
    c                    s   t    d S N)r   r   r3   r5   r7   r8   r   _      zMatMulWrapper.__init__c                 C   s   t ||S )a0  

        :param inputs: two torch tensors :return: matmul of these tensors

        Here are the typical dimensions found in BERT (the B is optional) mat1.shape: [B, <optional extra dims>, M, K]
        mat2.shape: [B, <optional extra dims>, K, N] output shape: [B, <optional extra dims>, M, N]
        )r/   matmul)r3   mat1mat2r7   r7   r8   rE   b   s   zMatMulWrapper.forwardrF   r7   r7   r5   r8   rM   Y   s    rM   c                   @   s"   e Zd ZdZdddZdd ZdS )	SqueezeBertLayerNormz
    This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension.

    N = batch C = channels W = sequence length
    -q=c                 C   s   t jj| ||d d S )N)normalized_shaper   )r   r(   r   )r3   r)   r   r7   r7   r8   r   t   s   zSqueezeBertLayerNorm.__init__c                 C   s*   | ddd}tj| |}| dddS )Nr      r   )permuter   r(   rE   )r3   xr7   r7   r8   rE   w   s   zSqueezeBertLayerNorm.forwardN)rU   )rG   rH   rI   rJ   r   rE   r7   r7   r7   r8   rT   m   s    
rT   c                       rL   )ConvDropoutLayerNormz8
    ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
    c                    s8   t    tj||d|d| _t|| _t|| _d S Nr   in_channelsout_channelskernel_sizegroups)	r   r   r   Conv1dconv1drT   	layernormr+   r-   )r3   cincoutr`   dropout_probr5   r7   r8   r      s   

zConvDropoutLayerNorm.__init__c                 C   s*   |  |}| |}|| }| |}|S rN   )rb   r-   rc   )r3   hidden_statesinput_tensorrY   r7   r7   r8   rE      s
   


zConvDropoutLayerNorm.forwardrF   r7   r7   r5   r8   rZ   }   s    rZ   c                       rL   )ConvActivationz*
    ConvActivation: Conv, Activation
    c                    s,   t    tj||d|d| _t| | _d S r[   )r   r   r   ra   rb   r	   act)r3   rd   re   r`   rj   r5   r7   r8   r      s   
zConvActivation.__init__c                 C   s   |  |}| |S rN   )rb   rj   )r3   rY   outputr7   r7   r8   rE      s   

zConvActivation.forwardrF   r7   r7   r5   r8   ri      s    ri   c                       s>   e Zd Zd fdd	Zdd Zdd Zdd	 Zd
d Z  ZS )SqueezeBertSelfAttentionr   c                    s   t    ||j dkrtd| d|j d|j| _t||j | _| j| j | _tj||d|d| _	tj||d|d| _
tj||d|d| _t|j| _tjdd| _t | _t | _d	S )
z
        config = used for some things; ignored for others (work in progress...) cin = input channels = output channels
        groups = number of groups to use in conv1d layers
        r   zcin (z6) is not a multiple of the number of attention heads ()r   r\   r   dimN)r   r   num_attention_heads
ValueErrorintattention_head_sizeall_head_sizer   ra   querykeyvaluer+   attention_probs_dropout_probr-   SoftmaxsoftmaxrM   	matmul_qk
matmul_qkv)r3   r4   rd   q_groupsk_groupsv_groupsr5   r7   r8   r      s   
z!SqueezeBertSelfAttention.__init__c                 C   s:   |  d | j| j|  d f}|j| }|ddddS )z
        - input: [N, C, W]
        - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents
        r   r   r   r   rW   )r<   rp   rs   viewrX   r3   rY   new_x_shaper7   r7   r8   transpose_for_scores   s    
z-SqueezeBertSelfAttention.transpose_for_scoresc                 C   s.   |  d | j| j|  d f}|j| }|S )z
        - input: [N, C, W]
        - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents
        r   r   )r<   rp   rs   r   r   r7   r7   r8   transpose_key_for_scores   s    
z1SqueezeBertSelfAttention.transpose_key_for_scoresc                 C   s>   | dddd }| d | j| d f}|j| }|S )zE
        - input: [N, C1, W, C2]
        - output: [N, C, W]
        r   r   r   rW   )rX   
contiguousr<   rt   r   r   r7   r7   r8   transpose_output   s   
z)SqueezeBertSelfAttention.transpose_outputc                 C   s   |  |}| |}| |}| |}| |}| |}	| ||}
|
t| j }
|
| }
| 	|
}| 
|}| ||	}| |}d|i}|rO|
|d< |S )z
        expects hidden_states in [N, C, W] data layout.

        The attention_mask data layout is [N, W], and it does not need to be transposed.
        context_layerattention_score)ru   rv   rw   r   r   r{   mathsqrtrs   rz   r-   r|   r   )r3   rg   attention_maskoutput_attentionsmixed_query_layermixed_key_layermixed_value_layerquery_layer	key_layervalue_layerr   attention_probsr   resultr7   r7   r8   rE      s"   








z SqueezeBertSelfAttention.forward)r   r   r   )	rG   rH   rI   r   r   r   r   rE   rK   r7   r7   r5   r8   rl      s    	

rl   c                       $   e Zd Z fddZdd Z  ZS )SqueezeBertModulec                    s   t    |j}|j}|j}|j}t|||j|j|jd| _t	|||j
|jd| _t|||j|jd| _t	|||j|jd| _dS )a  
        - hidden_size = input chans = output chans for Q, K, V (they are all the same ... for now) = output chans for
          the module
        - intermediate_size = output chans for intermediate layer
        - groups = number of groups for all layers in the BertModule. (eventually we could change the interface to
          allow different groups for different layers)
        )r4   rd   r}   r~   r   )rd   re   r`   rf   )rd   re   r`   rj   N)r   r   r)   intermediate_sizerl   r}   r~   r   	attentionrZ   post_attention_groupsr,   post_attentionri   intermediate_groups
hidden_actintermediateoutput_groupsrk   )r3   r4   c0c1c2c3r5   r7   r8   r      s   
zSqueezeBertModule.__init__c           
      C   sT   |  |||}|d }| ||}| |}| ||}d|i}	|r(|d |	d< |	S )Nr   feature_mapr   )r   r   r   rk   )
r3   rg   r   r   attattention_outputpost_attention_outputintermediate_outputlayer_outputoutput_dictr7   r7   r8   rE     s   
zSqueezeBertModule.forwardrG   rH   rI   r   rE   rK   r7   r7   r5   r8   r      s    r   c                       s0   e Zd Z fddZ					dddZ  ZS )	SqueezeBertEncoderc                    sB   t     j jksJ dt fddt jD | _d S )NzIf you want embedding_size != intermediate hidden_size, please insert a Conv1d layer to adjust the number of channels before the first SqueezeBertModule.c                 3   s    | ]}t  V  qd S rN   )r   ).0_r4   r7   r8   	<genexpr>.  s    z.SqueezeBertEncoder.__init__.<locals>.<genexpr>)	r   r   r!   r)   r   
ModuleListrangenum_hidden_layerslayersr2   r5   r   r8   r   %  s
   
$zSqueezeBertEncoder.__init__NFTc                 C   s  |d u rd}n| d t|krd}nd}|du sJ d|ddd}|r(dnd }|r.dnd }	| jD ]+}
|rJ|ddd}||f7 }|ddd}|
|||}|d }|r^|	|d	 f7 }	q3|ddd}|rm||f7 }|s{td
d |||	fD S t|||	dS )NTFzAhead_mask is not yet supported in the SqueezeBert implementation.r   rW   r   r7   r   r   c                 s   s    | ]	}|d ur|V  qd S rN   r7   )r   vr7   r7   r8   r   [  s    z-SqueezeBertEncoder.forward.<locals>.<genexpr>)last_hidden_staterg   
attentions)countlenrX   r   rE   tupler
   )r3   rg   r   	head_maskr   output_hidden_statesreturn_dicthead_mask_is_all_noneall_hidden_statesall_attentionslayerr   r7   r7   r8   rE   0  s6   	


zSqueezeBertEncoder.forward)NNFFTr   r7   r7   r5   r8   r   $  s    r   c                       r   )SqueezeBertPoolerc                    s*   t    t|j|j| _t | _d S rN   )r   r   r   Linearr)   denseTanh
activationr2   r5   r7   r8   r   b  s   
zSqueezeBertPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r3   rg   first_token_tensorpooled_outputr7   r7   r8   rE   g  s   

zSqueezeBertPooler.forwardr   r7   r7   r5   r8   r   a  s    r   c                       r   )"SqueezeBertPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S )Nr   )r   r   r   r   r)   r   
isinstancer   strr	   transform_act_fnr(   r*   r2   r5   r7   r8   r   q  s   
z+SqueezeBertPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S rN   )r   r   r(   r3   rg   r7   r7   r8   rE   z  s   


z*SqueezeBertPredictionHeadTransform.forwardr   r7   r7   r5   r8   r   p  s    	r   c                       s.   e Zd Z fddZd	ddZdd Z  ZS )
SqueezeBertLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)bias)r   r   r   	transformr   r   r)   r    decoder	Parameterr/   r=   r   r2   r5   r7   r8   r     s
   

z$SqueezeBertLMPredictionHead.__init__returnNc                 C   s   | j | j_ d S rN   )r   r   rO   r7   r7   r8   _tie_weights  rP   z(SqueezeBertLMPredictionHead._tie_weightsc                 C   s   |  |}| |}|S rN   )r   r   r   r7   r7   r8   rE     s   

z#SqueezeBertLMPredictionHead.forward)r   N)rG   rH   rI   r   r   rE   rK   r7   r7   r5   r8   r     s    
r   c                       r   )SqueezeBertOnlyMLMHeadc                    s   t    t|| _d S rN   )r   r   r   predictionsr2   r5   r7   r8   r     s   
zSqueezeBertOnlyMLMHead.__init__c                 C   s   |  |}|S rN   )r   )r3   sequence_outputprediction_scoresr7   r7   r8   rE     s   
zSqueezeBertOnlyMLMHead.forwardr   r7   r7   r5   r8   r     s    r   c                   @   s"   e Zd ZU eed< dZdd ZdS )SqueezeBertPreTrainedModelr4   transformerc                 C   s   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjrF|jjjd| jjd |jdurD|jj|j 
  dS dS t |tjr[|j	j
  |jjd dS t |trh|j	j
  dS dS )zInitialize the weightsg        )meanstdNg      ?)r   r   r   ra   weightdatanormal_r4   initializer_ranger   zero_r   r   r(   fill_r   )r3   moduler7   r7   r8   _init_weights  s    


z(SqueezeBertPreTrainedModel._init_weightsN)rG   rH   rI   r   __annotations__base_model_prefixr   r7   r7   r7   r8   r     s   
 r   c                       s   e Zd Z fddZdd Zdd Zdd Ze																		dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j dee dee dee deeef fddZ  ZS )SqueezeBertModelc                    s6   t  | t|| _t|| _t|| _|   d S rN   )	r   r   r   rD   r   encoderr   pooler	post_initr2   r5   r7   r8   r     s
   


zSqueezeBertModel.__init__c                 C   s   | j jS rN   rD   r#   rO   r7   r7   r8   get_input_embeddings  s   z%SqueezeBertModel.get_input_embeddingsc                 C   s   || j _d S rN   r   r3   new_embeddingsr7   r7   r8   set_input_embeddings  s   z%SqueezeBertModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   prune_heads)r3   heads_to_pruner   headsr7   r7   r8   _prune_heads  s   zSqueezeBertModel._prune_headsNr?   r   r@   r   r   rA   r   r   r   r   c
                 C   sZ  |d ur|n| j j}|d ur|n| j j}|	d ur|	n| j j}	|d ur*|d ur*td|d ur9| || | }
n|d urF| d d }
ntd|d urQ|jn|j}|d u r_tj	|
|d}|d u rltj
|
tj|d}| ||
}| || j j}| j||||d}| j||||||	d}|d }| |}|	s||f|d	d   S t|||j|jd
S )NzDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embeds)r;   r9   )r?   r   r@   rA   )rg   r   r   r   r   r   r   r   )r   pooler_outputrg   r   )r4   r   r   use_return_dictrq   %warn_if_padding_and_no_attention_maskr<   r;   r/   onesr=   r>   get_extended_attention_maskget_head_maskr   rD   r   r   r   rg   r   )r3   r?   r   r@   r   r   rA   r   r   r   rB   r;   extended_attention_maskembedding_outputencoder_outputsr   r   r7   r7   r8   rE     sP   

zSqueezeBertModel.forward)	NNNNNNNNN)rG   rH   rI   r   r   r   r   r   r   r/   TensorFloatTensorboolr   r   r   rE   rK   r7   r7   r5   r8   r     sH    
	

r   c                       s   e Zd ZddgZ fddZdd Zdd Ze																				dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee dee deeef fddZ  ZS )SqueezeBertForMaskedLMzcls.predictions.decoder.weightzcls.predictions.decoder.biasc                    s,   t  | t|| _t|| _|   d S rN   )r   r   r   r   r   clsr   r2   r5   r7   r8   r     s   

zSqueezeBertForMaskedLM.__init__c                 C   s
   | j jjS rN   )r  r   r   rO   r7   r7   r8   get_output_embeddings&  s   
z,SqueezeBertForMaskedLM.get_output_embeddingsc                 C   s   || j j_|j| j j_d S rN   )r  r   r   r   r   r7   r7   r8   set_output_embeddings)  s   
z,SqueezeBertForMaskedLM.set_output_embeddingsNr?   r   r@   r   r   rA   labelsr   r   r   r   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|dur8t }||d| j j|d}|
sN|f|dd  }|durL|f| S |S t|||j|j	dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nr   r@   r   r   rA   r   r   r   r   r   rW   losslogitsrg   r   )
r4   r   r   r  r   r   r    r   rg   r   )r3   r?   r   r@   r   r   rA   r  r   r   r   outputsr   r   masked_lm_lossloss_fctrk   r7   r7   r8   rE   -  s6   
zSqueezeBertForMaskedLM.forward
NNNNNNNNNN)rG   rH   rI   _tied_weights_keysr   r  r  r   r   r/   r   r   r   r   r   rE   rK   r7   r7   r5   r8   r     sN    		

r   z
    SqueezeBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    )custom_introc                          e Zd Z fddZe										ddeej deej deej deej deej d	eej d
eej dee dee dee de	e
ef fddZ  ZS )$SqueezeBertForSequenceClassificationc                    sR   t  | |j| _|| _t|| _t|j| _	t
|j| jj| _|   d S rN   )r   r   
num_labelsr4   r   r   r   r+   r,   r-   r   r)   
classifierr   r2   r5   r7   r8   r   j  s   
z-SqueezeBertForSequenceClassification.__init__Nr?   r   r@   r   r   rA   r  r   r   r   r   c                 C   sr  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur| j jdu rV| jdkr<d| j _n| jdkrR|jtj	ksM|jtj
krRd| j _nd| j _| j jdkrtt }| jdkrn|| | }n+|||}n%| j jdkrt }||d| j|d}n| j jdkrt }|||}|
s|f|dd  }|dur|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr  r   
regressionsingle_label_classificationmulti_label_classificationr   rW   r  )r4   r   r   r-   r  problem_typer  r:   r/   r>   rr   r   squeezer   r   r   r   rg   r   )r3   r?   r   r@   r   r   rA   r  r   r   r   r	  r   r  r  r  rk   r7   r7   r8   rE   v  sV   



"


z,SqueezeBertForSequenceClassification.forwardr  )rG   rH   rI   r   r   r   r/   r   r   r   r   r   rE   rK   r7   r7   r5   r8   r  c  sH    	

r  c                       r  )SqueezeBertForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S )Nr   )r   r   r   r   r   r+   r,   r-   r   r)   r  r   r2   r5   r7   r8   r     s
   
z%SqueezeBertForMultipleChoice.__init__Nr?   r   r@   r   r   rA   r  r   r   r   r   c                 C   sn  |
dur|
n| j j}
|dur|jd n|jd }|dur%|d|dnd}|dur4|d|dnd}|durC|d|dnd}|durR|d|dnd}|dure|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|durt }|||}|
s|f|dd  }|dur|f| S |S t	|||j
|jdS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
            *input_ids* above)
        Nr   r   r  rW   r  )r4   r   shaper   r<   r   r-   r  r   r   rg   r   )r3   r?   r   r@   r   r   rA   r  r   r   r   num_choicesr	  r   r  reshaped_logitsr  r  rk   r7   r7   r8   rE     sL   ,


z$SqueezeBertForMultipleChoice.forwardr  )rG   rH   rI   r   r   r   r/   r   r   r   r   r   rE   rK   r7   r7   r5   r8   r    sH    
	

r  c                       r  )!SqueezeBertForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S rN   )r   r   r  r   r   r   r+   r,   r-   r   r)   r  r   r2   r5   r7   r8   r   *  s   
z*SqueezeBertForTokenClassification.__init__Nr?   r   r@   r   r   rA   r  r   r   r   r   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur<t }||d| j|d}|
sR|f|dd  }|durP|f| S |S t|||j	|j
dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr  r   r   rW   r  )r4   r   r   r-   r  r   r   r  r   rg   r   )r3   r?   r   r@   r   r   rA   r  r   r   r   r	  r   r  r  r  rk   r7   r7   r8   rE   5  s8   

z)SqueezeBertForTokenClassification.forwardr  )rG   rH   rI   r   r   r   r/   r   r   r   r   r   rE   rK   r7   r7   r5   r8   r  (  sH    	

r  c                       s   e Zd Z fddZe											ddeej deej deej deej deej d	eej d
eej deej dee dee dee de	e
ef fddZ  ZS )SqueezeBertForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S rN   )
r   r   r  r   r   r   r   r)   
qa_outputsr   r2   r5   r7   r8   r   m  s
   
z(SqueezeBertForQuestionAnswering.__init__Nr?   r   r@   r   r   rA   start_positionsend_positionsr   r   r   r   c                 C   sH  |d ur|n| j j}| j|||||||	|
|d	}|d }| |}|jddd\}}|d }|d }d }|d ur|d urt| dkrO|d}t| dkr\|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s||f|dd   }|d ur|f| S |S t||||j|jdS )	Nr  r   r   r   rn   )ignore_indexrW   )r  start_logits
end_logitsrg   r   )r4   r   r   r  splitr  r   r   r<   clampr   r   rg   r   )r3   r?   r   r@   r   r   rA   r   r!  r   r   r   r	  r   r  r#  r$  
total_lossignored_indexr  
start_lossend_lossrk   r7   r7   r8   rE   w  sP   






z'SqueezeBertForQuestionAnswering.forward)NNNNNNNNNNN)rG   rH   rI   r   r   r   r/   r   r   r   r   r   rE   rK   r7   r7   r5   r8   r  k  sN    
	

r  )r   r  r  r  r  r   r   r   )5rJ   r   typingr   r   r/   r   torch.nnr   r   r   activationsr	   modeling_outputsr
   r   r   r   r   r   r   modeling_utilsr   utilsr   r   configuration_squeezebertr   
get_loggerrG   loggerModuler   rM   r(   rT   rZ   ri   rl   r   r   r   r   r   r   r   r   r   r  r  r  r  __all__r7   r7   r7   r8   <module>   sR   $	
,Z*=
^IWgBM