o
    iD                     @   sD  d Z ddlZddlmZ ddlmZmZ ddlZddlmZ ddl	m
Z
 ddlmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlmZmZmZ ddlmZmZmZmZ ddlmZ e e!Z"eeddG dd deZ#G dd dej$Z%G dd dej$Z&G dd dej$Z'G dd dej$Z(G dd dej$Z)G dd dej$Z*G d d! d!ej$Z+G d"d# d#ej$Z,G d$d% d%ej$Z-G d&d' d'eZ.G d(d) d)ej$Z/G d*d+ d+ej$Z0G d,d- d-ej$Z1eG d.d/ d/eZ2eG d0d1 d1e2Z3eG d2d3 d3e2Z4ed4dG d5d6 d6e2Z5ed7dG d8d9 d9e2Z6g d:Z7dS );zPyTorch Bros model.    N)	dataclass)OptionalUnion)nn)CrossEntropyLoss   )ACT2FN)GradientCheckpointingLayer)"BaseModelOutputWithCrossAttentions,BaseModelOutputWithPoolingAndCrossAttentionsTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringcan_return_tuplelogging   )
BrosConfigz@
    Base class for outputs of token classification models.
    )custom_introc                   @   st   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dS )BrosSpadeOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Classification loss.
    initial_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
        Classification scores for entity initial tokens (before SoftMax).
    subsequent_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length+1)`):
        Classification scores for entity sequence tokens (before SoftMax).
    Nlossinitial_token_logitssubsequent_token_logitshidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   tupler    r&   r&   Z/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/bros/modeling_bros.pyr   )   s   
 	r   c                       2   e Zd Z fddZdejdejfddZ  ZS )BrosPositionalEmbedding1Dc                    s@   t    |j| _ddtd| jd| j   }| d| d S )Nr   i'          g       @inv_freq)super__init__dim_bbox_sinusoid_emb_1dr"   arangeregister_buffer)selfconfigr+   	__class__r&   r'   r-   C   s   
z"BrosPositionalEmbedding1D.__init__pos_seqreturnc                 C   sX   |  }|\}}}||||d| jddd| jd  }tj| | gdd}|S )Nr      dim)sizeviewr+   r.   r"   catsincos)r1   r5   seq_sizeb1b2b3sinusoid_inppos_embr&   r&   r'   forwardM   s
   
(z!BrosPositionalEmbedding1D.forwardr   r   r    r-   r"   TensorrF   __classcell__r&   r&   r3   r'   r)   @   s    
r)   c                       r(   )BrosPositionalEmbedding2Dc                    s*   t    |j| _t|| _t|| _d S N)r,   r-   dim_bboxr)   	x_pos_emb	y_pos_embr1   r2   r3   r&   r'   r-   V   s   

z"BrosPositionalEmbedding2D.__init__bboxr6   c                 C   sd   g }t | jD ]!}|d dkr|| |d|f  q|| |d|f  qtj|dd}|S )Nr7   r   .r8   r9   )rangerL   appendrM   rN   r"   r=   )r1   rP   stackibbox_pos_embr&   r&   r'   rF   ]   s   z!BrosPositionalEmbedding2D.forwardrG   r&   r&   r3   r'   rJ   U   s    rJ   c                       s,   e Zd Z fddZdejfddZ  ZS )BrosBboxEmbeddingsc                    s.   t    t|| _tj|j|jdd| _d S )NF)bias)	r,   r-   rJ   bbox_sinusoid_embr   Lineardim_bbox_sinusoid_emb_2ddim_bbox_projectionbbox_projectionrO   r3   r&   r'   r-   i   s   

zBrosBboxEmbeddings.__init__rP   c                 C   s\   | dd}|d d d d d d d f |d d d d d d d f  }| |}| |}|S )Nr   r   )	transposerX   r\   )r1   rP   bbox_tbbox_posrU   r&   r&   r'   rF   n   s
   8

zBrosBboxEmbeddings.forwardrG   r&   r&   r3   r'   rV   h   s    rV   c                       sb   e Zd ZdZ fddZ				ddeej deej deej deej d	ejf
d
dZ  Z	S )BrosTextEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| dt|jd | jdtj| j tj| jjdd	d
 d S )N)padding_idxepsposition_embedding_typeabsoluteposition_ids)r   r8   token_type_idsdtypedeviceF)
persistent)r,   r-   r   	Embedding
vocab_sizehidden_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutgetattrrd   r0   r"   r/   expandzerosrf   r;   longrj   rO   r3   r&   r'   r-   z   s"   

zBrosTextEmbeddings.__init__N	input_idsrg   rf   inputs_embedsr6   c                 C   s   |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u rNt| drC| jd d d |f }||d |}|}ntj|tj| jjd}|d u rW| 	|}| 
|}	||	 }
| jdkrn| |}|
|7 }
| |
}
| |
}
|
S )Nr8   r   rg   r   rh   re   )r;   rf   hasattrrg   r{   r"   r|   r}   rj   rp   rt   rd   rr   ru   ry   )r1   r~   rg   rf   r   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedrt   
embeddingsrr   r&   r&   r'   rF      s,   







zBrosTextEmbeddings.forward)NNNN)
r   r   r    r!   r-   r   r"   rH   rF   rI   r&   r&   r3   r'   r`   w   s$    r`   c                       sz   e Zd Z fddZ					ddejdejdeej deej d	eej d
eej deej deej fddZ  Z	S )BrosSelfAttentionc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t|dd| _| jdksf| jd	krw|j| _t	d
|j d | j| _|j| _d S )Nr   embedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()rd   re   relative_keyrelative_key_queryr7   r   )r,   r-   rn   num_attention_headsr   
ValueErrorintattention_head_sizeall_head_sizer   rY   querykeyvaluerw   attention_probs_dropout_probry   rz   rd   rq   rl   distance_embedding
is_decoderrO   r3   r&   r'   r-      s&   

zBrosSelfAttention.__init__NFr   rU   attention_mask	head_maskencoder_hidden_statesencoder_attention_maskoutput_attentionsr6   c                 C   s  |j d d| j| jf}| ||dd}	|d u}
|
r8| ||dd}| ||dd}|}n| ||dd}| ||dd}t	|	|dd}| j
dksd| j
dkr| d }tj|tj|jddd}tj|tj|jddd}|| }| || j d }|j|	jd	}| j
dkrtd
|	|}|| }n| j
dkrtd
|	|}td||}|| | }|	j \}}}}|||||}|g d}td|	|f}|| }|t| j }|d ur|| }tjdd|}| |}|d ur|| }t	||}|dddd }| d d | jf }|j| }|r5||fn|f}| jr@|d }|S )Nr   r8   r   r7   r   r   rh   )ri   zbhld,lrd->bhlrzbhrd,lrd->bhlr)r7   r   r   r   zbnid,bijd->bnijr9   r   rK   )shaper   r   r   r<   r]   r   r   r"   matmulrd   r;   r/   r}   rj   r   rq   tori   einsumpermutemathsqrtr   Softmaxry   
contiguousr   r   )r1   r   rU   r   r   r   r   r   hidden_shapequery_layeris_cross_attention	key_layervalue_layerattention_scoresr   position_ids_lposition_ids_rdistancepositional_embeddingrelative_position_scoresrelative_position_scores_queryrelative_position_scores_key
batch_sizen_headd_headbbox_pos_scoresattention_probscontext_layernew_context_layer_shapeoutputsr&   r&   r'   rF      sX   






zBrosSelfAttention.forwardNNNNF)
r   r   r    r-   r"   rH   r   r%   rF   rI   r&   r&   r3   r'   r      s0    	r   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )BrosSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nrb   )r,   r-   r   rY   rn   denseru   rv   rw   rx   ry   rO   r3   r&   r'   r-   &     
zBrosSelfOutput.__init__r   input_tensorr6   c                 C   &   |  |}| |}| || }|S rK   r   ry   ru   r1   r   r   r&   r&   r'   rF   ,     

zBrosSelfOutput.forwardrG   r&   r&   r3   r'   r   %      $r   c                       s   e Zd Z fddZdd Z					ddejdejd	eej d
eej deej deej dee de	ej fddZ
  ZS )BrosAttentionc                    s*   t    t|| _t|| _t | _d S rK   )r,   r-   r   r1   r   outputsetpruned_headsrO   r3   r&   r'   r-   4  s   


zBrosAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r9   )lenr   r1   r   r   r   r   r   r   r   r   r   r   union)r1   headsindexr&   r&   r'   prune_heads:  s   zBrosAttention.prune_headsNFr   rU   r   r   r   r   r   r6   c              	   C   s>   | j |||||||d}| |d |}	|	f|dd   }
|
S )Nr   rU   r   r   r   r   r   r   r   )r1   r   )r1   r   rU   r   r   r   r   r   self_outputsattention_outputr   r&   r&   r'   rF   O  s   
	zBrosAttention.forwardr   )r   r   r    r-   r   r"   rH   r   boolr%   rF   rI   r&   r&   r3   r'   r   3  s2    	r   c                       r(   )BrosIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S rK   )r,   r-   r   rY   rn   intermediate_sizer   
isinstance
hidden_actstrr   intermediate_act_fnrO   r3   r&   r'   r-   i  s
   
zBrosIntermediate.__init__r   r6   c                 C   s   |  |}| |}|S rK   )r   r   )r1   r   r&   r&   r'   rF   q  s   

zBrosIntermediate.forwardrG   r&   r&   r3   r'   r   h  s    r   c                       r   )
BrosOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r,   r-   r   rY   r   rn   r   ru   rv   rw   rx   ry   rO   r3   r&   r'   r-   x  r   zBrosOutput.__init__r   r   r6   c                 C   r   rK   r   r   r&   r&   r'   rF   ~  r   zBrosOutput.forwardrG   r&   r&   r3   r'   r   w  r   r   c                       s   e Zd Z fddZ					ddejdejdeej deej d	eej d
eej dee de	ej fddZ
dd Z  ZS )	BrosLayerc                    sn   t    |j| _d| _t|| _|j| _|j| _| jr+| js&t|  dt|| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is added)r,   r-   chunk_size_feed_forwardseq_len_dimr   	attentionr   add_cross_attention	Exceptioncrossattentionr   intermediater   r   rO   r3   r&   r'   r-     s   



zBrosLayer.__init__NFr   rU   r   r   r   r   r   r6   c                 C   s   | j |||||d}|d }	| jr|dd }
n|dd  }
| jrI|d urIt| dr2td|  d| j|	|||||d}|d }	|
|dd  }
t| j| j| j|	}|f|
 }
| jr_|
d	 }
|
S )
N)rU   r   r   r   r   r   r8   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`)r   r   r   r   r   rK   )	r   r   r   r   r   r   feed_forward_chunkr   r   )r1   r   rU   r   r   r   r   r   self_attention_outputsr   r   cross_attention_outputslayer_outputr&   r&   r'   rF     sH   



