o
    wih                     @   s  d Z ddlZddlmZ ddlmZmZ ddlZddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlmZmZmZ ddlmZmZ ddlmZ ee Z!G dd de	j"Z#G dd de	j"Z$G dd de	j"Z%de$iZ&G dd de	j"Z'G dd de	j"Z(G dd de	j"Z)G dd deZ*G dd  d e	j"Z+eG d!d" d"eZ,eG d#d$ d$e,Z-G d%d& d&e	j"Z.G d'd( d(e	j"Z/eG d)d* d*e,Z0eed+d,G d-d. d.eZ1ed/d,G d0d1 d1e,Z2g d2Z3dS )3zPyTorch Splinter model.    N)	dataclass)OptionalUnion)nn)CrossEntropyLoss   )ACT2FN)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentionsModelOutputQuestionAnsweringModelOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )SplinterConfigc                       sj   e Zd ZdZ fddZ					ddeej deej deej d	eej d
ee	 de
fddZ  ZS )SplinterEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _| jdt|jddd t|dd| _d S )	N)padding_idxepsposition_ids)r   F)
persistentposition_embedding_typeabsolute)super__init__r   	Embedding
vocab_sizehidden_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutregister_buffertorcharangeexpandgetattrr   selfconfig	__class__ k/home/ubuntu/sommelier/.venv/lib/python3.10/site-packages/transformers/models/splinter/modeling_splinter.pyr   ,   s   
zSplinterEmbeddings.__init__Nr   	input_idstoken_type_idsr   inputs_embedspast_key_values_lengthreturnc                 C   s   |d ur	|  }n|  d d }|d }|d u r&| jd d ||| f }|d u r5tj|tj| jjd}|d u r>| |}| |}|| }	| jdkrU| 	|}
|	|
7 }	| 
|	}	| |	}	|	S )Nr   r   dtypedevicer   )sizer   r/   zeroslongrA   r$   r(   r   r&   r)   r-   )r4   r:   r;   r   r<   r=   input_shape
seq_lengthr(   
embeddingsr&   r8   r8   r9   forward=   s$   






zSplinterEmbeddings.forward)NNNNr   )__name__
__module____qualname____doc__r   r   r/   
LongTensorFloatTensorinttuplerH   __classcell__r8   r8   r6   r9   r   )   s*    r   c                       s   e Zd Zd fdd	ZdejdejfddZ						dd	ejd
eej deej deej deej dee	e	ej   dee
 de	ej fddZ  ZS )SplinterSelfAttentionNc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|p\t|dd| _| jdksh| jd	kry|j| _t	d
|j d | j| _|j| _d S )Nr   embedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()r   r   relative_keyrelative_key_query   r   )r   r   r"   num_attention_headshasattr
ValueErrorrO   attention_head_sizeall_head_sizer   Linearquerykeyvaluer+   attention_probs_dropout_probr-   r2   r   r%   r    distance_embedding
is_decoderr4   r5   r   r6   r8   r9   r   a   s*   

zSplinterSelfAttention.__init__xr>   c                 C   s6   |  d d | j| jf }||}|ddddS )Nr   r   rW   r   r   )rB   rX   r[   viewpermute)r4   re   new_x_shaper8   r8   r9   transpose_for_scores{   s   
z*SplinterSelfAttention.transpose_for_scoresFhidden_statesattention_mask	head_maskencoder_hidden_statesencoder_attention_maskpast_key_valueoutput_attentionsc                 C   s  |  |}|d u}	|	r|d ur|d }
|d }|}nP|	r/| | |}
| | |}|}n;|d urZ| | |}
| | |}tj|d |
gdd}
tj|d |gdd}n| | |}
| | |}| |}|d u}| jrz|
|f}t||
dd}| j	dks| j	dkr	|j
d |
j
d }}|rtj|d tj|jd	dd}ntj|tj|jd	dd}tj|tj|jd	dd}|| }| || j d }|j|jd
}| j	dkrtd||}|| }n| j	dkr	td||}td|
|}|| | }|t| j }|d ur|| }tjj|dd}| |}|d ur0|| }t||}|dddd }| d d | jf }||}|rX||fn|f}| jrd||f }|S )Nr   r   rW   dimr   rU   rV   r?   )r@   zbhld,lrd->bhlrzbhrd,lrd->bhlrr   ) r^   ri   r_   r`   r/   catrc   matmul	transposer   shapetensorrD   rA   rf   r0   rb   r%   tor@   einsummathsqrtr[   r   
functionalsoftmaxr-   rg   
contiguousrB   r\   )r4   rj   rk   rl   rm   rn   ro   rp   mixed_query_layeris_cross_attention	key_layervalue_layerquery_layer	use_cacheattention_scoresquery_length
key_lengthposition_ids_lposition_ids_rdistancepositional_embeddingrelative_position_scoresrelative_position_scores_queryrelative_position_scores_keyattention_probscontext_layernew_context_layer_shapeoutputsr8   r8   r9   rH      sn   









zSplinterSelfAttention.forwardNNNNNNF)rI   rJ   rK   r   r/   Tensorri   r   rN   rP   boolrH   rQ   r8   r8   r6   r9   rR   `   s4    	rR   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )SplinterSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr   )r   r   r   r]   r"   denser)   r*   r+   r,   r-   r3   r6   r8   r9   r         
zSplinterSelfOutput.__init__rj   input_tensorr>   c                 C   &   |  |}| |}| || }|S r   r   r-   r)   r4   rj   r   r8   r8   r9   rH         

