o
    i                     @   s  d dl Z d dlmZmZ d dlZd dlmZmZmZ d dlm	Z	 ddl
mZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlmZmZmZ ddlmZ ddl m!Z! ddl"m#Z# e$e%Z&G dd dej'Z(G dd dej'Z)G dd dej'Z*G dd dej'Z+G dd dej'Z,G dd dej'Z-G dd deZ.G dd  d ej'Z/G d!d" d"ej'Z0G d#d$ d$ej'Z1G d%d& d&ej'Z2G d'd( d(ej'Z3G d)d* d*eZ4G d+d, d,e4Z5G d-d. d.e4eZ6g d/Z7dS )0    N)OptionalUnion)Tensordevicenn)CrossEntropyLoss   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions,BaseModelOutputWithPoolingAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)logging)deprecate_kwarg   )BlipTextConfigc                       s\   e Zd ZdZ fddZ				ddeej deej deej d	e	d
ej
f
ddZ  ZS )BlipTextEmbeddingsz;Construct the embeddings from word and position embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	tj
|j|jd| _
t|j| _| jdt|jddd t|dd| _|| _d S )	N)padding_idxepsposition_ids)r   F)
persistentposition_embedding_typeabsolute)super__init__r   	Embedding
vocab_sizehidden_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutregister_buffertorcharangeexpandgetattrr!   configselfr6   	__class__ _/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/blip/modeling_blip_text.pyr$   /   s   

zBlipTextEmbeddings.__init__Nr   	input_idsr   inputs_embedspast_key_values_lengthreturnc           	      C   s   |d ur	|  }n|  d d }|d }|d u r&| jd d ||| f }|d u r/| |}|}| jdkr?| |}||7 }| |}| |}|S )Nr   r   r"   )sizer   r)   r!   r+   r,   r0   )	r8   r=   r   r>   r?   input_shape
seq_length
embeddingsr+   r;   r;   r<   forwardA   s   





zBlipTextEmbeddings.forward)NNNr   )__name__
__module____qualname____doc__r$   r   r2   
LongTensorFloatTensorintr   rE   __classcell__r;   r;   r9   r<   r   ,   s$    r   c                       s   e Zd Zd fdd	Zdd Zdd Zdd	 Zd
d Zedddd							dde	j
dee	j dee	j dee	j dee	j dee dee dee	j
 dee	j
 fddZ  ZS )BlipTextSelfAttentionNc                    s&  t    || _|j|j dkrt|dstd|j|jf |j| _t|j|j | _| j| j | _	|| _
t|j| j	| _|rTt|j| j	| _t|j| j	| _nt|j| j	| _t|j| j	| _t|j| _t|dd| _| jdks~| jdkr|j| _td|j d	 | j| _d S d S )
Nr   embedding_sizezLThe hidden size (%d) is not a multiple of the number of attention heads (%d)r!   r"   relative_keyrelative_key_query   r   )r#   r$   r6   r'   num_attention_headshasattr
ValueErrorrL   attention_head_sizeall_head_size	layer_idxr   Linearqueryencoder_hidden_sizekeyvaluer.   attention_probs_dropout_probr0   r5   r!   r*   r%   distance_embeddingr8   r6   is_cross_attentionrX   r9   r;   r<   r$   a   s0   

zBlipTextSelfAttention.__init__c                 C   
   || _ d S Nattn_gradients)r8   re   r;   r;   r<   save_attn_gradients}      
z)BlipTextSelfAttention.save_attn_gradientsc                 C      | j S rc   rd   r8   r;   r;   r<   get_attn_gradients      z(BlipTextSelfAttention.get_attn_gradientsc                 C   rb   rc   attention_map)r8   rm   r;   r;   r<   save_attention_map   rg   z(BlipTextSelfAttention.save_attention_mapc                 C   rh   rc   rl   ri   r;   r;   r<   get_attention_map   rk   z'BlipTextSelfAttention.get_attention_mappast_key_valuepast_key_values4.58new_nameversionFhidden_statesattention_mask	head_maskencoder_hidden_statesencoder_attention_maskoutput_attentionscache_positionr@   c	                 C   s  |j \}	}
}| ||	d| j| jdd}|d u}|r|n|}d}|d ur?t|tr=|j	| j
}|r9|j}n|j}n|}|rC|n|}|r\|d ur\|r\|j| j
 j}|j| j
 j}nF| ||	d| j| jdd}| ||	d| j| jdd}|d ur|s|nd }|||| j
