o
    iy                     @   s  d Z ddlZddlmZmZ ddlZddlmZ ddlmZ ddl	m
Z
 ddlmZ dd	lmZmZmZ dd
lmZmZmZ ddlmZ ddlmZmZ ddlmZmZmZmZmZ ddl m!Z! ddl"m#Z# ddl$m%Z% e rwddl&m'Z'm(Z( G dd deZ)eG dd deZ*G dd deZ+G dd deZ,eG dd de*Z-eddG d d! d!e*eZ.G d"d# d#eZ/G d$d% d%e!Z0G d&d' d'eZ1g d(Z2dS ))zPyTorch PLBART model.    N)OptionalUnion)nn)CrossEntropyLoss   )Cache)GenerationMixin)AttentionMaskConverter_prepare_4d_attention_mask#_prepare_4d_attention_mask_for_sdpa)BaseModelOutputSeq2SeqLMOutputSeq2SeqModelOutput)PreTrainedModel)auto_docstringis_torch_flex_attn_available   )BartClassificationHeadBartDecoderBartEncoderBartForCausalLMBartScaledWordEmbedding)'BigBirdPegasusForSequenceClassification)shift_tokens_right   )PLBartConfig)	BlockMaskmake_flex_block_causal_maskc                   @      e Zd ZdS )PLBartScaledWordEmbeddingN__name__
__module____qualname__ r$   r$   f/home/ubuntu/veenaModal/venv/lib/python3.10/site-packages/transformers/models/plbart/modular_plbart.pyr   6       r   c                   @   s   e Zd ZU eed< dZdZddgZdZdZ	dZ
deejdf dejfd	d
Zdeeejdf  dejdejdefddZedejdededejdejdefddZdeejdf deejdf dejdejfddZdS )PLBartPreTrainedModelconfigmodelTPLBartDecoderLayerPLBartEncoderLayerattention_maskNinputs_embedsc                 C   s   |d ur>| j jdkrd|v r|}|S d }|S | j jdkr$t||j}|S | j jdkr8t|tjr6t|dd}|S t||j}|S )Nflash_attention_2r   sdpaflex_attentionF)	is_causal	r(   _attn_implementationr   dtype
isinstancetorchTensorr   r
   )selfr,   r-   r$   r$   r%   _update_full_maskE   s   z'PLBartPreTrainedModel._update_full_maskr   input_tensorcache_positionpast_key_valuesc                 C   sb  | j jdkr*t|tjrt|}|S |d u r(ttj|jd |jd f|jd}|S | j jdkr>|d ur<|dk	 r<|S d S |d urF|
 nd}|d urO|jnd}| j jdkre|setj|||| jd	red S |j}|jd }|rt| }	nt|tjr|jd
 n|| d }	| j|||	|||jd d}