zBrosLayer.forwardc                 C   s   |  |}| ||}|S rK   )r   r   )r1   r   intermediate_outputr   r&   r&   r'   r     s   
zBrosLayer.feed_forward_chunkr   )r   r   r    r-   r"   rH   r   r#   r   r%   rF   r   rI   r&   r&   r3   r'   r     s2    	
8r   c                       s   e Zd Z fddZe							ddejdejdeej d	eej d
eej deej dee	 dee	 dee	 de
eej ef fddZ  ZS )BrosEncoderc                    s4   t     | _t fddt jD | _d S )Nc                    s   g | ]}t  qS r&   )r   ).0_r2   r&   r'   
<listcomp>  s    z(BrosEncoder.__init__.<locals>.<listcomp>)r,   r-   r2   r   
ModuleListrQ   num_hidden_layerslayerrO   r3   r   r'   r-     s   
$zBrosEncoder.__init__NFTr   rU   r   r   r   r   r   output_hidden_statesreturn_dictr6   c
              
   C   s   |rdnd }
|r
dnd }|r| j jrdnd }t| jD ]8\}}|r&|
|f }
|d ur.|| nd }||||||||d}|d }|rS||d f }| j jrS||d f }q|r[|
|f }
t||
||dS )Nr&   r   r   r   r7   )last_hidden_stater   r   cross_attentions)r2   r   	enumerater   r
   )r1   r   rU   r   r   r   r   r   r   r   all_hidden_statesall_self_attentionsall_cross_attentionsrT   layer_modulelayer_head_masklayer_outputsr&   r&   r'   rF     s<   


