o
    	۷i                    @   s  d Z ddlZddlZddlZddlmZ ddlmZmZ ddl	Z	ddl	m
Z
 ddlmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZmZmZmZ ddlmZ ddlmZmZmZ ddl m!Z!m"Z" ddl#m$Z$ e"%e&Z'g dZ(ee!ddG dd deZ)dd Z*G dd de
j+Z,G dd de
j+Z-G dd de
j+Z.G dd de
j+Z/G dd  d e
j+Z0G d!d" d"e
j+Z1G d#d$ d$e
j+Z2G d%d& d&e
j+Z3G d'd( d(eZ4G d)d* d*e
j+Z5G d+d, d,e
j+Z6G d-d. d.e
j+Z7G d/d0 d0e
j+Z8G d1d2 d2e
j+Z9e!G d3d4 d4eZ:e!G d5d6 d6e:Z;e!d7dG d8d9 d9e:Z<e!G d:d; d;e:Z=e!G d<d= d=e:Z>e!G d>d? d?e:Z?g d@Z@dS )AzPyTorch CANINE model.    N)	dataclass)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputModelOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )CanineConfig)   +   ;   =   I   a   g   q                           a  
    Output type of [`CanineModel`]. Based on [`~modeling_outputs.BaseModelOutputWithPooling`], but with slightly
    different `hidden_states` and `attentions`, as these also include the hidden states and attentions of the shallow
    Transformer encoders.
    )custom_introc                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )CanineModelOutputWithPoolinga  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
        Sequence of hidden-states at the output of the last layer of the model (i.e. the output of the final
        shallow Transformer encoder).
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
        Hidden-state of the first token of the sequence (classification token) at the last layer of the deep
        Transformer encoder, further processed by a Linear layer and a Tanh activation function. The Linear layer
        weights are trained from the next sentence prediction (classification) objective during pretraining.
    hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the input to each encoder + one for the output of each layer of each
        encoder) of shape `(batch_size, sequence_length, hidden_size)` and `(batch_size, sequence_length //
        config.downsampling_rate, hidden_size)`. Hidden-states of the model at the output of each layer plus the
        initial input to each Transformer encoder. The hidden states of the shallow encoders have length
        `sequence_length`, but the hidden states of the deep encoder have length `sequence_length` //
        `config.downsampling_rate`.
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of the 3 Transformer encoders of shape `(batch_size,
        num_heads, sequence_length, sequence_length)` and `(batch_size, num_heads, sequence_length //
        config.downsampling_rate, sequence_length // config.downsampling_rate)`. Attentions weights after the
        attention softmax, used to compute the weighted average in the self-attention heads.
    Nlast_hidden_statepooler_outputhidden_states
attentions)__name__
__module____qualname____doc__r,   r   torchFloatTensor__annotations__r-   r.   tupler/    r8   r8   `/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/canine/modeling_canine.pyr+   2   s   
 	r+   c                 C   s*  zddl }ddl}ddl}W n ty   td  w tj|}t	d|  |j
|}g }g }	|D ] \}
}t	d|
 d|  |j
||
}||
 |	| q6t||	D ]5\}
}|
d}
tdd	 |
D r{t	d
d|
  q\|
d dkrd|
d< n>|
d dkr|
|
d  n0|
d dkrd|
d< n%|
d dkrdg|
dd  }
n|
d dkr|
d dv rdg|
dd  }
| }|
D ]x}|d|rd|vr|d|}n|g}|d dks|d dkrt|d}n?|d dks|d dkrt|d}n,|d d krt|d}nz	t||d }W n ty.   t	d
d|
  Y qw t|d!kr@t|d }|| }q|d"d d#krPt|d}n |d$d d%d& td'D v rft|d}n
|dkrp||}|j|jkrtd(|j d)|j d*t	d+|
  t||_q\| S ),z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape /c                 s   s    | ]}|d v V  qdS ))adam_vadam_mAdamWeightDecayOptimizerAdamWeightDecayOptimizer_1global_stepclsautoregressive_decoderchar_output_weightsNr8   ).0nr8   r8   r9   	<genexpr>u   s    
z,load_tf_weights_in_canine.<locals>.<genexpr>z	Skipping bertencoderr   
embeddingssegment_embeddingstoken_type_embeddingsinitial_char_encoderchars_to_moleculesfinal_char_encoder)	LayerNormconv
projectionz[A-Za-z]+_\d+Embedderz_(\d+)kernelgammaweightoutput_biasbetabiasoutput_weights   i_embeddingsic                 S   s   g | ]}d | qS )	Embedder_r8   )rC   ir8   r8   r9   
<listcomp>   s    z-load_tf_weights_in_canine.<locals>.<listcomp>   zPointer shape z and array shape z mismatchedzInitialize PyTorch weight )renumpy
tensorflowImportErrorloggererrorospathabspathinfotrainlist_variablesload_variableappendzipsplitanyjoinremove	fullmatchgetattrAttributeErrorlenintrange	transposeshape
ValueErrorr4   
from_numpydata)modelconfigtf_checkpoint_pathr`   nptftf_path	init_varsnamesarraysnamerz   arraypointerm_namescope_namesnumr8   r8   r9   load_tf_weights_in_canineW   s   



 