d|i\}}|rt|trd|j| j
< t||dd}| jdks| jd	kr| d }
tj|
tj|jd
dd}tj|
tj|jd
dd}|| }| || j d }|j|jd}| jdkrtd||}|| }n| jd	krtd||}td||}|| | }|t | j }|d ur.|||j }t!j"dd|}| #|}|d urD|| }t||}|$dddd% }| d d | j&f }|j| }||fS )Nr   r   rR   Fr|   TrP   rQ   )dtyper   r~   zbhld,lrd->bhlrzbhrd,lrd->bhlrdimr   r   )'shaperZ   viewrS   rV   	transpose
isinstancer   
is_updatedgetrX   cross_attention_cacheself_attention_cachelayerskeysvaluesr\   r]   updater2   matmulr!   rA   r3   longr   r_   r*   tor~   einsummathsqrtr   Softmaxr0   permute
contiguousrW   )r8   rv   rw   rx   ry   rz   rq   r{   r|   
batch_sizerC   _query_layerra   r   curr_past_key_valuecurrent_states	key_layervalue_layerattention_scoresposition_ids_lposition_ids_rdistancepositional_embeddingrelative_position_scoresrelative_position_scores_queryrelative_position_scores_keyattention_probsattention_probs_droppedcontext_layernew_context_layer_shaper;   r;   r<   rE      sz   	






zBlipTextSelfAttention.forwardrc   NNNNNFN)rF   rG   rH   r$   rf   rj   rn   ro   r   r2   r   r   rK   r
   booltuplerE   rM   r;   r;   r9   r<   rN   `   sB    	
rN   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )BlipTextSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr   )r#   r$   r   rY   r'   denser,   r-   r.   r/   r0   r7   r9   r;   r<   r$         
zBlipTextSelfOutput.__init__rv   input_tensorr@   c                 C   &   |  |}| |}| || }|S rc   r   r0   r,   r8   rv   r   r;   r;   r<   rE         

zBlipTextSelfOutput.forwardrF   rG   rH   r$   r2   r   rE   rM   r;   r;   r9   r<   r          $r   c                       s   e Zd Zd fdd	Zdd Zeddd	d
						ddejdeej	 deej	 deej	 dee
 dee deej deej fddZ  ZS )BlipTextAttentionFNc                    s0   t    t|||d| _t|| _t | _d S )NrX   )r#   r$   rN   r8   r   outputsetpruned_headsr`   r9   r;   r<   r$     s   

zBlipTextAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )lenr   r8   rS   rV   r   r   rZ   r\   r]   r   r   rW   union)r8   headsindexr;   r;   r<   prune_heads  s   zBlipTextAttention.prune_headsrp   rq   rr   rs   rv   rw   rx   ry   r{   r|   r@   c              	   C   s>   | j |||||||d}| |d |}	|	f|dd   }
|
S )Nrw   rx   ry   rq   r{   r|   r   r   )r8   r   )r8   rv   rw   rx   ry   rq   r{   r|   self_outputsattention_outputoutputsr;   r;   r<   rE     s   	zBlipTextAttention.forward)FN)NNNNFN)rF   rG   rH   r$   r   r   r2   r   r   rK   r
   r   r   rE   rM   r;   r;   r9   r<   r     s6    	r   c                       2   e Zd Z fddZdejdejfddZ  ZS )BlipTextIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S rc   )r#   r$   r   rY   r'   intermediate_sizer   r   
hidden_actstrr	   intermediate_act_fnr7   r9   r;   r<   r$   9  s
   
zBlipTextIntermediate.__init__rv   r@   c                 C      |  |}| |}|S rc   )r   r   r8   rv   r;   r;   r<   rE   A     

zBlipTextIntermediate.forwardr   r;   r;   r9   r<   r   8  s    r   c                       r   )BlipTextOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r#   r$   r   rY   r   r'   r   r,   r-   r.   r/   r0   r7   r9   r;   r<   r$   I  r   zBlipTextOutput.__init__rv   r   r@   c                 C   r   rc   r   r   r;   r;   r<   rE   O  r   zBlipTextOutput.forwardr   r;   r;   r9   r<   r   H  r   r   c                       s   e Zd Z fddZedddd							dd	ejd
eej deej deej deej dee	 dee
 deej deej fddZdd Z  ZS )BlipTextLayerc                    sf   t    || _|j| _d| _t||d| _|| _| jjr't|| jj|d| _	t
|| _t|| _d S )Nr   r   )ra   rX   )r#   r$   r6   chunk_size_feed_forwardseq_len_dimr   	attention	layer_num
is_decodercrossattentionr   intermediater   r   )r8   r6   r   r9   r;   r<   r$   W  s   


