o
    i                     @   s  d Z ddlZddlmZ ddlmZmZmZmZ ddl	Z	ddl
mZ ddlmZ ddlmZ ddlmZmZmZmZ dd	lmZmZ dd
lmZmZmZ ddlmZmZmZm Z m!Z!m"Z" ddl#m$Z$m%Z%m&Z& e!'e(Z)de	j*de	j*fddZ+de	j*de	j*fddZ,eeG dd deZ-G dd dej.Z/G dd dej.Z0G dd dej.Z1de0iZ2G dd dej.Z3G d d! d!ej.Z4G d"d# d#ej.Z5G d$d% d%eZ6G d&d' d'ej.Z7G d(d) d)ej.Z8	*dOd+ej.d,e	j*d-e	j*d.e	j*d/ee	j* d0e9d1e9fd2d3Z:G d4d5 d5ej.Z;G d6d7 d7ej.Z<G d8d9 d9eZ=G d:d; d;ej.Z>G d<d= d=ej.Z?eG d>d? d?eZ@G d@dA dAej.ZAG dBdC dCe@ZBedDdEG dFdG dGe@ZCG dHdI dIe@ZDG dJdK dKe@ZEdPdLdMZFg dNZGdS )QzPyTorch AltCLIP model.    N)	dataclass)AnyCallableOptionalUnion   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling,BaseModelOutputWithPoolingAndCrossAttentions'BaseModelOutputWithPoolingAndProjection)ALL_ATTENTION_FUNCTIONSPreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringcan_return_tuplefilter_out_non_signature_kwargslogging	torch_int   )AltCLIPConfigAltCLIPTextConfigAltCLIPVisionConfiglogitsreturnc                 C   s   t j| tjt| | jdS )Ndevice)nn
functionalcross_entropytorcharangelenr    )r    r'   i/home/ubuntu/veenaModal/venv/lib/python3.10/site-packages/transformers/models/altclip/modeling_altclip.pycontrastive_loss+   s   r)   
similarityc                 C   s    t | }t |  }|| d S )Ng       @)r)   t)r*   caption_loss
image_lossr'   r'   r(   	clip_loss/   s   r.   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< d
ee fddZdS )AltCLIPOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPTextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The image embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPVisionModel`].
    text_model_output (`BaseModelOutputWithPooling`):
        The output of the [`AltCLIPTextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`AltCLIPVisionModel`].
    Nlosslogits_per_imagelogits_per_texttext_embedsimage_embedstext_model_outputvision_model_outputr   c                    s   t  fdd  D S )Nc                 3   s.    | ]}|d vr | nt  | V  qdS ))r5   r6   N)getattrto_tuple).0kselfr'   r(   	<genexpr>U   s
    
z)AltCLIPOutput.to_tuple.<locals>.<genexpr>)tuplekeysr;   r'   r;   r(   r8   T   s   zAltCLIPOutput.to_tuple)__name__
__module____qualname____doc__r0   r   r$   FloatTensor__annotations__r1   r2   r3   r4   r5   r   r6   r>   r   r8   r'   r'   r'   r(   r/   5   s   
 r/   c                       s4   e Zd ZdZ fddZ	d
ddZdd	 Z  ZS )AltRobertaEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd |j| _tj|j|j| jd| _	d S )N)padding_idxepsposition_embedding_typeabsoluteposition_idsr   F
persistenttoken_type_idsdtype)super__init__r!   	Embedding
vocab_sizehidden_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutr7   rJ   register_bufferr$   r%   expandzerosrL   sizelongrG   r<   config	__class__r'   r(   rU   b   s"   
zAltRobertaEmbeddings.__init__Nr   c                 C   s   |d u r|d urt || j|}n| |}|d ur| }n| d d }|d }|d u rTt| drI| jd d d |f }||d |}	|	}ntj|tj	| j
jd}|d u r]| |}| |}
||
 }| jdkrt| |}||7 }| |}| |}|S )NrN   r   rQ   r   rS   r    rK   )"create_position_ids_from_input_idsrG   &create_position_ids_from_inputs_embedsrg   hasattrrQ   re   r$   rf   rh   rL   r    rZ   r^   rJ   r\   r_   rc   )r<   	input_idsrQ   rL   inputs_embedspast_key_values_lengthinput_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedr^   