zBrosEncoder.forward)NNNNFFT)r   r   r    r-   r   r"   rH   r   r#   r   r   r%   r
   rF   rI   r&   r&   r3   r'   r     s>    	
r   c                       r(   )
BrosPoolerc                    s*   t    t|j|j| _t | _d S rK   )r,   r-   r   rY   rn   r   Tanh
activationrO   r3   r&   r'   r-     s   
zBrosPooler.__init__r   r6   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r1   r   first_token_tensorpooled_outputr&   r&   r'   rF     s   

zBrosPooler.forwardrG   r&   r&   r3   r'   r     s    r   c                       r(   )BrosRelationExtractorc                    s   t    |j| _|j| _|j| _|j| _t| j| _	t
| j| j| j | _t
| j| j| j | _ttd| j| _d S )Nr   )r,   r-   n_relationsrn   backbone_hidden_sizehead_hidden_sizeclassifier_dropout_probr   rw   droprY   r   r   	Parameterr"   r|   
dummy_noderO   r3   r&   r'   r-     s   
zBrosRelationExtractor.__init__r   r   c              	   C   s   |  | |}| jdd|dd}tj||gdd}| | |}|	|d|d| j
| j}|	|d|d| j
| j}t|dddd|dddd}|S )Nr   r   axisr7   r   )r   r  r  	unsqueezerepeatr;   r"   r=   r   r<   r   r  r   r   )r1   r   r   	dummy_vecrelation_scorer&   r&   r'   rF   )  s    zBrosRelationExtractor.forwardrG   r&   r&   r3   r'   r     s    r   c                   @   s*   e Zd ZU eed< dZdejfddZdS )BrosPreTrainedModelr2   brosmodulec                 C   s   | j j}t|tjr"|jjjd|d |jdur |jj	  dS dS t|tj
rC|jjjd|d |jdurA|jj|j 	  dS dS t|tjrX|jj	  |jjd dS t|trhtjj|j|d dS dS )zInitialize the weightsr*   )meanstdNg      ?)r  )r2   initializer_ranger   r   rY   weightdatanormal_rW   zero_rl   ra   ru   fill_r   initr  )r1   r  r  r&   r&   r'   _init_weightsA  s"   