zSplinterSelfOutput.forwardrI   rJ   rK   r   r/   r   rH   rQ   r8   r8   r6   r9   r          $r   eagerc                       s   e Zd Zd fdd	Zdd Z						ddejdeej d	eej d
eej deej dee	e	ej   dee
 de	ej fddZ  ZS )SplinterAttentionNc                    s4   t    t|j ||d| _t|| _t | _d S )Nr   )	r   r   SPLINTER_SELF_ATTENTION_CLASSES_attn_implementationr4   r   outputsetpruned_headsrd   r6   r8   r9   r      s   

zSplinterAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rq   )lenr   r4   rX   r[   r   r   r^   r_   r`   r   r   r\   union)r4   headsindexr8   r8   r9   prune_heads  s   zSplinterAttention.prune_headsFrj   rk   rl   rm   rn   ro   rp   r>   c              	   C   s<   |  |||||||}| |d |}	|	f|dd   }
|
S )Nr   r   )r4   r   )r4   rj   rk   rl   rm   rn   ro   rp   self_outputsattention_outputr   r8   r8   r9   rH     s   
	zSplinterAttention.forwardr   r   )rI   rJ   rK   r   r   r/   r   r   rN   rP   r   rH   rQ   r8   r8   r6   r9   r      s4    	r   c                       s2   e Zd Z fddZdejdejfddZ  ZS )SplinterIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )r   r   r   r]   r"   intermediate_sizer   
isinstance
hidden_actstrr   intermediate_act_fnr3   r6   r8   r9   r   0  s
   
zSplinterIntermediate.__init__rj   r>   c                 C   s   |  |}| |}|S r   )r   r   )r4   rj   r8   r8   r9   rH   8  s   

zSplinterIntermediate.forwardr   r8   r8   r6   r9   r   /  s    r   c                       r   )SplinterOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r   r   r   r]   r   r"   r   r)   r*   r+   r,   r-   r3   r6   r8   r9   r   @  r   zSplinterOutput.__init__rj   r   r>   c                 C   r   r   r   r   r8   r8   r9   rH   F  r   zSplinterOutput.forwardr   r8   r8   r6   r9   r   ?  r   r   c                       s   e Zd Z fddZ						ddejdeej deej deej d	eej d
eeeej   dee	 deej fddZ
dd Z  ZS )SplinterLayerc                    sr   t    |j| _d| _t|| _|j| _|j| _| jr-| js&t|  dt|dd| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is addedr   r   )r   r   chunk_size_feed_forwardseq_len_dimr   	attentionrc   add_cross_attentionrZ   crossattentionr   intermediater   r   r3   r6   r8   r9   r   O  s   


zSplinterLayer.__init__NFrj   rk   rl   rm   rn   ro   rp   r>   c              	   C   s  |d ur
|d d nd }| j |||||d}	|	d }
| jr(|	dd }|	d }n|	dd  }d }| jro|d urot| dsDtd|  d|d urN|d	d  nd }| |
||||||}|d }
||dd  }|d }|| }t| j| j| j|
}|f| }| jr||f }|S )
NrW   )rp   ro   r   r   r   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`rs   )	r   rc   rY   rZ   r   r   feed_forward_chunkr   r   )r4   rj   rk   rl   rm   rn   ro   rp   self_attn_past_key_valueself_attention_outputsr   r   present_key_valuecross_attn_present_key_valuecross_attn_past_key_valuecross_attention_outputslayer_outputr8   r8   r9   rH   ]  sP   


	

zSplinterLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )r4   r   intermediate_outputr   r8   r8   r9   r     s   
z SplinterLayer.feed_forward_chunkr   )rI   rJ   rK   r   r/   r   r   rN   rP   r   rH   r   rQ   r8   r8   r6   r9   r   N  s4    	
Ar   c                       s   e Zd Z fddZ									ddejdeej deej d	eej d
eej deeeej   dee	 dee	 dee	 dee	 de
eej ef fddZ  ZS )SplinterEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r8   )r   ).0_r5   r8   r9   
<listcomp>  s    z,SplinterEncoder.__init__.<locals>.<listcomp>F)	r   r   r5   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr3   r6   r   r9   r     s   
 
zSplinterEncoder.__init__NFTrj   rk   rl   rm   rn   past_key_valuesr   rp   output_hidden_statesreturn_dictr>   c              
   C   s8  |	rdnd }|r
dnd }|r| j jrdnd }| jr%| jr%|r%td d}|r)dnd }t| jD ]K\}}|	r;||f }|d urC|| nd }|d urM|| nd }||||||||d}|d }|rg||d f7 }|r{||d f }| j jr{||d f }q0|	r||f }|
std	d
 |||||fD S t	|||||dS )Nr8   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...F)rn   ro   rp   r   r   r   rW   c                 s   s    | ]	}|d ur|V  qd S r   r8   )r   vr8   r8   r9   	<genexpr>  s    z*SplinterEncoder.forward.<locals>.<genexpr>last_hidden_stater   rj   
attentionscross_attentions)
r5   r   r   trainingloggerwarning_once	enumerater   rP   r
   )r4   rj   rk   rl   rm   rn   r   r   rp   r   r   all_hidden_statesall_self_attentionsall_cross_attentionsnext_decoder_cacheilayer_modulelayer_head_maskro   layer_outputsr8   r8   r9   rH     sd   


zSplinterEncoder.forward)	NNNNNNFFT)rI   rJ   rK   r   r/   r   r   rN   rP   r   r   r
   rH   rQ   r8   r8   r6   r9   r     sD    		
r   c                   @   s    e Zd ZeZdZdZdd ZdS )SplinterPreTrainedModelsplinterTc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjrX|jj	  |jjd dS dS )zInitialize the weightsg        )meanstdNg      ?)r   r   r]   weightdatanormal_r5   initializer_rangebiaszero_r    r   r)   fill_)r4   moduler8   r8   r9   _init_weights  s   

z%SplinterPreTrainedModel._init_weightsN)rI   rJ   rK   r   config_classbase_model_prefixsupports_gradient_checkpointingr   r8   r8   r8   r9   r     s
    r   c                        s   e Zd ZdZ fddZdd Zdd Zdd	 Ze	
	
	
	
	
	
	
	
	
	
	
	
	
dde	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	ee
j  de	e de	e de	e de	e deeef fddZ  ZS )SplinterModela2  
    The model is an encoder (with only self-attention) following the architecture described in [Attention is all you
    need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
    Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
    c                    s2   t  | || _t|| _t|| _|   d S r   )r   r   r5   r   rG   r   encoder	post_initr3   r6   r8   r9   r     s
   