embeddingsr\   r'   r'   r(   forward{   s0   








zAltRobertaEmbeddings.forwardc                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        NrN   r   rm   r   )rg   r$   r%   rG   rh   r    	unsqueezere   )r<   rr   rt   sequence_lengthrL   r'   r'   r(   ro      s   	z;AltRobertaEmbeddings.create_position_ids_from_inputs_embeds)NNNNr   )r@   rA   rB   rC   rU   ry   ro   __classcell__r'   r'   rk   r(   rF   \   s    
(rF   c                       s\   e Zd Zd fdd	Z			ddejdeej deej dee d	e	ej f
d
dZ
  ZS )AltRobertaSelfAttentionNc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|p\t|dd| _| jdksh| jd	kr{|j| _t	d
|j d | j| _d S d S )Nr   embedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()rJ   rK   relative_keyrelative_key_query   r   )rT   rU   rX   num_attention_headsrp   
ValueErrorintattention_head_sizeall_head_sizer!   Linearquerykeyvaluera   attention_probs_dropout_probrc   r7   rJ   r[   rV   distance_embeddingr<   rj   rJ   rk   r'   r(   rU      s*   

z AltRobertaSelfAttention.__init__Fhidden_statesattention_mask	head_maskoutput_attentionsr   c                 C   s  |j d d }g |d| jR }| ||dd}| ||dd}| ||dd}	t||dd}
| j	dksI| j	dkr|j d |j d }}tj
|tj|jddd}tj
|tj|jddd}|| }| || j d }|j|jd}| j	dkrtd	||}|
| }
n| j	dkrtd	||}td
||}|
| | }
|
t| j }
|d ur|
| }
tjj|
dd}| |}|d ur|| }t||	}|dddd }| d d | jf }||}|r||f}|S |f}|S )NrN   r   r   r   r   rm   rR   zbhld,lrd->bhlrzbhrd,lrd->bhlrdimr   r   )shaper   r   view	transposer   r   r$   matmulrJ   r%   rh   r    r   r[   torS   einsummathsqrtr!   r"   softmaxrc   permute
contiguousrg   r   )r<   r   r   r   r   rt   hidden_shapequery_layer	key_layervalue_layerattention_scoresquery_length
key_lengthposition_ids_lposition_ids_rdistancepositional_embeddingrelative_position_scoresrelative_position_scores_queryrelative_position_scores_keyattention_probscontext_layernew_context_layer_shapeoutputsr'   r'   r(   ry      sF   




zAltRobertaSelfAttention.forwardNNNF)r@   rA   rB   rU   r$   Tensorr   rD   boolr>   ry   r|   r'   r'   rk   r(   r}      s     r}   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )AltRobertaSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S NrH   )rT   rU   r!   r   rX   denser_   r`   ra   rb   rc   ri   rk   r'   r(   rU        
zAltRobertaSelfOutput.__init__r   input_tensorr   c                 C   &   |  |}| |}| || }|S r   r   rc   r_   r<   r   r   r'   r'   r(   ry        

zAltRobertaSelfOutput.forwardr@   rA   rB   rU   r$   r   ry   r|   r'   r'   rk   r(   r         $r   eagerc                       sd   e Zd Zd fdd	Zdd Z			ddejdeej d	eej d
ee	 de
ej f
ddZ  ZS )AltRobertaAttentionNc                    s4   t    t|j ||d| _t|| _t | _d S )N)rJ   )	rT   rU   "ALT_ROBERTA_SELF_ATTENTION_CLASSES_attn_implementationr<   r   outputsetpruned_headsr   rk   r'   r(   rU      s   

zAltRobertaAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )r&   r   r<   r   r   r   r   r   r   r   r   r   r   union)r<   headsindexr'   r'   r(   prune_heads(  s   zAltRobertaAttention.prune_headsFr   r   r   r   r   c                 C   s8   | j ||||d}| |d |}|f|dd   }|S N)r   r   r   r   r   )r<   r   )r<   r   r   r   r   self_outputsattention_outputr   r'   r'   r(   ry   :  s   zAltRobertaAttention.forwardr   r   )r@   rA   rB   rU   r   r$   r   r   rD   r   r>   ry   r|   r'   r'   rk   r(   r     s"    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )AltRobertaIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )rT   rU   r!   r   rX   intermediate_sizer   
isinstance
hidden_actstrr   intermediate_act_fnri   rk   r'   r(   rU   N  s
   
zAltRobertaIntermediate.__init__r   r   c                 C   s   |  |}| |}|S r   )r   r   r<   r   r'   r'   r(   ry   V  s   

zAltRobertaIntermediate.forwardr   r'   r'   rk   r(   r   M  s    r   c                       r   )AltRobertaOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )rT   rU   r!   r   r   rX   r   r_   r`   ra   rb   rc   ri   rk   r'   r(   rU   ^  r   zAltRobertaOutput.__init__r   r   r   c                 C   r   r   r   r   r'   r'   r(   ry   d  r   zAltRobertaOutput.forwardr   r'   r'   rk   r(   r   ]  r   r   c                       sb   e Zd Z fddZ			ddejdeej deej dee d	e	ej f
d
dZ
dd Z  ZS )AltRobertaLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S )Nr   )
rT   rU   chunk_size_feed_forwardseq_len_dimr   	attentionr   intermediater   r   ri   rk   r'   r(   rU   m  s   