z!BrosPreTrainedModel._init_weightsN)	r   r   r    r   r$   base_model_prefixr   Moduler  r&   r&   r&   r'   r  <  s   
 r  c                       s   e Zd Zd fdd	Zdd Zdd Zdd	 Zee	
	
	
	
	
	
	
	
	
	
	
	
dde	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e de	e de	e deee
j ef fddZ  ZS )	BrosModelTc                    sN   t  | || _t|| _t|| _t|| _|rt	|nd| _
|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)r,   r-   r2   r`   r   rV   bbox_embeddingsr   encoderr   poolerinit_weights)r1   r2   add_pooling_layerr3   r&   r'   r-   W  s   


zBrosModel.__init__c                 C   s   | j jS rK   r   rp   )r1   r&   r&   r'   get_input_embeddingsg  s   zBrosModel.get_input_embeddingsc                 C   s   || j _d S rK   r"  )r1   r   r&   r&   r'   set_input_embeddingsj  s   zBrosModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )r1   heads_to_pruner   r   r&   r&   r'   _prune_headsm  s   zBrosModel._prune_headsNr~   rP   r   rg   rf   r   r   r   r   r   r   r   r6   c                 C   s*  |
dur|
n| j j}
|dur|n| j j}|dur|n| j j}|dur*|dur*td|dur3| }n|dur@| dd }ntd|du rLtd|\}}|durW|jn|j}|du retj||d}|du rt	| j
dr| j
jddd|f }|||}|}n	tj|tj|d}| |||}| j jr|dur| \}}}||f}|	du rtj||d}	| |	}nd}| || j j}| j
||||d	}|jd d
kr|ddddg df }|| j j }| |}| j|||||||
|dd	}|d }| jdur| |nd}t|||j|j|jdS )a  
        bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
            Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
            (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
            bounding box.

        Examples:

        ```python
        >>> import torch
        >>> from transformers import BrosProcessor, BrosModel

        >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")

        >>> model = BrosModel.from_pretrained("jinho8345/bros-base-uncased")

        >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
        >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
        >>> encoding["bbox"] = bbox

        >>> outputs = model(**encoding)
        >>> last_hidden_states = outputs.last_hidden_state
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timer8   z5You have to specify either input_ids or inputs_embedszYou have to specify bbox)rj   rg   rh   )r~   rf   rg   r      )r   r   r7   r   r7   r   r   r   T)rU   r   r   r   r   r   r   r   r   )r   pooler_outputr   r   r   )r2   r   r   use_return_dictr   r;   rj   r"   onesr   r   rg   r{   r|   r}   get_extended_attention_maskr   invert_attention_maskget_head_maskr   r   
bbox_scaler  r  r  r   r   r   r   )r1   r~   rP   r   rg   rf   r   r   r   r   r   r   r   r   r   r   rj   r   r   extended_attention_maskencoder_batch_sizeencoder_sequence_lengthr   encoder_hidden_shapeencoder_extended_attention_maskembedding_outputscaled_bboxbbox_position_embeddingsencoder_outputssequence_outputr   r&   r&   r'   rF   u  s|   (