r   c                       s   e Zd ZdZ fddZdedefddZdededefd	d
Z				ddee	j
 dee	j
 dee	j
 dee	j de	jf
ddZ  ZS )CanineEmbeddingsz<Construct the character, position and token_type embeddings.c                    s   t    || _|j|j }t|jD ]}d| }t| |t|j	| qt|j	|j| _
t|j|j| _tj|j|jd| _t|j| _| jdt|jddd t|dd| _d S )	NHashBucketCodepointEmbedder_epsposition_ids)r   F)
persistentposition_embedding_typeabsolute)super__init__r   hidden_sizenum_hash_functionsrx   setattrr   	Embeddingnum_hash_bucketschar_position_embeddingstype_vocab_sizerJ   rO   layer_norm_epsDropouthidden_dropout_probdropoutregister_bufferr4   arangemax_position_embeddingsexpandrt   r   )selfr   shard_embedding_sizer]   r   	__class__r8   r9   r      s   

zCanineEmbeddings.__init__
num_hashesnum_bucketsc                 C   sV   |t tkrtdt t td| }g }|D ]}|d | | }|| q|S )a  
        Converts ids to hash bucket ids via multiple hashing.

        Args:
            input_ids: The codepoints or other IDs to be hashed.
            num_hashes: The number of hash functions to use.
            num_buckets: The number of hash buckets (i.e. embeddings in each table).

        Returns:
            A list of tensors, each of which is the hash bucket IDs from one hash function.
        z`num_hashes` must be <= Nr   )rv   _PRIMESr{   rm   )r   	input_idsr   r   primesresult_tensorsprimehashedr8   r8   r9   _hash_bucket_tensors   s   z%CanineEmbeddings._hash_bucket_tensorsembedding_sizec                 C   sx   || dkrt d| d| d| j|||d}g }t|D ]\}}d| }	t| |	|}
||
 qtj|ddS )	zDConverts IDs (e.g. codepoints) into embeddings via multiple hashing.r   zExpected `embedding_size` (z) % `num_hashes` (z) == 0)r   r   r   r   dim)r{   r   	enumeratert   rm   r4   cat)r   r   r   r   r   hash_bucket_tensorsembedding_shardsr]   hash_bucket_idsr   shard_embeddingsr8   r8   r9   _embed_hash_buckets   s   
z$CanineEmbeddings._embed_hash_bucketsNr   token_type_idsr   inputs_embedsreturnc           
      C   s   |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u r3tj|tj| jjd}|d u rE| || jj| jj	| jj
}| |}|| }| jdkr\| |}	||	7 }| |}| |}|S )Nr   r   dtypedevicer   )sizer   r4   zeroslongr   r   r   r   r   r   rJ   r   r   rO   r   )
r   r   r   r   r   input_shape
seq_lengthrJ   rH   position_embeddingsr8   r8   r9   forward   s(   





zCanineEmbeddings.forward)NNNN)r0   r1   r2   r3   r   rw   r   r   r   r4   
LongTensorr5   r   __classcell__r8   r8   r   r9   r      s(    r   c                       s6   e Zd ZdZ fddZdejdejfddZ  ZS )CharactersToMoleculeszeConvert character sequence to initial molecule sequence (i.e. downsample) using strided convolutions.c                    sJ   t    tj|j|j|j|jd| _t|j | _	tj
|j|jd| _
d S )Nin_channelsout_channelskernel_sizestrider   )r   r   r   Conv1dr   downsampling_raterP   r
   