| j jdkr|d ur|jjdv rt|j}t|
|}
|
S )Nr0   r   r   )sizedevicer.   g        Fr/   )r-   past_key_values_lengthis_training)sequence_lengthtarget_lengthr4   r;   
batch_size)cudaxpunpu)r(   r3   r5   r6   r7   r   onesshaper>   anyget_seq_lengthis_compileabler	   _ignore_causal_mask_sdpatrainingr4   get_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_positiontypefinfomin_unmask_unattended)r8   r,   r:   r;   r<   past_seen_tokensusing_compilable_cacher4   rB   rC   causal_mask	min_dtyper$   r$   r%   _update_causal_mask\   s`   





z)PLBartPreTrainedModel._update_causal_maskrB   rC   r4   rD   c                 K   sD  | dur|   dkr| }|S t|j}tj||f|||jd}|dkr+tj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| dur|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        N   )
fill_valuer4   r>   r   )diagonalr>   rA   r   )dimr6   rR   rS   fullr>   triuarangereshapeexpandclonerI   tomasked_fill)r,   rB   rC   r4   r;   rD   kwargsrW   rX   mask_lengthpadding_maskr$   r$   r%   rP      s,    $
6  zKPLBartPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_positionencoder_hidden_statesencoder_attention_maskinput_shapec                 C   s   |d urM|d urM| j jdkrd|v r|}|S d }|S | j jdkr,t||j|d d}|S | j jdkrCt|tjrAt||d dd}|S t||j|d d}|S )	Nr.   r   r/   rA   )tgt_lenr0   F)query_lengthr1   r2   )r8   rj   rk   rl   r-   r$   r$   r%   _update_cross_attn_mask   s2   z-PLBartPreTrainedModel._update_cross_attn_mask)r!   r"   r#   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modules_supports_flash_attn_supports_sdpa_supports_flex_attnr   r6   r7   r9   r   r   rY   staticmethodintr4   rP   Sizero   r$   r$   r$   r%   r'   :   sZ   
 

L7r'   c                   @   r   )PLBartEncoderNr    r$   r$   r$   r%   rz     r&   rz   c                   @   r   )PLBartDecoderNr    r$   r$   r$   r%   r{   	  r&   r{   c                &       s&  e Zd ZddgZdef fddZdd Zdd	 Zd
d Zdd Z	e
																d"deej deej deej deej deej deej deej deeej  dee deej deej dee dee dee dee deej deeej ef f"d d!Z  ZS )#PLBartModelencoder.embed_tokens.weightdecoder.embed_tokens.weightr(   c                    sl   t  | |j|j}}|jrt|jnd}t||j||d| _	t
|| j	| _t|| j	| _|   d S )Ng      ?)embed_scale)super__init__pad_token_id
vocab_sizescale_embeddingmathsqrtd_modelr   sharedrz   encoderr{   decoderinit_weights)r8   r(   padding_idxr   r   	__class__r$   r%   r     s   zPLBartModel.__init__c                 C      | j S N)r   r8   r$   r$   r%   get_input_embeddings     z PLBartModel.get_input_embeddingsc                 C   s   || _ | j | j_| j | j_d S r   )r   r   embed_tokensr   )r8   valuer$   r$   r%   set_input_embeddings   s   
z PLBartModel.set_input_embeddingsc                 C   s4   | j jr| | jj| j | | jj| j d S d S r   )r(   tie_word_embeddings_tie_or_clone_weightsr   r   r   r   r   r$   r$   r%   _tie_weights%  s   zPLBartModel._tie_weightsc                 C   r   r   )r   r   r$   r$   r%   get_encoder*  r   zPLBartModel.get_encoderN	input_idsr,   decoder_input_idsdecoder_attention_mask	head_maskdecoder_head_maskcross_attn_head_maskencoder_outputsr<   r-   decoder_inputs_embeds	use_cacheoutput_attentionsoutput_hidden_statesreturn_dictr;   returnc                 C   s4  |dur|n| j j}|dur|n| j j}|dur|n| j j}|dur$|n| j j}|du r7|du r7t|| j j}|du rH| j||||
|||d}n$|rlt|t	slt	|d t
|dkr]|d ndt
|dkrh|d ndd}| j|||d ||||	||||||d}|s|| S t|j|j|j|j|j|j|j|jdS )	a  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
            See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
            varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (:
            obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior:
            generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
        cross_attn_head_mask (:
            obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify
            selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        N)r   r,   r   r-   r   r   r   r   r   r   )last_hidden_statehidden_states
attentions)r   r,   rj   rk   r   r   r<   r-   r   r   r   r   r;   )r   r<   decoder_hidden_statesdecoder_attentionscross_attentionsencoder_last_hidden_staterj   encoder_attentions)r(   r   r   r   use_return_dictr   r   r   r5   r   lenr   r   r   r<   r   r   r   )r8   r   r,   r   r   r   r   r   r   r<   r-   r   r   r   r   r   r;   decoder_outputsr$   r$   r%   forward-  sd   1
zPLBartModel.forward)NNNNNNNNNNNNNNNN)r!   r"   r#   _tied_weights_keysr   r   r   r   r   r   r   r   r6   
LongTensorr7   listFloatTensorr   boolr   tupler   r   __classcell__r$   r$   r   r%   r|     sv    	
r|   zv
    The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.
    )custom_introc                (       sv  e Zd ZdZdgZg dZdef fddZdd Zd	d
 Z		d*de
dee
 dedejf fddZde
ddfddZe																	d+deej deej deej deej deej deej deej deeej  dee deej deej d eej d!ee d"ee d#ee d$ee d%eej deeej ef f$d&d'Zd ejfd(d)Z  ZS ),PLBartForConditionalGenerationr)   final_logits_bias)r}   r~   zlm_head.weightr(   c                    sX   t  | t|| _| dtd| jjjf t	j
|j| jjjdd| _|   d S )Nr   r   F)bias)r   r   r|   r)   register_bufferr6   zerosr   num_embeddingsr   Linearr   lm_headr   )r8   r(   r   r$   r%   r     s
   
z'PLBartForConditionalGeneration.__init__c                 C   
   | j  S r   )r)   r   r   r$   r$   r%   r        
z*PLBartForConditionalGeneration.get_encoderc                 C   r   r   )r)   get_decoderr   r$   r$   r%   r     r   z*PLBartForConditionalGeneration.get_decoderNTnew_num_tokenspad_to_multiple_ofmean_resizingr   c                    s&   t  |||}| |jjd  |S )Nr   )r   resize_token_embeddings_resize_final_logits_biasweightrI   )r8   r   r   r   new_embeddingsr   r$   r%   r     s   z6PLBartForConditionalGeneration.resize_token_embeddingsc                 C   sj   | j jd }||kr| j d d d |f }ntjd|| f| j jd}tj| j |gdd}| d| d S )NrA   r   r]   )r^   r   )r   rI   r6   r   r>   catr   )r8   r   old_num_tokensnew_bias
extra_biasr$   r$   r%   r     s   z8PLBartForConditionalGeneration._resize_final_logits_biasr   r,   r   r   r   r   r   r   r<   r-   r   labelsr   r   r   r   r;   c                 C   s
  |dur|n| j j}|dur|du r|du rt|| j j}| j|f||||||||	|
||||||d}| |d }|| j|j }d}|dur[t	 }||
d| j j|
d}|sq|f|dd  }|duro|f| S |S t|||j|j|j|j|j|j|jd	S )a  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
            See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
            varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (:
            obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior:
            generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
        cross_attn_head_mask (:
            obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify
            selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example Mask-filling:

        ```python
        >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration

        >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")
        >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")

        >>> # en_XX is the language symbol id <LID> for English
        >>> TXT = "<s> Is 0 the <mask> Fibonacci number ? </s> en_XX"
        >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids

        >>> logits = model(input_ids).logits
        >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
        >>> probs = logits[0, masked_index].softmax(dim=0)
        >>> values, predictions = probs.topk(5)

        >>> tokenizer.decode(predictions).split()
        ['first', 'same', 'highest', 'result', 'number']
        ```
        N)r,   r   r   r   r   r   r   r<   r-   r   r   r   r   r   r;   r   rA   r   )	losslogitsr<   r   r   r   r   rj   r   )r(   r   r   r   r)   r   r   re   r>   r   viewr   r   r<   r   r   r   r   rj   r   )r8   r   r,   r   r   r   r   r   r   r<   r-   r   r   r   r   r   r   r;   outputs	lm_logitsmasked_lm_lossloss_fctoutputr$   r$   r%   r     sV   Kz&PLBartForConditionalGeneration.forwardc                 C   s   t || jjS r   )r   r(   r   )r8   r   r$   r$   r%   %prepare_decoder_input_ids_from_labels?  s   zDPLBartForConditionalGeneration.prepare_decoder_input_ids_from_labels)NT)NNNNNNNNNNNNNNNNN)r!   r"   r#   rq   _keys_to_ignore_on_load_missingr   r   r   r   r   rx   r   r   r   	Embeddingr   r   r   r6   r   r7   r   r   r   r   r   r   r   r   r   r$   r$   r   r%   r     s    		
zr   c                   @   r   )PLBartClassificationHeadNr    r$   r$   r$   r%   r   C  r&   r   c                       s   e Zd Z fddZ  ZS )PLBartForSequenceClassificationc                        t  jdi |  dS )a  
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
            See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
            varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            For translation and summarization training, `decoder_input_ids` should be provided. If no
            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
            for denoising pre-training following the paper.
        decoder_attention_mask (:
            obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior:
            generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
        cross_attn_head_mask (:
            obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify
            selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr$   r   r   super_kwargsr   r$   r%   r   H  s   !z'PLBartForSequenceClassification.forward)r!   r"   r#   r   r   r$   r$   r   r%   r   G  s    r   c                       s    e Zd Ze fddZ  ZS )PLBartForCausalLMc                     r   )a  
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, PLBartForCausalLM

        >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
        >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False)
        >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> logits = outputs.logits
        >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
        >>> list(logits.shape) == expected_shape
        True
        ```Nr$   r   r   r   r$   r%   r   m  s   zPLBartForCausalLM.forward)r!   r"   r#   r   r   r   r$   r$   r   r%   r   l  s    r   )r   r   r   r|   r'   )3__doc__r   typingr   r   r6   r   torch.nnr   cache_utilsr   
generationr   modeling_attn_mask_utilsr	   r
   r   modeling_outputsr   r   r   modeling_utilsr   utilsr   r   bart.modeling_bartr   r   r   r   r   (bigbird_pegasus.modeling_bigbird_pegasusr   mbart.modeling_mbartr   configuration_plbartr   integrations.flex_attentionr   r   r   r'   rz   r{   r|   r   r   r   r   __all__r$   r$   r$   r%   <module>   sH    K  #%!