zBlipTextLayer.__init__rp   rq   rr   rs   NFrv   rw   rx   ry   rz   r{   r|   r@   c	              	   C   s   | j ||||||d}	|	d }
|	dd  }|d ur1| j|
||||||d}|d }
||dd   }t| j| j| j|
}|f| S )N)rw   rx   r{   rq   r|   r   r   r   )r   r   r   feed_forward_chunkr   r   )r8   rv   rw   rx   ry   rz   rq   r{   r|   self_attention_outputsr   r   cross_attention_outputslayer_outputr;   r;   r<   rE   e  s4   	
zBlipTextLayer.forwardc                 C   s   |  |}| ||}|S rc   )r   r   )r8   r   intermediate_outputr   r;   r;   r<   r     s   
z BlipTextLayer.feed_forward_chunkr   )rF   rG   rH   r$   r   r2   r   r   rK   r
   r   r   rE   r   rM   r;   r;   r9   r<   r   V  s<    	
'r   c                       s   e Zd Z fddZ										ddejdeej deej d	eej d
eej dee dee	 dee	 dee	 dee	 deej de
eej ef fddZ  ZS )BlipTextEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  |qS r;   )r   ).0ir6   r;   r<   
<listcomp>      z,BlipTextEncoder.__init__.<locals>.<listcomp>F)	r#   r$   r6   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr7   r9   r   r<   r$     s   
 
zBlipTextEncoder.__init__NFTrv   rw   rx   ry   rz   rq   	use_cacher{   output_hidden_statesreturn_dictr|   r@   c              
   C   sr  | j r| jr|rtd d}|rAt|tr!td t|}n t|t	r0t|t	| j
d}n|d u rAtt	| j
dt	| j
d}|	rEdnd }|rKdnd }|rU|d urUdnd }t| j
jD ];}| j| }|	rk||f }|d urs|| nd }|||||||||}|d }|r||d f }|d ur||d f }q]|	r||f }|
std	d
 |||||fD S t|||||dS )NzZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...FzPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   r;   r   r   rR   c                 s   s    | ]	}|d ur|V  qd S rc   r;   )r   vr;   r;   r<   	<genexpr>  s    z*BlipTextEncoder.forward.<locals>.<genexpr>)last_hidden_staterq   rv   
attentionscross_attentions)r   trainingloggerwarningr   r   warning_oncer   from_legacy_cacher   r6   r   r   r   r   )r8   rv   rw   rx   ry   rz   rq   r   r{   r   r   r|   all_hidden_statesall_self_attentionsall_cross_attentionsr   layer_modulelayer_head_masklayer_outputsr;   r;   r<   rE     sx   




zBlipTextEncoder.forward)
NNNNNNFFTN)rF   rG   rH   r$   r2   r   r   rK   r
   r   r   r   r   rE   rM   r;   r;   r9   r<   r     sJ    		
r   c                       r   )BlipTextPoolerc                    s*   t    t|j|j| _t | _d S rc   )r#   r$   r   rY   r'   r   Tanh
activationr7   r9   r;   r<   r$     s   
zBlipTextPooler.__init__rv   r@   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r8   rv   first_token_tensorpooled_outputr;   r;   r<   rE     s   

zBlipTextPooler.forwardr   r;   r;   r9   r<   r     s    r   c                       r   )BlipTextPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S r   )r#   r$   r   rY   r'   r   r   r   r   r	   transform_act_fnr,   r-   r7   r9   r;   r<   r$     s   
z(BlipTextPredictionHeadTransform.__init__rv   r@   c                 C   s"   |  |}| |}| |}|S rc   )r   r  r,   r   r;   r;   r<   rE     s   


z'BlipTextPredictionHeadTransform.forwardr   r;   r;   r9   r<   r    s    	r  c                       s,   e Zd Z fddZdd Zdd Z  ZS )BlipTextLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)bias)r#   r$   r  	transformr   rY   r'   r&   decoder	Parameterr2   zerosr  r7   r9   r;   r<   r$     s
   

z!BlipTextLMPredictionHead.__init__c                 C   s   | j | j_ d S rc   )r  r  ri   r;   r;   r<   _tie_weights&  s   z%BlipTextLMPredictionHead._tie_weightsc                 C   r   rc   )r  r  r   r;   r;   r<   rE   )  r   z BlipTextLMPredictionHead.forward)rF   rG   rH   r$   r	  rE   rM   r;   r;   r9   r<   r    s    r  c                       r   )BlipTextOnlyMLMHeadc                    s   t    t|| _d S rc   )r#   r$   r  predictionsr7   r9   r;   r<   r$   1  s   
zBlipTextOnlyMLMHead.__init__sequence_outputr@   c                 C   s   |  |}|S rc   )r  )r8   r  prediction_scoresr;   r;   r<   rE   5  s   
zBlipTextOnlyMLMHead.forwardr   r;   r;   r9   r<   r
  0  s    r
  c                   @   s*   e Zd ZU dZeed< dZg Zdd ZdS )BlipTextPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    r6   bertc                 C   s~   t |tjtjfr|jjjd| jjd nt |tj	r(|j
j  |jjd t |tjr;|j
dur=|j
j  dS dS dS )zInitialize the weightsg        )meanstd      ?N)r   r   rY   r%   weightdatanormal_r6   initializer_ranger,   r  zero_fill_)r8   moduler;   r;   r<   _init_weightsE  s   z%BlipTextPreTrainedModel._init_weightsN)	rF   rG   rH   rI   r   __annotations__base_model_prefix_no_split_modulesr  r;   r;   r;   r<   r  ;  s   
 r  c                #       s"  e Zd ZdZd# fdd	Zdd Zdd Zd	d
 Zdede	e
 dededef