hidden_act
activationrO   r   r   r   r   r8   r9   r      s   
zCharactersToMolecules.__init__char_encodingr   c                 C   s   |d d ddd d f }t |dd}| |}t |dd}| |}|d d ddd d f }t j||gdd}| |}|S )Nr   r   rZ   r   r   )r4   ry   rP   r   r   rO   )r   r   cls_encodingdownsampleddownsampled_truncatedresultr8   r8   r9   r   /  s   


zCharactersToMolecules.forward)	r0   r1   r2   r3   r   r4   Tensorr   r   r8   r8   r   r9   r     s    r   c                       sD   e Zd ZdZ fddZ	d
dejdeej dejfdd	Z  Z	S )ConvProjectionz
    Project representations from hidden_size*2 back to hidden_size across a window of w = config.upsampling_kernel_size
    characters.
    c                    s`   t    || _tj|jd |j|jdd| _t|j	 | _
tj|j|jd| _t|j| _d S )NrZ   r   r   r   )r   r   r   r   r   r   upsampling_kernel_sizerP   r
   r   r   rO   r   r   r   r   r   r   r8   r9   r   Q  s   
zConvProjection.__init__Ninputsfinal_seq_char_positionsr   c           
      C   s   t |dd}| jjd }|d }|| }t||fd}| ||}t |dd}| |}| |}| 	|}|}|d urDt
d|}	|	S )Nr   rZ   r   z,CanineForMaskedLM is currently not supported)r4   ry   r   r   r   ConstantPad1drP   r   rO   r   NotImplementedError)
r   r   r   	pad_totalpad_begpad_endpadr   final_char_seq	query_seqr8   r8   r9   r   `  s   


zConvProjection.forwardN)
r0   r1   r2   r3   r   r4   r   r   r   r   r8   r8   r   r9   r   K  s    r   c                       sj   e Zd Z fddZ			ddejdejdeej deej d	ee d
e	ejeej f fddZ
  ZS )CanineSelfAttentionc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t|dd| _| jdksf| jd	kry|j| _t	d
|j d | j| _d S d S )Nr   r   zThe hidden size (z6) is not a multiple of the number of attention heads ()r   r   relative_keyrelative_key_queryrZ   r   )r   r   r   num_attention_headshasattrr{   rw   attention_head_sizeall_head_sizer   Linearquerykeyvaluer   attention_probs_dropout_probr   rt   r   r   r   distance_embeddingr   r   r8   r9   r     s&   

zCanineSelfAttention.__init__NFfrom_tensor	to_tensorattention_mask	head_maskoutput_attentionsr   c                 C   s6  |j \}}}| ||d| j| jdd}	| ||d| j| jdd}
| ||d| j| jdd}t	||	dd}| j
dksM| j
dkr| d }tj|tj|jddd}tj|tj|jddd}|| }| || j d }|j|jd}| j
dkrtd	||}|| }n| j
dkrtd	||}td
|	|}|| | }|t| j }|d ur|jdkrtj|dd}d|  t|jj }|| }tjj|dd}| |}|d ur|| }t	||
}|dddd  }| d d | j!f }|j| }|r||f}|S |f}|S )Nr   r   rZ   rM   r   r   r   )r   zbhld,lrd->bhlrzbhrd,lrd->bhlrr	   r         ?r   )"rz   r   viewr   r   ry   r   r   r4   matmulr   r   r   r   r   r   r   tor   einsummathsqrtndim	unsqueezefloatfinfominr   
functionalsoftmaxr   permute
contiguousr   )r   r   r   r   r   r   
batch_sizer   _	key_layervalue_layerquery_layerattention_scoresposition_ids_lposition_ids_rdistancepositional_embeddingrelative_position_scoresrelative_position_scores_queryrelative_position_scores_keyattention_probscontext_layernew_context_layer_shapeoutputsr8   r8   r9   r     s\   