zAltRobertaLayer.__init__NFr   r   r   r   r   c           
      K   sP   | j |f|||d|}|d }|dd  }t| j| j| j|}	|	f| }|S r   )r   r   feed_forward_chunkr   r   )
r<   r   r   r   r   kwargsself_attention_outputsr   r   layer_outputr'   r'   r(   ry   u  s    
zAltRobertaLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )r<   r   intermediate_outputr   r'   r'   r(   r     s   
z"AltRobertaLayer.feed_forward_chunkr   )r@   rA   rB   rU   r$   r   r   rD   r   r>   ry   r   r|   r'   r'   rk   r(   r   l  s"    
r   c                       sz   e Zd Z fddZe					ddejdeej deej d	ee	 d
ee	 dee	 de
eej ef fddZ  ZS )AltRobertaEncoderc                    :   t     | _t fddt jD | _d| _d S )Nc                       g | ]}t  qS r'   )r   )r9   irj   r'   r(   
<listcomp>      z.AltRobertaEncoder.__init__.<locals>.<listcomp>F)	rT   rU   rj   r!   
ModuleListrangenum_hidden_layerslayergradient_checkpointingri   rk   r   r(   rU        
 
zAltRobertaEncoder.__init__NFTr   r   r   r   output_hidden_statesreturn_dictr   c                 K   s   |rdnd }|r
dnd }	t | jD ].\}
}|r||f }|d ur$||
 nd }|d||||d|}|d }|r?|	|d f }	q|rG||f }t|||	dS )Nr'   )r   r   r   r   r   r   last_hidden_stater   
attentions)	enumerater   r
   )r<   r   r   r   r   r   r   r   all_hidden_statesall_self_attentionsr   layer_modulelayer_head_masklayer_outputsr'   r'   r(   ry     s2   

zAltRobertaEncoder.forward)NNFFT)r@   rA   rB   rU   r   r$   r   r   rD   r   r   r>   r
   ry   r|   r'   r'   rk   r(   r     s.    	r   c                       r   )AltRobertaPoolerc                    s*   t    t|j|j| _t | _d S r   )rT   rU   r!   r   rX   r   Tanh