ddZ															d$deej deej deej deej deej deej deej deej dee dee dee dee dee dee d eej dee	ej ef f d!d"Z  ZS )%BlipTextModela&  
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
    all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and `is_decoder` set to `True`; an
    `encoder_hidden_states` is then expected as an input to the forward pass.
    Tc                    sD   t  | || _t|| _t|| _|rt|nd | _| 	  d S rc   )
r#   r$   r6   r   rD   r   encoderr   pooler	post_init)r8   r6   add_pooling_layerr9   r;   r<   r$   \  s   

zBlipTextModel.__init__c                 C   s   | j jS rc   rD   r)   ri   r;   r;   r<   get_input_embeddingsf  s   z"BlipTextModel.get_input_embeddingsc                 C   s   || j _d S rc   r#  )r8   r]   r;   r;   r<   set_input_embeddingsi  s   z"BlipTextModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )r8   heads_to_pruner   r   r;   r;   r<   _prune_headsm  s   zBlipTextModel._prune_headsrw   rB   r   r   r@   c                 C   s^  |  dkr|dddddddf }n|  dkr|r|\}}tj||d}|ddddf ||d|ddddf k}	|	|j}	|	jd |jd k rl|jd |	jd  }
tjtj|||
f||	jd|	gdd}	|	dddddddf |ddddddf  }n|ddddddf }nt	d	| d
|j d|j| jd}d| d }|S )a=  
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (`tuple[int]`):
                The shape of the input to the model.
            device (`torch.device`):
                The device of the input to the model.

        Returns:
            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        r   NrR   r   r   )r   r~   r   )axisz!Wrong shape for input_ids (shape z) or attention_mask (shape )r   r  g     )
r   r2   r3   repeatr   r~   r   catonesrU   )r8   rw   rB   r   r   extended_attention_maskr   rC   seq_idscausal_maskprefix_seq_lenr;   r;   r<   get_extended_attention_masku  s4   .6
	z)BlipTextModel.get_extended_attention_maskNFr=   r   rx   r>   encoder_embedsry   rz   rq   r   r{   r   r   r|   c                    s  |dur|n j j}|dur|n j j}|dur|n j j}|r+|
dur&|
n j j}
nd}
|dur9|dur9td|durO || | }|\}}|j}n,|durc| dd }|\}}|j}n|durw| dd }|\}}|j}ntdd}|	durt	|	t
s|	d d jd n|	 }|du rt||| f|} ||||}|durt	|tr|d  \}}}n| \}}}||f}t	|tr؇ fdd	|D }n|du rtj||d
} |}n |}nd} | j j}|du r j||||d}n|} j||||||	|
||||d}|d } jdur) |nd}|s8||f|dd  S t|||j|j|j|jdS )a  
        encoder_hidden_states  (`torch.FloatTensor`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`torch.FloatTensor`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (`Cache`, *optional*):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        NFzDYou cannot specify both input_ids and inputs_embeds at the same timer   zGYou have to specify either input_ids or inputs_embeds or encoder_embedsr   r}   c                    s   g | ]}  |qS r;   )invert_attention_mask)r   maskri   r;   r<   r     r   z)BlipTextModel.forward.<locals>.<listcomp>r)  )r=   r   r>   r?   )