zCanineSelfAttention.forwardNNF)r0   r1   r2   r   r4   r   r   r5   boolr7   r   r   r8   r8   r   r9   r     s$    r   c                       sF   e Zd Z fddZdeej dejdeejejf fddZ  ZS )CanineSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr   )r   r   r   r   r   denserO   r   r   r   r   r   r   r8   r9   r        
zCanineSelfOutput.__init__r.   input_tensorr   c                 C   &   |  |}| |}| || }|S r   r$  r   rO   r   r.   r&  r8   r8   r9   r     s   

zCanineSelfOutput.forward	r0   r1   r2   r   r7   r4   r5   r   r   r8   r8   r   r9   r"    s    r"  c                       s   e Zd ZdZ							ddededededed	ef fd
dZdd Z			ddee	j
 dee	j
 dee	j
 dee dee	j
ee	j
 f f
ddZ  ZS )CanineAttentionav  
    Additional arguments related to local attention:

        - **local** (`bool`, *optional*, defaults to `False`) -- Whether to apply local attention.
        - **always_attend_to_first_position** (`bool`, *optional*, defaults to `False`) -- Should all blocks be able to
          attend
        to the `to_tensor`'s first position (e.g. a [CLS] position)? - **first_position_attends_to_all** (`bool`,
        *optional*, defaults to `False`) -- Should the *from_tensor*'s first position be able to attend to all
        positions within the *from_tensor*? - **attend_from_chunk_width** (`int`, *optional*, defaults to 128) -- The
        width of each block-wise chunk in `from_tensor`. - **attend_from_chunk_stride** (`int`, *optional*, defaults to
        128) -- The number of elements to skip when moving to the next block in `from_tensor`. -
        **attend_to_chunk_width** (`int`, *optional*, defaults to 128) -- The width of each block-wise chunk in
        *to_tensor*. - **attend_to_chunk_stride** (`int`, *optional*, defaults to 128) -- The number of elements to
        skip when moving to the next block in `to_tensor`.
    F   always_attend_to_first_positionfirst_position_attends_to_allattend_from_chunk_widthattend_from_chunk_strideattend_to_chunk_widthattend_to_chunk_stridec	           	         st   t    t|| _t|| _t | _|| _||k rt	d||k r&t	d|| _