activationri   rk   r'   r(   rU     s   
zAltRobertaPooler.__init__r   r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r<   r   first_token_tensorpooled_outputr'   r'   r(   ry     s   

zAltRobertaPooler.forwardr   r'   r'   rk   r(   r     s    r           moduler   r   r   r   scalingrc   c           
      K   s|   t ||dd| }|d ur|| }tjj|dt jd|j}tjj	||| j
d}t ||}	|	dd }	|	|fS )NrN   r   )r   rS   )ptrainingr   r   )r$   r   r   r!   r"   r   float32r   rS   rc   r  r   )
r  r   r   r   r   r  rc   r   attn_weightsattn_outputr'   r'   r(   eager_attention_forward  s   
r	  c                       sh   e Zd ZdZ fddZ			ddejdeej deej d	ee d
e	ejeej f f
ddZ
  ZS )AltCLIPAttentionz=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t    || _|j| _|j| _| j| j | _| j| j | jkr-td| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: ).      F)rT   rU   rj   rX   	embed_dimr   	num_headshead_dimr   scaleattention_dropoutrc   	is_causalr!   r   k_projv_projq_projout_projri   rk   r'   r(   rU     s$   

zAltCLIPAttention.__init__NFr   r   causal_attention_maskr   r   c              
   C   s0  |j \}}}| |}| |}	| |}
|||| j| jdd}|	||| j| jdd}	|
||| j| jdd}
| jj	dkrY|durR|durR|| }n|durX|}n|du| _
t}| jj	dkrlt| jj	 }|| ||	|
|| j
| j| js{dn| jd\}}|||| }| |}|sd}||fS )z#Input shape: Batch x Time x Channelr   r   flash_attention_2Nr   r  )r  r  rc   )r   r  r  r  r   r  r  r   rj   r   r  r	  r   r  r  rc   reshaper   r  )r<   r   r   r  r   
batch_sizeru   r  queriesr?   valuesattention_interfacer  r  r'   r'   r(   ry     s@   	






zAltCLIPAttention.forwardr   )r@   rA   rB   rC   rU   r$   r   r   r   r>   ry   r|   r'   r'   rk   r(   r
    s"    r
  c                       r   )
AltCLIPMLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S r   )rT   rU   rj   r   r   activation_fnr!   r   rX   r   fc1fc2ri   rk   r'   r(   rU   9  s
   
zAltCLIPMLP.__init__r   r   c                 C   s"   |  |}| |}| |}|S r   )r   r  r!  r   r'   r'   r(   ry   @  s   