zSplinterModel.__init__c                 C   s   | j jS r   rG   r$   )r4   r8   r8   r9   get_input_embeddings  s   z"SplinterModel.get_input_embeddingsc                 C   s   || j _d S r   r   )r4   r`   r8   r8   r9   set_input_embeddings"  s   z"SplinterModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r4   heads_to_pruner   r   r8   r8   r9   _prune_heads%  s   zSplinterModel._prune_headsNr:   rk   r;   r   rl   r<   rm   rn   r   r   rp   r   r   r>   c                 C   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}| j jr-|
dur(|
n| j j}
nd}
|dur;|dur;td|durJ| || | }n|durW| dd }ntd|\}}|durf|j	n|j	}|	durv|	d d j
d nd}|du rtj||| f|d}|du rtj|tj|d	}| ||}| j jr|dur| \}}}||f}|du rtj||d}| |}nd}| || j j}| j|||||d
}| j||||||	|
|||d
}|d }|s|f|dd  S t||j|j|j|jdS )a  
        token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        NFzDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embedsr   rW   )rA   r?   )r:   r   r;   r<   r=   )	rk   rl   rm   rn   r   r   rp   r   r   r   r   )r5   rp   r   use_return_dictrc   r   rZ   %warn_if_padding_and_no_attention_maskrB   rA   rw   r/   onesrC   rD   get_extended_attention_maskinvert_attention_maskget_head_maskr   rG   r   r
   r   rj   r   r   )r4   r:   rk   r;   r   rl   r<   rm   rn   r   r   rp   r   r   rE   
batch_sizerF   rA   r=   extended_attention_maskencoder_batch_sizeencoder_sequence_lengthr   encoder_hidden_shapeencoder_extended_attention_maskembedding_outputencoder_outputssequence_outputr8   r8   r9   rH   -  sx    
zSplinterModel.forward)NNNNNNNNNNNNN)rI   rJ   rK   rL   r   r   r   r  r   r   r/   r   listrN   r   r   rP   r
   rH   rQ   r8   r8   r6   r9   r     sb    
	

r   c                       s4   e Zd Zd fdd	ZdejdejfddZ  ZS )	SplinterFullyConnectedLayergeluc                    sD   t    || _|| _t| j| j| _t| | _t	| j| _	d S r   )
r   r   	input_dim
output_dimr   r]   r   r   act_fnr)   )r4   r  r  r   r6   r8   r9   r     s   

z$SplinterFullyConnectedLayer.__init__inputsr>   c                 C   s"   |  |}| |}| |}|S r   )r   r  r)   )r4   r  rj   r8   r8   r9   rH     s   


z#SplinterFullyConnectedLayer.forward)r  r   r8   r8   r6   r9   r    s    
r  c                       s(   e Zd ZdZ fddZdd Z  ZS )QuestionAwareSpanSelectionHeadzf
    Implementation of Question-Aware Span Selection (QASS) head, described in Splinter's paper:

    c                    sz   t    t|j|j| _t|j|j| _t|j|j| _t|j|j| _tj	|j|jdd| _
tj	|j|jdd| _d S )NF)r   )r   r   r  r"   query_start_transformquery_end_transformstart_transformend_transformr   r]   start_classifierend_classifierr3   r6   r8   r9   r     s   
z'QuestionAwareSpanSelectionHead.__init__c                 C   s   |  \}}}|ddd|}tj|d|d}| |}| |}| |}	| |}
| 	|}|	
ddd}	t||	}| |}|

ddd}
t||
}||fS )Nr   r   )rr   r   r   rW   )rB   	unsqueezerepeatr/   gatherr  r  r  r  r  rg   ru   r  )r4   r  	positionsr   rr   r   gathered_repsquery_start_repsquery_end_reps
start_repsend_repsrj   start_logits
end_logitsr8   r8   r9   rH     s   





z&QuestionAwareSpanSelectionHead.forward)rI   rJ   rK   rL   r   rH   rQ   r8   r8   r6   r9   r    s    r  c                       s   e Zd Z fddZe												ddeej deej deej deej deej d	eej d
eej deej dee	 dee	 dee	 deej de
eef fddZ  ZS )SplinterForQuestionAnsweringc                    4   t  | t|| _t|| _|j| _|   d S r   r   r   r   r   r  splinter_qassquestion_token_idr   r3   r6   r8   r9   r     
   

z%SplinterForQuestionAnswering.__init__Nr:   rk   r;   r   rl   r<   start_positionsend_positionsrp   r   r   question_positionsr>   c                 C   s  |dur|n| j j}d}|du r9|dur#tjt|| j dd}ntj|dtj	|j
|jd}|d}d}| j|||||||	|
|d	}|d }| ||\}}|r`|d	|d	}}|dur~|d	| t|jj  }|d	| t|jj  }d}|dur|durt| d	kr|d}t| d	kr|d}|d	}|d| |d| t|d
}|||}|||}|| d }|s||f|d	d  }|dur|f| S |S t||||j|jdS )a  
        token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
            The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
            num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
            the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
            sequence_length)`.
        NFr   rq   r   )r@   layoutrA   Trk   r;   r   rl   r<   rp   r   r   r   ignore_indexrW   lossr(  r)  rj   r   )r5   r  r/   argmaxeqr.  rO   rC   rB   rD   r3  rA   r  r   r-  squeezefinfor@   minr   clamp_r   r   rj   r   )r4   r:   rk   r;   r   rl   r<   r0  r1  rp   r   r   r2  question_positions_were_none"question_position_for_each_exampler   r  r(  r)  