rw   rx   ry   rz   rq   r   r{   r   r   r|   r   )r   pooler_outputrq   rv   r   r   )r6   r{   r   use_return_dictr   rU   %warn_if_padding_and_no_attention_maskrA   r   r   r
   r   get_seq_lengthr2   r.  r   r3  listr5  get_head_maskr   rD   r  r   r   rq   rv   r   r   )r8   r=   rw   r   rx   r>   r4  ry   rz   rq   r   r{   r   r   r   r|   rB   r   rC   r   r?   r/  encoder_batch_sizeencoder_sequence_lengthr   encoder_hidden_shapeencoder_extended_attention_maskembedding_outputencoder_outputsr  r   r;   ri   r<   rE     s   $


zBlipTextModel.forward)T)NNNNNNNNNNNNNFN)rF   rG   rH   rI   r$   r$  r%  r(  r   r   rL   r   r   r3  r   r2   r
   r   r   rE   rM   r;   r;   r9   r<   r  S  s    

@	
r  c                '       s,  e Zd ZddgZ fddZdd Zdd Zd	d
 Zdd Z																	d'de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e de	e de	e de	e de	e de	e de	e d e	e d!e	e
j d"eee
j ef f$d#d$Zd( fd%d&	Z  ZS ))BlipTextLMHeadModelzcls.predictions.decoder.weightzcls.predictions.decoder.biasc                    s0   t  | t|dd| _t|| _|j| _d S )NF)r"  )r#   r$   r  r  r
  clslabel_smoothingr7   r9   r;   r<   r$   J  s   
zBlipTextLMHeadModel.__init__c                 C   s
   | j  S rc   )r  r$  ri   r;   r;   r<   r$  Q  rg   z(BlipTextLMHeadModel.get_input_embeddingsc                 C   s   | j | d S rc   )r  r%  r8   new_embeddingsr;   r;   r<   r%  T  s   z(BlipTextLMHeadModel.set_input_embeddingsc                 C   s
   | j jjS rc   )rD  r  r  ri   r;   r;   r<   get_output_embeddingsW  rg   z)BlipTextLMHeadModel.get_output_embeddingsc                 C   s   || j j_|j| j j_d S rc   )rD  r  r  r  rF  r;   r;   r<   set_output_embeddingsZ  s   
z)BlipTextLMHeadModel.set_output_embeddingsNFTr  r=   rw   r   rx   r>   ry   rz   labelsrq   r   r{   r   r   return_logitsr   	reductionr|   r@   c                 C   sV  |dur|n| j j}|durd}
| j||||||||	|
|||||d}|d }| |}|r=|ddddddf  S d}|dur|ddddddf  }|ddddf  |j}t|| jd}||	d| j j
|	d}|dkr|	|ddd}|s|f|d	d  }|dur|f| S |S t|||j|j|j|jd
S )a  
        encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of
            hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is
            configured as a decoder.
        encoder_attention_mask (`torch.FloatTensor`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        labels (`torch.LongTensor`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
        past_key_values (`Cache`, *optional*):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        NF)rw   r   rx   r>   ry   rz   rq   r   r{   r   r   r   r|   r   r   r   )rL  rE  nonerR   )losslogitsrq   rv   r   r   )r6   r8  r  rD  r   r   r   r   rE  r   r&   rA   sumr   rq   rv   r   r   )r8   r=   rw   r   rx   r>   ry   rz   rJ  rq   r   r{   r   r   rK  r   rL  r|   r   r  r  lm_lossshifted_prediction_scoresloss_fctr   r;   r;   r<   rE   ^  sT   *
 zBlipTextLMHeadModel.forwardc                    s&   t  j|f||d|}d|d< |S )N)rq   rw   Tr   )r#   prepare_inputs_for_generation)r8   r=   rq   rw   model_kwargsmodel_inputsr9   r;   r<   rT    s   z1BlipTextLMHeadModel.prepare_inputs_for_generation)NNNNNNNNNNNNNFTr  N)NN)rF   rG   rH   _tied_weights_keysr$   r$  r%  rH  rI  r   r2   r   r
   r   r   r   r   r   rE   rT  rM   r;   r;   r9   r<   rC  G  s|    	

\rC  )r  rC  r  )8r   typingr   r   r2   r   r   r   torch.nnr   activationsr	   cache_utilsr
   r   r   
generationr   modeling_layersr   modeling_outputsr   r   r   modeling_utilsr   pytorch_utilsr   r   r   utilsr   utils.deprecationr   configuration_blipr   
get_loggerrF   r   Moduler   rN   r   r   r   r   r   r   r   r  r  r
  r  r  rC  __all__r;   r;   r;   r<   <module>   sF   
4 3>b u 