zAltCLIPMLP.forwardr   r'   r'   rk   r(   r  8  s    r  c                       sT   e Zd Zdef fddZ	ddejdejdejdee d	e	ej
 f
d
dZ  ZS )AltCLIPEncoderLayerrj   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S r   )rT   rU   rX   r  r
  	self_attnr!   r_   r`   layer_norm1r  mlplayer_norm2ri   rk   r'   r(   rU   H  s   


zAltCLIPEncoderLayer.__init__Fr   r   r  r   r   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r0||f7 }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r   r  r   )r$  r#  r&  r%  )r<   r   r   r  r   residualr  r   r'   r'   r(   ry   P  s"   




zAltCLIPEncoderLayer.forwardF)r@   rA   rB   r   rU   r$   r   r   r   r>   rD   ry   r|   r'   r'   rk   r(   r"  G  s    r"  c                       sx   e Zd ZdZdef fddZe					ddeej	 deej	 dee
 d	ee
 d
ee
 deeef fddZ  ZS )AltCLIPEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`AltCLIPEncoderLayer`].

    Args:
        config: AltCLIPConfig
    rj   c                    r   )Nc                    r   r'   )r"  )r9   _r   r'   r(   r     r   z+AltCLIPEncoder.__init__.<locals>.<listcomp>F)	rT   rU   rj   r!   r   r   r   layersr   ri   rk   r   r(   rU     r   zAltCLIPEncoder.__init__Nr   r  r   r   r   r   c                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|r"dnd}|r(dnd}|}	t| jD ] \}
}|r<||	f }||	|||d}|d }	|rQ||d f }q1|rY||	f }t|	||dS )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr'   )r   r   r   r   )rj   r   r   use_return_dictr   r+  r
   )r<   rr   r   r  r   r   r   encoder_statesall_attentionsr   idxencoder_layerr   r'   r'   r(   ry     s2   '

zAltCLIPEncoder.forward)NNNNN)r@   rA   rB   rC   r   rU   r   r   r$   r   r   r   r>   r
   ry   r|   r'   r'   rk   r(   r)  y  s,    
r)  c                       sX   e Zd Zdef fddZdejdededejfdd	Zddej	dejfddZ
  ZS )AltCLIPVisionEmbeddingsrj   c                    s   t    || _|j| _|j| _|j| _tt	
| j| _tj|j| j| j| jdd| _| j| j d | _| jd | _t| j| j| _| jdt	| jddd d S )NF)in_channelsout_channelskernel_sizestridebiasr   r   rL   rM   rO   )rT   rU   rj   rX   r  
image_size
patch_sizer!   	Parameterr$   randnclass_embeddingConv2dnum_channelspatch_embeddingnum_patchesnum_positionsrV   position_embeddingrd   r%   re   ri   rk   r'   r(   rU     s"   
"z AltCLIPVisionEmbeddings.__init__rx   heightwidthr   c                 C   s  |j d d }| jjd}|j d d }tj s(||kr(||kr(| | jS |ddddf }|ddddf }|j d }	|| j }
|| j }t	|d }|
d|||	}|dddd}tjj||
|fdd	d
}|dddddd|	}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   NrN   g      ?r   r   bicubicF)rg   modealign_cornersr   )r   rA  weightrz   r$   jit
is_tracingrL   r8  r   r  r   r!   r"   interpolater   cat)r<   rx   rB  rC  r?  rA  r@  class_pos_embedpatch_pos_embedr   
new_height	new_widthsqrt_num_positionsr'   r'   r(   interpolate_pos_encoding  s*   



z0AltCLIPVisionEmbeddings.interpolate_pos_encodingFpixel_valuesc              
   C   s   |j \}}}}|s&|| jks|| jkr&td| d| d| j d| j d	| jjj}| |j|d}|ddd}| j	
|dd}	tj|	|gdd	}
|r[|
| |
|| }
|
S |
| | j }
|
S )
NzInput image size (*z) doesn't match model (r  rR   r   r   rN   r   )r   r7  r   r>  rG  rS   r   flattenr   r;  re   r$   rK  rQ  rA  rL   )r<   rR  rQ  r  r*  rB  rC  target_dtypepatch_embedsclass_embedsrx   r'   r'   r(   ry     s    
zAltCLIPVisionEmbeddings.forwardr(  )r@   rA   rB   r   rU   r$   r   r   rQ  rD   ry   r|   r'   r'   rk   r(   r1    s     )r1  c                   @   s*   e Zd ZU eed< dZdZg Zdd ZdS )AltCLIPPreTrainedModelrj   altclipTc                 C   s  | j j}t|tr:| j j}tjj|jd|jd | d tjj|j	j
|j j| d tjj|jj
|j j| d dS t|tr| j j}|jd d|j j d  | }|jd | }tjj|jj
|d tjj|jj
|d tjj|jj
|d tjj|jj
|d dS t|tr| j j}|j jd d|j j d  | }d|j j d | }tjj|jj
|d tjj|jj
|d dS t|trtjj|jj
|jd | j j d d|j_tjj|jj
|jd | j j d d|j_dS t|tjr|jj   |j
j!d dS t|tj"r%|j
jjd| j jd |jdur#|jj   dS dS t|tj#rH|j
jjd| j jd |j$durJ|j
j|j$    dS dS dS )	zInitialize the weightsr  r  )meanstd)r[  r   Tg      ?N)%rj   initializer_factorr   r1  r!   initnormal_r;  r  r>  rG  initializer_rangerA  r
  r   r  r  r  r  r  rX   r   r!  AltCLIPModeltext_projectiontext_embed_dim_is_hf_initializedvisual_projectionvision_embed_dimr_   r6  datazero_fill_r   rV   rG   )r<   r  factorin_proj_stdout_proj_stdfc_stdr'   r'   r(   _init_weights+  sZ   
 

 
z$AltCLIPPreTrainedModel._init_weightsN)	r@   rA   rB   r   rE   base_model_prefixsupports_gradient_checkpointing_no_split_modulerm  r'   r'   r'   r(   rX  $  s   
 rX  c                       sv   e Zd Zdef fddZee					ddeej	 dee
 dee
 d	ee
 d
ee
 deeef fddZ  ZS )AltCLIPVisionTransformerrj   c                    sR   t    || _|j}t|| _tj||jd| _	t
|| _tj||jd| _d S r   )rT   rU   rj   rX   r1  rx   r!   r_   r`   pre_layrnormr)  encoderpost_layernorm)r<   rj   r  rk   r'   r(   rU   Z  s   