total_lossignored_indexloss_fct
start_lossend_lossr   r8   r8   r9   rH     sj   $






z$SplinterForQuestionAnswering.forwardNNNNNNNNNNNN)rI   rJ   rK   r   r   r   r/   r   rM   r   r   rP   r   rH   rQ   r8   r8   r6   r9   r*    sT    
	

r*  zB
    Class for outputs of Splinter as a span selection model.
    )custom_introc                   @   st   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dS )SplinterForPreTrainingOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
        Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
    start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
        Span-start scores (before SoftMax).
    end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
        Span-end scores (before SoftMax).
    Nr8  r(  r)  rj   r   )rI   rJ   rK   rL   r8  r   r/   rN   __annotations__r(  r)  rj   rP   r   r8   r8   r8   r9   rH  O  s   
 	rH  z
    Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
    is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
    instead.
    c                       s   e Zd Z fddZe												ddeej deej deej deej deej d	eej d
eej deej dee	 dee	 dee	 deej de
eef fddZdejdejfddZ  ZS )SplinterForPreTrainingc                    r+  r   r,  r3   r6   r8   r9   r   n  r/  zSplinterForPreTraining.__init__Nr:   rk   r;   r   rl   r<   r0  r1  rp   r   r   r2  r>   c                 C   s  |dur|n| j j}|du r|dur|durtd|du r&|du r&td|du r/| |}| j|||||||	|
|d	}|d }| \}}}| ||\}}|d}|dur}|d|||}|d| t	
|jj  }|d| t	
|jj  }d}|dur|dur|dtd|d  |dtd|d  t| j jd}|||| |||| }|||| |||| }|| d }|s||f|dd  }|dur|f| S |S t||||j|jd	S )
a  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_questions, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `batch_size, num_questions, sequence_length`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `batch_size, num_questions, sequence_length`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
            The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
            num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
            the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
            sequence_length)`.
        NzCquestion_positions must be specified in order to calculate the lossz>question_positions must be specified when input_embeds is usedr4  r   r   r5  rW   r7  )r5   r  	TypeError_prepare_question_positionsr   rB   r-  r  r1   r/   r<  r@   r=  r>  maxr   r#   rf   rH  rj   r   )r4   r:   rk   r;   r   rl   r<   r0  r1  rp   r   r   r2  r   r  r  sequence_lengthrr   r(  r)  num_questions attention_mask_for_each_questionrA  rC  rD  rE  r   r8   r8   r9   rH   x  sh   7


zSplinterForPreTraining.forwardc                 C   sl   t || jjk\}}t |}t j|d| f| jjt j	|j
d}t dd |D }||||f< |S )Nr   r?   c                 S   s   g | ]}t |qS r8   )r/   r0   )r   nr8   r8   r9   r     s    zFSplinterForPreTraining._prepare_question_positions.<locals>.<listcomp>)r/   wherer5   r.  bincountfullrB   rM  r#   rD   rA   rt   )r4   r:   rowsflat_positionsrO  r"  colsr8   r8   r9   rL    s   
z2SplinterForPreTraining._prepare_question_positionsrF  )rI   rJ   rK   r   r   r   r/   r   rM   r   r   rP   rH  rH   rL  rQ   r8   r8   r6   r9   rJ  f  sV    
	

|rJ  )r*  rJ  r   r   r   )4rL   r{   dataclassesr   typingr   r   r/   torch.utils.checkpointr   torch.nnr   activationsr   modeling_layersr	   modeling_outputsr
   r   r   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   configuration_splinterr   
get_loggerrI   r   Moduler   rR   r   r   r   r   r   r   r   r   r   r  r  r*  rH  rJ  __all__r8   r8   r8   r9   <module>   s^   
7 4WQ &r 