|| _|| _|| _|| _|| _d S )Nze`attend_from_chunk_width` < `attend_from_chunk_stride` would cause sequence positions to get skipped.z``attend_to_chunk_width` < `attend_to_chunk_stride`would cause sequence positions to get skipped.)r   r   r   r   r"  outputsetpruned_headslocalr{   r-  r.  r/  r0  r1  r2  	r   r   r6  r-  r.  r/  r0  r1  r2  r   r8   r9   r     s&   



zCanineAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )rv   r   r   r   r   r5  r   r   r   r   r3  r$  r   union)r   headsindexr8   r8   r9   prune_heads1  s   zCanineAttention.prune_headsNr.   r   r   r   r   c                 C   sF  | j s| |||||}|d }n|jd  }}| }	}
g }| jr)|d d}nd}t||| jD ]}t||| j }|||f q2g }| jrP|d|f td|| j	D ]}t||| j
 }|||f qWt|t|kr|td| d| dg }g }t||D ]w\\}}\}}|	d d ||d d f }|
d d ||d d f }|d d ||||f }| jr|d d ||ddf }tj||gdd}|
d d ddd d f }tj||gdd}| |||||}||d  |r||d  qtj|dd}| ||}|f}| j s||dd   }|S |t| }|S )	Nr   r   )r   r   z/Expected to have same number of `from_chunks` (z) and `to_chunks` (z). Check strides.rZ   r   )r6  r   rz   r.  rm   rx   r0  r
  r/  r2  r1  rv   r{   rn   r-  r4   r   r3  r7   )r   r.   r   r   r   self_outputsattention_outputfrom_seq_lengthto_seq_lengthr   r   from_chunks
from_startchunk_start	chunk_end	to_chunksattention_output_chunksattention_probs_chunksfrom_endto_startto_endfrom_tensor_chunkto_tensor_chunkattention_mask_chunkcls_attention_maskcls_positionattention_outputs_chunkr  r8   r8   r9   r   C  sf   


zCanineAttention.forwardFFFr,  r,  r,  r,  r   )r0   r1   r2   r3   r!  rw   r   r;  r7   r4   r5   r   r   r   r8   r8   r   r9   r+    sJ    	!r+  c                       s2   e Zd Z fddZdejdejfddZ  ZS )CanineIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )r   r   r   r   r   intermediate_sizer$  
isinstancer   strr
   intermediate_act_fnr   r   r8   r9   r     s
   
zCanineIntermediate.__init__r.   r   c                 C      |  |}| |}|S r   )r$  rU  r   r.   r8   r8   r9   r        

zCanineIntermediate.forward)r0   r1   r2   r   r4   r5   r   r   r8   r8   r   r9   rQ    s    rQ  c                       s<   e Zd Z fddZdeej dejdejfddZ  ZS )CanineOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r#  )r   r   r   r   rR  r   r$  rO   r   r   r   r   r   r   r8   r9   r     r%  zCanineOutput.__init__r.   r&  r   c                 C   r'  r   r(  r)  r8   r8   r9   r     s   

zCanineOutput.forwardr*  r8   r8   r   r9   rY    s    (rY  c                       sp   e Zd Z fddZ			ddeej deej deej dee d	eejeej f f
d
dZ	dd Z
  ZS )CanineLayerc	           	   	      sH   t    |j| _d| _t||||||||| _t|| _t|| _	d S Nr   )
r   r   chunk_size_feed_forwardseq_len_dimr+  	attentionrQ  intermediaterY  r3  r7  r   r8   r9   r     s   


zCanineLayer.__init__NFr.   r   r   r   r   c           	      C   sH   | j ||||d}|d }|dd  }t| j| j| j|}|f| }|S )N)r   r   r   )r^  r   feed_forward_chunkr\  r]  )	r   r.   r   r   r   self_attention_outputsr=  r  layer_outputr8   r8   r9   r     s   
zCanineLayer.forwardc                 C   s   |  |}| ||}|S r   )r_  r3  )r   r=  intermediate_outputrb  r8   r8   r9   r`    s   
zCanineLayer.feed_forward_chunkr   )r0   r1   r2   r   r7   r4   r5   r   r!  r   r`  r   r8   r8   r   r9   rZ    s"    
rZ  c                       s   e Zd Z							d fdd	Z					ddeej deej d	eej d
ee dee dee de	ee
f fddZ  ZS )CanineEncoderFr,  c	           	   
      sH   t    | _t fddtjD | _d| _d S )Nc                    s"   g | ]}t  qS r8   )rZ  )rC   r  r-  r0  r/  r2  r1  r   r.  r6  r8   r9   r^     s    z*CanineEncoder.__init__.<locals>.<listcomp>F)	r   r   r   r   
ModuleListrx   num_hidden_layerslayergradient_checkpointingr7  r   re  r9   r     s   

zCanineEncoder.__init__NTr.   r   r   r   output_hidden_statesreturn_dictr   c                 C   s   |rdnd }|r
dnd }t | jD ])\}	}
|r||f }|d ur$||	 nd }|
||||}|d }|r:||d f }q|rB||f }|sPtdd |||fD S t|||dS )Nr8   r   r   c                 s       | ]	}|d ur|V  qd S r   r8   rC   vr8   r8   r9   rE          z(CanineEncoder.forward.<locals>.<genexpr>)r,   r.   r/   )r   rh  r7   r   )r   r.   r   r   r   rj  rk  all_hidden_statesall_self_attentionsr]   layer_modulelayer_head_masklayer_outputsr8   r8   r9   r     s(   	

zCanineEncoder.forwardrP  )NNFFT)r0   r1   r2   r   r7   r4   r5   r   r!  r   r   r   r   r8   r8   r   r9   rd    s:    !
rd  c                       6   e Zd Z fddZdeej dejfddZ  ZS )CaninePoolerc                    s*   t    t|j|j| _t | _d S r   )r   r   r   r   r   r$  Tanhr   r   r   r8   r9   r   )  s   
zCaninePooler.__init__r.   r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r$  r   )r   r.   first_token_tensorpooled_outputr8   r8   r9   r   .  s   

zCaninePooler.forwardr*  r8   r8   r   r9   rv  (  s    "rv  c                       ru  )CaninePredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S r#  )r   r   r   r   r   r$  rS  r   rT  r
   transform_act_fnrO   r   r   r   r8   r9   r   8  s   
z&CaninePredictionHeadTransform.__init__r.   r   c                 C   s"   |  |}| |}| |}|S r   )r$  r{  rO   rW  r8   r8   r9   r   A  s   


z%CaninePredictionHeadTransform.forwardr*  r8   r8   r   r9   rz  7  s    "	rz  c                       ru  )CanineLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)rX   )r   r   rz  	transformr   r   r   
vocab_sizedecoder	Parameterr4   r   rX   r   r   r8   r9   r   I  s
   

zCanineLMPredictionHead.__init__r.   r   c                 C   rV  r   )r}  r  rW  r8   r8   r9   r   V  rX  zCanineLMPredictionHead.forwardr*  r8   r8   r   r9   r|  H  s    "r|  c                       s:   e Zd Z fddZdeej deej fddZ  ZS )CanineOnlyMLMHeadc                    s   t    t|| _d S r   )r   r   r|  predictionsr   r   r8   r9   r   ]  s   
zCanineOnlyMLMHead.__init__sequence_outputr   c                 C   s   |  |}|S r   )r  )r   r  prediction_scoresr8   r8   r9   r   a  s   
zCanineOnlyMLMHead.forward)	r0   r1   r2   r   r7   r4   r   r   r   r8   r8   r   r9   r  \  s    r  c                   @   s*   e Zd ZU eed< eZdZdZdd Z	dS )CaninePreTrainedModelr   canineTc                 C   s   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjrF|jjjd| jjd |jdurD|jj|j 
  dS dS t |tjr[|j	j
  |jjd dS dS )zInitialize the weightsg        )meanstdNr   )rS  r   r   r   rU   r}   normal_r   initializer_rangerX   zero_r   padding_idxrO   fill_)r   moduler8   r8   r9   _init_weightsp  s   

z#CaninePreTrainedModel._init_weightsN)
r0   r1   r2   r   r6   r   load_tf_weightsbase_model_prefixsupports_gradient_checkpointingr  r8   r8   r8   r9   r  i  s   
 r  c                       s   e Zd Zd fdd	Zdd Zdd Zdejd	efd
dZ	dejdedejfddZ
e									ddeej deej deej deej deej deej dee dee dee deeef fddZ  ZS )CanineModelTc              
      s   t  | || _t|}d|_t|| _t|ddd|j	|j	|j	|j	d| _
t|| _t|| _t|| _t|| _|rAt|nd| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        r   TF)r6  r-  r.  r/  r0  r1  r2  N)r   r   r   copydeepcopyrg  r   char_embeddingsrd  local_transformer_striderK   r   rL   rG   r   rQ   rN   rv  pooler	post_init)r   r   add_pooling_layershallow_configr   r8   r9   r     s*   






zCanineModel.__init__c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsrG   rh  r^  r;  )r   heads_to_prunerh  r9  r8   r8   r9   _prune_heads  s   zCanineModel._prune_headsc                 C   s\   |j d |j d }}|j d }t||d|f }tj||dftj|jd}|| }|S )aP  
        Create 3D attention mask from a 2D tensor mask.

        Args:
            from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
            to_mask: int32 Tensor of shape [batch_size, to_seq_length].

        Returns:
            float Tensor of shape [batch_size, from_seq_length, to_seq_length].
        r   r   )r   r   r   )rz   r4   reshaper  onesfloat32r   )r   r   to_maskr  r>  r?  broadcast_onesmaskr8   r8   r9   )_create_3d_attention_mask_from_input_mask  s   
z5CanineModel._create_3d_attention_mask_from_input_maskchar_attention_maskr   c                 C   sF   |j \}}t||d|f}tjj||d| }tj|dd}|S )z[Downsample 2D character attention mask to 2D molecule attention mask using MaxPool1d layer.r   )r   r   r   r   )rz   r4   r  r   	MaxPool1dr  squeeze)r   r  r   r  char_seq_lenpoolable_char_maskpooled_molecule_maskmolecule_attention_maskr8   r8   r9   _downsample_attention_mask  s   
z&CanineModel._downsample_attention_mask	moleculeschar_seq_lengthr   c           	      C   sz   | j j}|ddddddf }tj||dd}|ddddddf }|| }tj||| dd}tj||gddS )zDRepeats molecules to make them the same length as the char sequence.Nr   rM   )repeatsr   r   r   )r   r   r4   repeat_interleaver   )	r   r  r  ratemolecules_without_extra_clsrepeatedlast_moleculeremainder_lengthremainder_repeatedr8   r8   r9   _repeat_molecules  s   zCanineModel._repeat_moleculesNr   r   r   r   r   r   r   rj  rk  c
           "      C   s  |d ur|n| j j}|d ur|n| j j}|rdnd }
|rdnd }|	d ur&|	n| j j}	|d ur6|d ur6td|d urE| || | }n|d urR| d d }ntd|\}}|d ura|jn|j}|d u rqtj	||f|d}|d u r~tj
|tj|d}| ||}| j|| j jd}| |||jd f}| || j j}| j||||d}| |d ur|n||}| j||||d	}|j}| |}| j||||||	d
}|d }| jd ur| |nd }| j||d d}tj||gdd}| |}| j||||d	}|j}|r|	r|jn|d }|
|j | |j }
|r2|	r$|jn|d } ||j |  |j }|	sH||f}!|!tdd |
|fD 7 }!|!S t |||
|dS )Nr8   zDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embeds)r   r   )r   )r   r   r   r   )r   r   rj  )r   r   r   rj  rk  r   )r  r   r   c                 s   rl  r   r8   rm  r8   r8   r9   rE     ro  z&CanineModel.forward.<locals>.<genexpr>)r,   r-   r.   r/   )!r   r   rj  use_return_dictr{   %warn_if_padding_and_no_attention_maskr   r   r4   r  r   r   get_extended_attention_maskr  r   rz   get_head_maskrg  r  r  rK   r,   rL   rG   r  r  r   rQ   rN   r.   r/   r7   r+   )"r   r   r   r   r   r   r   r   rj  rk  rp  rq  r   r  r   r   extended_attention_maskr   extended_molecule_attention_maskinput_char_embeddingsr  init_chars_encoder_outputsinput_char_encodinginit_molecule_encodingencoder_outputsmolecule_sequence_outputry  repeated_moleculesconcatr  final_chars_encoder_outputsdeep_encoder_hidden_statesdeep_encoder_self_attentionsr3  r8   r8   r9   r     s   
	


zCanineModel.forward)T)	NNNNNNNNN)r0   r1   r2   r   r  r  r4   r   rw   r  r  r   r   r   r5   r!  r   r7   r+   r   r   r8   r8   r   r9   r    sJ    "	

r  z
    CANINE Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    c                          e Zd Z fddZe										ddeej deej deej deej deej d	eej d
eej dee	 dee	 dee	 de
eef fddZ  ZS )CanineForSequenceClassificationc                    J   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r   r   r   
num_labelsr  r  r   r   r   r   r   r   
classifierr  r   r   r8   r9   r        
z(CanineForSequenceClassification.__init__Nr   r   r   r   r   r   labelsr   rj  rk  r   c                 C   sr  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur| j jdu rV| jdkr<d| j _n| jdkrR|jtj	ksM|jtj
krRd| j _nd| j _| j jdkrtt }| jdkrn|| | }n+|||}n%| j jdkrt }||d| j|d}n| j jdkrt }|||}|
s|f|dd  }|dur|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   r   r   r   rj  rk  r   
regressionsingle_label_classificationmulti_label_classificationr   rZ   losslogitsr.   r/   )r   r  r  r   r  problem_typer  r   r4   r   rw   r   r  r   r   r   r   r.   r/   )r   r   r   r   r   r   r   r  r   rj  rk  r  ry  r  r  loss_fctr3  r8   r8   r9   r     sV   



"


z'CanineForSequenceClassification.forward
NNNNNNNNNN)r0   r1   r2   r   r   r   r4   r   r5   r!  r   r7   r   r   r   r8   r8   r   r9   r    sH    	

r  c                       r  )CanineForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S r[  )r   r   r  r  r   r   r   r   r   r   r  r  r   r   r8   r9   r     s
   
z CanineForMultipleChoice.__init__Nr   r   r   r   r   r   r  r   rj  rk  r   c                 C   sn  |
dur|
n| j j}
|dur|jd n|jd }|dur%|d|dnd}|dur4|d|dnd}|durC|d|dnd}|durR|d|dnd}|dure|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|durt }|||}|
s|f|dd  }|dur|f| S |S t	|||j
|jdS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   r   rM   r  rZ   r  )r   r  rz   r   r   r  r   r  r   r   r.   r/   )r   r   r   r   r   r   r   r  r   rj  rk  num_choicesr  ry  r  reshaped_logitsr  r  r3  r8   r8   r9   r     sL   ,


zCanineForMultipleChoice.forwardr  )r0   r1   r2   r   r   r   r4   r   r5   r!  r   r7   r   r   r   r8   r8   r   r9   r    sH    
	

r  c                       r  )CanineForTokenClassificationc                    r  r   r  r   r   r8   r9   r   W  r  z%CanineForTokenClassification.__init__Nr   r   r   r   r   r   r  r   rj  rk  r   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur<t }||d| j|d}|
sR|f|dd  }|durP|f| S |S t|||j	|j
dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, CanineForTokenClassification
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/canine-s")
        >>> model = CanineForTokenClassification.from_pretrained("google/canine-s")

        >>> inputs = tokenizer(
        ...     "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
        ... )

        >>> with torch.no_grad():
        ...     logits = model(**inputs).logits

        >>> predicted_token_class_ids = logits.argmax(-1)

        >>> # Note that tokens are classified rather then input words which means that
        >>> # there might be more predicted token classes than words.
        >>> # Multiple token classes might account for the same word
        >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
        >>> predicted_tokens_classes  # doctest: +SKIP
        ```

        ```python
        >>> labels = predicted_token_class_ids
        >>> loss = model(**inputs, labels=labels).loss
        >>> round(loss.item(), 2)  # doctest: +SKIP
        ```Nr  r   r   rZ   r  )r   r  r  r   r  r   r   r  r   r.   r/   )r   r   r   r   r   r   r   r  r   rj  rk  r  r  r  r  r  r3  r8   r8   r9   r   b  s8   0