z!AltCLIPVisionTransformer.__init__NFrR  r   r   r   rQ  r   c           
      C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|d u r&td| j||d}| |}| j|||dd}|d }|d d dd d f }	| |	}	t	||	|j
|jdS )Nz You have to specify pixel_values)rQ  T)rr   r   r   r   r   r   pooler_outputr   r   )rj   r   r   r,  r   rx   rr  rs  rt  r   r   r   )
r<   rR  r   r   r   rQ  r   encoder_outputsr   r   r'   r'   r(   ry   d  s.   


z AltCLIPVisionTransformer.forward)NNNNF)r@   rA   rB   r   rU   r   r   r   r$   rD   r   r   r>   r   ry   r|   r'   r'   rk   r(   rq  Y  s,    

rq  c                       s   e Zd ZU eed< dZdef fddZdejfddZ	e
						ddeej d
ee dee dedee deeef fddZ  ZS )AltCLIPVisionModelrj   rR  c                    s"   t  | t|| _|   d S r   )rT   rU   rq  vision_model	post_initri   rk   r'   r(   rU     s   
zAltCLIPVisionModel.__init__r   c                 C   
   | j jjS r   )ry  rx   r>  r;   r'   r'   r(   get_input_embeddings     
z'AltCLIPVisionModel.get_input_embeddingsNFr   r   rQ  r   c                 C   s(   |dur|n| j j}| j|||||dS )a  
        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AltCLIPVisionModel

        >>> model = AltCLIPVisionModel.from_pretrained("BAAI/AltCLIP")
        >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```NrR  r   r   rQ  r   )rj   r,  ry  )r<   rR  r   r   rQ  r   r'   r'   r(   ry     s   zAltCLIPVisionModel.forward)NNNFN)r@   rA   rB   r   rE   main_input_namerU   r!   Moduler|  r   r   r$   rD   r   r   r>   r   ry   r|   r'   r'   rk   r(   rx    s0   
 
rx  aE  
    The model behaves as an encoder following the architecture described in *Attention is
    all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
    Kaiser and Illia Polosukhin.

    .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
    )custom_introc                       s   e Zd ZU eed< d fdd	Zdd Zdd Zd	d
 Ze										dde
ej de
ej de
ej de
ej de
ej de
ej de
e de
e de
e deeej ef fddZ  ZS )AltRobertaModelrj   Tc                    sD   t  | || _t|| _t|| _|rt|nd| _| 	  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)
rT   rU   rj   rF   rx   r   rs  r   poolerrz  )r<   rj   add_pooling_layerrk   r'   r(   rU     s   

zAltRobertaModel.__init__c                 C   s   | j jS r   rx   rZ   r;   r'   r'   r(   r|    s   z$AltRobertaModel.get_input_embeddingsc                 C   s   || j _d S r   r  r<   r   r'   r'   r(   set_input_embeddings     z$AltRobertaModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsrs  r   r   r   )r<   heads_to_pruner   r   r'   r'   r(   _prune_heads  s   zAltRobertaModel._prune_headsNrq   r   rQ   rL   r   rr   r   r   r   r   c
                 C   s  |d ur|n| j j}|d ur|n| j j}|	d ur|	n| j j}	|d ur*|d ur*td|d ur9| || | }
n|d urF| d d }
ntd|
\}}|d urU|jn|j}|d u retj	||f|d}|d u rt
| jdr| jjd d d |f }|||}|}n	tj|
tj|d}| ||
}| || j j}| j||||d}| j|||||dd	}|d
 }| jd ur| |nd }t|||j|jdS )NzDYou cannot specify both input_ids and inputs_embeds at the same timerN   z5You have to specify either input_ids or inputs_embedsr   rQ   rm   )rq   rL   rQ   rr   T)r   r   r   r   r   r   ru  )rj   r   r   r,  r   %warn_if_padding_and_no_attention_maskrg   r    r$   onesrp   rx   rQ   re   rf   rh   get_extended_attention_maskget_head_maskr   rs  r  r   r   r   )r<   rq   r   rQ   rL   r   rr   r   r   r   rt   r  ru   r    rv   rw   extended_attention_maskembedding_outputrw  sequence_outputr   r'   r'   r(   ry     s\   
zAltRobertaModel.forward)T	NNNNNNNNN)r@   rA   rB   r   rE   rU   r|  r  r  r   r   r$   r   r   r   r>   r   ry   r|   r'   r'   rk   r(   r    sJ   
 
	
r  c                       s   e Zd ZU eed<  fddZdejfddZdej	ddfd	d
Z
ddee dej	f fddZee									ddeej deej deej deej deej deej dee dee dee deeef fddZ  ZS )AltCLIPTextModelrj   c                    sL   t  | t|dd| _t|j|j| _tj	|j|j
d| _|   d S )NF)r  rH   )rT   rU   r  robertar!   r   rX   project_dimtransformationr_   r`   pre_LNrz  ri   rk   r'   r(   rU   <  s
   zAltCLIPTextModel.__init__r   c                 C   r{  r   r  rx   rZ   r;   r'   r'   r(   r|  C  r}  z%AltCLIPTextModel.get_input_embeddingsr   Nc                 C   s   || j j_d S r   r  r  r'   r'   r(   r  F  s   z%AltCLIPTextModel.set_input_embeddingsnew_num_tokensc                    s   t  |S r   )rT   resize_token_embeddings)r<   r  rk   r'   r(   r  I  r  z(AltCLIPTextModel.resize_token_embeddingsrq   r   rQ   rL   r   rr   r   r   r   c
                 C   sp   |dur|n| j j}| j||||||||	dd	}