zBrosModel.forward)TNNNNNNNNNNNN)r   r   r    r-   r#  r$  r'  r   r   r   r"   rH   r   r   r%   r   rF   rI   r&   r&   r3   r'   r  U  s\    	
r  c                          e Zd ZdgZ fddZee												ddeej	 deej	 deej	 deej	 d	eej	 d
eej	 deej	 deej	 deej	 dee
 dee
 dee
 deeej	 ef fddZ  ZS )BrosForTokenClassificationr  c                    s^   t  | |j| _t|| _t|dr|jn|j}t	|| _
t|j|j| _|   d S Nclassifier_dropout)r,   r-   
num_labelsr  r  r   r>  rx   r   rw   ry   rY   rn   
classifierr   r1   r2   r>  r3   r&   r'   r-     s   
z#BrosForTokenClassification.__init__Nr~   rP   r   bbox_first_token_maskrg   rf   r   r   labelsr   r   r   r6   c                 C   s   |dur|n| j j}| j||||||||
|dd
}|d }| |}| |}d}|	durXt }|durK|d}||d| j| |	d| }n||d| j|	d}t|||j	|j
dS )a  
        bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
            Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
            (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
            bounding box.
        bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

        Examples:

        ```python
        >>> import torch
        >>> from transformers import BrosProcessor, BrosForTokenClassification

        >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")

        >>> model = BrosForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")

        >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
        >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
        >>> encoding["bbox"] = bbox

        >>> outputs = model(**encoding)
        ```NT)	rP   r   rg   rf   r   r   r   r   r   r   r8   r   logitsr   r   )r2   r*  r  ry   r@  r   r<   r?  r   r   r   )r1   r~   rP   r   rB  rg   rf   r   r   rC  r   r   r   r   r9  rE  r   loss_fctr&   r&   r'   rF     s>   -


z"BrosForTokenClassification.forwardr:  r   r   r    "_keys_to_ignore_on_load_unexpectedr-   r   r   r   r"   rH   r   r   r%   r   rF   rI   r&   r&   r3   r'   r<    sX    	
r<  a  
    Bros Model with a token classification head on top (initial_token_layers and subsequent_token_layer on top of the
    hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. The initial_token_classifier is used to
    predict the first token of each entity, and the subsequent_token_classifier is used to predict the subsequent
    tokens within an entity. Compared to BrosForTokenClassification, this model is more robust to serialization errors
    since it predicts next token from one token.
    c                !       s   e Zd ZdgZ fddZee													ddeej	 deej	 deej	 deej	 d	eej	 d
eej	 deej	 deej	 deej	 deej	 dee
 dee
 dee
 deeej	 ef fddZ  ZS )!BrosSpadeEEForTokenClassificationr  c              	      s   t  | || _|j| _|j| _|j| _t|| _t	|dr"|j
n|j}tt|t|j|jt|t|j|j| _t|| _|   d S r=  )r,   r-   r2   r?  r   rn   r  r  r  r   r>  rx   r   
Sequentialrw   rY   initial_token_classifierr   subsequent_token_classifierr   rA  r3   r&   r'   r-   h  s    