z$CanineForTokenClassification.forwardr  )r0   r1   r2   r   r   r   r4   r   r5   r!  r   r7   r   r   r   r8   r8   r   r9   r  U  sH    	

r  c                       s   e Zd Z fddZe											ddeej deej deej deej deej d	eej d
eej deej dee	 dee	 dee	 de
eef fddZ  ZS )CanineForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r   )
r   r   r  r  r  r   r   r   
qa_outputsr  r   r   r8   r9   r     s
   
z#CanineForQuestionAnswering.__init__Nr   r   r   r   r   r   start_positionsend_positionsr   rj  rk  r   c                 C   s@  |d ur|n| j j}| j|||||||	|
|d	}|d }| |}|jddd\}}|d}|d}d }|d ur~|d ur~t| dkrK|d}t| dkrX|d}|d}|d| |d| t	|d}|||}|||}|| d }|s||f|dd   }|d ur|f| S |S t
||||j|jdS )	Nr  r   r   r   r   )ignore_indexrZ   )r  start_logits
end_logitsr.   r/   )r   r  r  r  ro   r  rv   r   clamp_r   r   r.   r/   )r   r   r   r   r   r   r   r  r  r   rj  rk  r  r  r  r  r  
total_lossignored_indexr  
start_lossend_lossr3  r8   r8   r9   r     sP   








z"CanineForQuestionAnswering.forward)NNNNNNNNNNN)r0   r1   r2   r   r   r   r4   r   r5   r!  r   r7   r   r   r   r8   r8   r   r9   r    sN    
	

r  )r  r  r  r  rZ  r  r  r   )Ar3   r  r  rf   dataclassesr   typingr   r   r4   r   torch.nnr   r   r   activationsr
   modeling_layersr   modeling_outputsr   r   r   r   r   r   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r   configuration_caniner   
get_loggerr0   rd   r   r+   r   Moduler   r   r   r   r"  r+  rQ  rY  rZ  rd  rv  rz  r|  r  r  r  r  r  r  r  __all__r8   r8   r8   r9   <module>   sp    
ae.:j :C  Ug`M