|
d }| |}| |}|dddf }t|||
j|
jdS )a+  
        Examples:

        ```python
        >>> from transformers import AutoProcessor, AltCLIPTextModel

        >>> model = AltCLIPTextModel.from_pretrained("BAAI/AltCLIP")
        >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")

        >>> texts = ["it's a cat", "it's a dog"]

        >>> inputs = processor(text=texts, padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```NT)	rq   r   rQ   rL   r   rr   r   r   r   r   ru  )rj   r,  r  r  r  r   r   r   )r<   rq   r   rQ   rL   r   rr   r   r   r   r   r  projection_staterv  r'   r'   r(   ry   L  s,    

zAltCLIPTextModel.forwardr   r  )r@   rA   rB   r   rE   rU   r!   r  r|  rV   r  r   r   r  r   r   r$   r   r   r   r>   r   ry   r|   r'   r'   rk   r(   r  9  sL   
 	

r  c                       s   e Zd ZU eed< def fddZe e			ddej	de
ej	 de
ej	 de
ej	 d	ejf
d
dZe e	ddejded	ejfddZe										dde
ej de
ej de
ej	 de
ej de
ej	 de
e de
e de
e dede
e d	eeef fddZ  ZS )r`  rj   c                    s   t  | t|jtstdt|j dt|jts(tdt|j d|j}|j}|j	|_	|j
| _
|j| _|j| _t|| _t|| _tj| j| j
dd| _tj| j| j
dd| _tt| jj| _|   d S )NzRconfig.vision_config is expected to be of type AltCLIPVisionConfig but is of type .zNconfig.text_config is expected to be of type AltCLIPTextConfig but is of type F)r6  )rT   rU   r   vision_configr   	TypeErrortypetext_configr   r   projection_dimr  rb  rX   re  r  
text_modelrq  ry  r!   r   rd  ra  r9  r$   tensorrj   logit_scale_init_valuelogit_scalerz  )r<   rj   r  r  rk   r'   r(   rU     s2   

zAltCLIPModel.__init__Nrq   r   rL   rQ   r   c                 C   s&   | j ||||d}|j}| |}|S )a  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`AltCLIPTextModel`].

        Examples:

        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, AltCLIPModel

        >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
        >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")

        >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
        >>> with torch.inference_mode():
        ...     text_features = model.get_text_features(**inputs)
        ```)rq   r   rL   rQ   )r  rv  ra  )r<   rq   r   rL   rQ   text_outputsr   text_featuresr'   r'   r(   get_text_features  s   
zAltCLIPModel.get_text_featuresFrR  rQ  c                 C   s"   | j ||d}|j}| |}|S )aQ  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`AltCLIPVisionModel`].

        Examples:

        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, AltCLIPModel
        >>> from transformers.image_utils import load_image

        >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
        >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = load_image(url)

        >>> inputs = processor(images=image, return_tensors="pt")
        >>> with torch.inference_mode():
        ...     image_features = model.get_image_features(**inputs)
        ```)rR  rQ  )ry  rv  rd  )r<   rR  rQ  vision_outputsr   image_featuresr'   r'   r(   get_image_features  s   
zAltCLIPModel.get_image_featuresreturn_lossr   r   r   c              	   C   s(  |dur|n| j j}|dur|n| j j}|
dur|
n| j j}
| j|||||||
d}| j||||	|
d}|d }| |}|d }| |}||jdddd }||jdddd }| j	
 }t|| | }|j}d}|rtt|}|
s||||||f}|dur|f| S |S t|||||||d	S )
a  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AltCLIPModel

        >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
        >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> inputs = processor(
        ...     text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
        ... )
        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```N)rq   r   rQ   rL   r   r   r   r~  r   r   rN   T)r  r   keepdim)r0   r1   r2   r3   r4   r5   r6   )rj   r   r   r,  r  ry  rd  ra  normr  expr$   r   r+   Tr.   r/   )r<   rq   rR  r   rL   rQ   r  r   r   rQ  r   r  r  r4   r3   r  r2   r1   r0   r   r'   r'   r(   ry     sX   %



zAltCLIPModel.forward)NNNr(  )
NNNNNNNNFN)r@   rA   rB   r   rE   rU   r   r   r$   r   r   rD   r  r   r  
LongTensorr   r>   r/   ry   r|   r'   r'   rk   r(   r`    s~   
 !$$	

r`  c                 C   s6   |  | }tj|dd|| | }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   r   )ner   r$   cumsumtype_asrh   )rq   rG   rs   maskincremental_indicesr'   r'   r(   rn   \  s   rn   )rX  rx  r  r`  )r  )r   )HrC   r   dataclassesr   typingr   r   r   r   r$   torch.nnr!   activationsr   modeling_layersr	   modeling_outputsr
   r   r   r   modeling_utilsr   r   pytorch_utilsr   r   r   utilsr   r   r   r   r   r   configuration_altclipr   r   r   
get_loggerr@   loggerr   r)   r.   r/   r  rF   r}   r   r   r   r   r   r   r   r   floatr	  r
  r  r"  r)  r1  rX  rq  rx  r  r  r`  rn   __all__r'   r'   r'   r(   <module>   s    
$YW.)2
J2XS445	nS 
Q