z*BrosSpadeEEForTokenClassification.__init__Nr~   rP   r   rB  rg   rf   r   r   initial_token_labelssubsequent_token_labelsr   r   r   r6   c                 C   s  |dur|n| j j}| j|||||||||dd
}|d }|dd }| |dd }| ||d}d| }|j\}}|j	}t
j|t
|dg|gdd }||dddddf t
|jj}t
||d j|t
jd}||dddddf t
|jj}|d }d}|	dur|
durt }|	d}	|dur|d}||d| j| |	| }n
||d| j|	}|
d}
||d|d | |
| }|| }t||||j|jd	S )
a>  
        bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
            Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
            (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
            bounding box.
        bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        initial_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for the initial token classification.
        subsequent_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for the subsequent token classification.

        Examples:

        ```python
        >>> import torch
        >>> from transformers import BrosProcessor, BrosSpadeEEForTokenClassification

        >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")

        >>> model = BrosSpadeEEForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")

        >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
        >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
        >>> encoding["bbox"] = bbox

        >>> outputs = model(**encoding)
        ```NT
r~   rP   r   rg   rf   r   r   r   r   r   r   r   r  rj   ri   r8   )r   r   r   r   r   )r2   r*  r  r]   r   rK  rL  squeezer   rj   r"   r=   r|   r   r   masked_fillfinfori   mineyer<   r   r?  r   r   r   )r1   r~   rP   r   rB  rg   rf   r   r   rM  rN  r   r   r   r   last_hidden_statesr   r   inv_attention_maskr   max_seq_lengthrj   invalid_token_maskself_token_masksubsequent_token_maskr   rF  initial_token_losssubsequent_token_lossr&   r&   r'   rF     sj   2
&  


z)BrosSpadeEEForTokenClassification.forward)NNNNNNNNNNNNN)r   r   r    rH  r-   r   r   r   r"   rH   r   r   r%   r   rF   rI   r&   r&   r3   r'   rI  \  s^    
	
rI  z
    Bros Model with a token classification head on top (a entity_linker layer on top of the hidden-states output) e.g.
    for Entity-Linking. The entity_linker is used to predict intra-entity links (one entity to another entity).
    c                       r;  )!BrosSpadeELForTokenClassificationr  c                    sx   t  | || _|j| _|j| _|j| _t|| _t	|dr"|j
n|j t|| _|   d S  t|| _|   d S r=  )r,   r-   r2   r?  r   rn   r  r  r  r   r>  rx   r   entity_linkerr   rO   r3   r&   r'   r-     s   


z*BrosSpadeELForTokenClassification.__init__Nr~   rP   r   rB  rg   rf   r   r   rC  r   r   r   r6   c                 C   sR  |dur|n| j j}| j||||||||
|dd
}|d }|dd }| ||d}d}|	durt }|j\}}|j	}t
||d j|t
jd}|d}t
j| t
j|dgt
j|dgdd	}||dddddf t
|jj}||dddddf t
|jj}||d|d | |	d| }t|||j|jd
S )a  
        bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
            Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
            (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
            bounding box.
        bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

        Examples:

        ```python
        >>> import torch
        >>> from transformers import BrosProcessor, BrosSpadeELForTokenClassification

        >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")

        >>> model = BrosSpadeELForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")

        >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
        >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
        >>> encoding["bbox"] = bbox

        >>> outputs = model(**encoding)
        ```NTrO  r   r   rP  r8   rh   r  rD  )r2   r*  r  r]   r   r_  rQ  r   r   rj   r"   rU  r   r   r<   r=   r|   rR  rS  ri   rT  r   r   r   )r1   r~   rP   r   rB  rg   rf   r   r   rC  r   r   r   r   rV  rE  r   rF  r   rX  rj   rZ  maskr&   r&   r'   rF     sL   ,

(($z)BrosSpadeELForTokenClassification.forwardr:  rG  r&   r&   r3   r'   r^    sX    	
r^  )r  r  r<  rI  r^  )8r!   r   dataclassesr   typingr   r   r"   r   torch.nnr   activationsr   modeling_layersr	   modeling_outputsr
   r   r   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   r   r   configuration_brosr   
get_loggerr   loggerr   r  r)   rJ   rV   r`   r   r   r   r   r   r   r   r   r   r  r  r<  rI  r^  __all__r&   r&   r&   r'   <module>   sd   
Am5M9" "d	 o