o
    ei                     @   s  d dl mZ d dlmZ d dlZd dlmZ ddlmZ ddl	m
Z
 ddlmZ dd	lmZmZ dd
lmZmZ ddlmZmZmZmZ ddlmZ eeZeeddG dd deZeeddG dd deZG dd dejZ G dd dejZ!	dPdejdej"dej"dej"dej"dB de#d e#fd!d"Z$d#d$ Z%G d%d& d&ejZ&dQd(ej"d)e#d*e'd+ej"fd,d-Z(G d.d/ d/ejZ)G d0d1 d1ejZ*G d2d3 d3eZ+G d4d5 d5ejZ,d6ej"d7e-ej" d+ej"fd8d9Z.G d:d; d;ejZ/G d<d= d=ejZ0G d>d? d?ejZ1G d@dA dAejZ2G dBdC dCeZ3G dDdE dEeZ4G dFdG dGejZ5eG dHdI dIeZ6eG dJdK dKe6Z7edLdG dMdN dNe6Z8g dOZ9dS )R    )Callable)	dataclassN)nn   )initialization)ACT2FN)GradientCheckpointingLayer)BaseModelOutputImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuplelogging   )VJEPA2ConfigzO
    VJEPA Predictor outputs that also contains the masked encoder outputs
    )custom_introc                   @   st   e Zd ZU dZejed< dZejdB ed< dZe	ejdf dB ed< dZ
e	ejdf dB ed< dZejdB ed< dS )	$VJEPA2WithMaskedInputPredictorOutputa  
    masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
        The masked hidden state of the model.
    target_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs):
        The target hidden state of the model.
    last_hidden_stateNmasked_hidden_state.hidden_states
attentionstarget_hidden_state)__name__
__module____qualname____doc__torchFloatTensor__annotations__r   r   tupler   r    r"   r"   h/home/ubuntu/transcripts/venv/lib/python3.10/site-packages/transformers/models/vjepa2/modeling_vjepa2.pyr       s   
 
r   zs
    VJEPA outputs that also contains the masked encoder outputs
    Optionally contains the predictor outputs
    c                       s   e Zd ZU dZejed< dZejdB ed< dZe	ejdf dB ed< dZ
e	ejdf dB ed< dZedB ed<  fd	d
Z  ZS ) VJEPA2WithMaskedInputModelOutputaq  
    masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
        The masked hidden state of the model.
    predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*):
        The output from the Predictor module.
    r   Nr   .r   r   predictor_outputc                    s4   t t  }t|d tr|d  |d< t|S )N)listsuperto_tuple
isinstancer   r!   )selfoutput	__class__r"   r#   r)   J   s   z)VJEPA2WithMaskedInputModelOutput.to_tuple)r   r   r   r   r   r   r    r   r   r!   r   r%   r   r)   __classcell__r"   r"   r-   r#   r$   5   s   
 
r$   c                       sP   e Zd ZdZ	ddedef fddZedd Zd	e	j
d
e	j
fddZ  ZS )VJEPA2PatchEmbeddings3Dz"
    Image to Patch Embedding
       confighidden_sizec                    sR   t    |j| _|j| _|| _tj|j||j|j|jf|j|j|jfd| _d S )N)in_channelsout_channelskernel_sizestride)	r(   __init__
patch_sizetubelet_sizer3   r   Conv3din_chansprojr+   r2   r3   r-   r"   r#   r8   V   s   
z VJEPA2PatchEmbeddings3D.__init__c                 C   s$   | j | j | j| j  | j| j  S Nframes_per_clipr:   	crop_sizer9   r2   r"   r"   r#   num_patchesg   s   


z#VJEPA2PatchEmbeddings3D.num_patchespixel_values_videosreturnc                 C   s   |  |ddd}|S )N   r   )r=   flatten	transpose)r+   rE   xr"   r"   r#   forwardo   s   zVJEPA2PatchEmbeddings3D.forwardr1   )r   r   r   r   r   intr8   staticmethodrD   r   TensorrK   r/   r"   r"   r-   r#   r0   Q   s    
r0   c                       sB   e Zd ZdZddedef fddZdejdejfd	d
Z	  Z
S )VJEPA2Embeddings>
    Construct mask token, position and patch embeddings.
    r1   r2   r3   c                    s:   t    || _|| _t||d| _| jj| _|j| _d S )Nr3   )r(   r8   r2   r3   r0   patch_embeddingsrD   r9   r>   r-   r"   r#   r8   y   s   

zVJEPA2Embeddings.__init__rE   rF   c                 C   sd   |j d }|ddddd}|| jjk r|dd| jjdd}| jjjj}|j	|d}| |}|S )Nr   r   rG   r      )dtype)
shapepermuter2   r:   repeatrS   r=   weightrU   to)r+   rE   
num_framestarget_dtype
embeddingsr"   r"   r#   rK      s   

zVJEPA2Embeddings.forwardrL   )r   r   r   r   r   rM   r8   r   rO   rK   r/   r"   r"   r-   r#   rP   t   s    
rP           modulequerykeyvalueattention_maskscalingdropoutc           
      K   sl   t ||dd| }tjj|dt jd|j}tjj	||| j
d}t ||}	|	dd }	|	|fS )Nr&   )dimrU   )ptrainingr   rG   )r   matmulrI   r   
functionalsoftmaxfloat32rZ   rU   re   ri   
contiguous)
r_   r`   ra   rb   rc   rd   re   kwargsattn_weightsattn_outputr"   r"   r#   eager_attention_forward   s   rr   c                 C   s   |   \}}}}tj|d | j| jd}||d  }dd|  }|d| }| }| }	|dddd}|	dddd}	| 	dd}
|
j
dd	\}}tj| |fdd	}
|
d
}
| |	 |
|  S )NrG   rU   deviceg       @g      ?i'  r&   r   )r&   rG   rg   rf   )sizer   arangerU   rt   	unsqueezesincosrX   	unflattenunbindstackrH   )rJ   posB	num_headsNDomegafreqemb_sinemb_cosyy1y2r"   r"   r#   rotate_queries_or_keys   s   
r   c                	       s   e Zd Z		ddededef fddZdd	 Zd
d ZdddZdd Z			dde
jdB dedee
je
jf ee
j B fddZ  ZS )VJEPA2RopeAttentionr1      r2   r3   num_attention_headsc                    sD  t    || _|| _|| _|| dkr td|f d| dt|| | _| j| j | _t	j
|| j|jd| _t	j
|| j|jd| _t	j
|| j|jd| _t	
||| _|j| _t	| j| _| jj| jj | _| jj| jj | _td| jd d  | _td| jd d  | _td| jd d  | _| jd | _d	| _d S )
Nr   zThe hidden size z4 is not a multiple of the number of attention heads .biasrG   r         F)r(   r8   r2   r3   r   
ValueErrorrM   attention_head_sizeall_head_sizer   Linearqkv_biasr`   ra   rb   r=   attention_probs_dropout_probdropout_probDropoutre   rB   r9   	grid_sizerA   r:   
grid_depthd_dimh_dimw_dimrd   	is_causal)r+   r2   r3   r   r-   r"   r#   r8      s2   


zVJEPA2RopeAttention.__init__c                 C   s   t | j| j }|| S r?   )rM   r   )r+   idstokens_per_framer"   r"   r#   _get_frame_pos   s   z"VJEPA2RopeAttention._get_frame_posc                 C   s4   t | j| j }| |}|||  }| j}|| S r?   )rM   r   r   )r+   r   r   	frame_idstokens_per_rowr"   r"   r#   _get_height_pos   s
   
z#VJEPA2RopeAttention._get_height_posNc                 C   s   |j }|d}|d ur|dd| jd}ntj||d}t| j| j }| 	|}| j}| 
|}	|||  ||	  }
||	|
fS )Nr   rt   )rt   rv   rx   rX   r   r   rw   rM   r   r   r   )r+   rJ   masksrt   
token_sizer   r   r   r   
height_ids	width_idsr"   r"   r#   get_position_ids   s   



z$VJEPA2RopeAttention.get_position_idsc                 C   s   |\}}}d}t |d||| j f |d}|| j7 }t |d||| j f |d}|| j7 }t |d||| j f |d}	|| j7 }|| jk r]|d|d f }
tj|||	|
gdd}|S tj|||	gdd}|S )Nr   .)r~   r&   ru   )r   r   r   r   r   r   cat)r+   qkpos_idsd_maskh_maskw_masksqkdqkhqkwqkrr"   r"   r#   apply_rotary_embeddings  s   




z+VJEPA2RopeAttention.apply_rotary_embeddingsFposition_maskoutput_attentionsrF   c              
   C   s  |j \}}}| ||d| j| jdd}| ||d| j| jdd}| ||d| j| jdd}	| j||d}
| 	||
}| 	||
}t
| jjt}|| |||	d | j| j| jscdn| jd\}}| d d | jf }| ||}|r||f}|S |f}|S )Nr&   r   rG   )r   r^   r   rd   re   rf   )rV   r`   viewr   r   rI   ra   rb   r   r   r   get_interfacer2   _attn_implementationrr   r   rd   ri   r   rv   r   r=   reshape)r+   r   r   r   
batch_size
seq_length_query_layer	key_layervalue_layerr   attention_interfacecontext_layerattention_probsnew_context_layer_shapeoutputsr"   r"   r#   rK   %  sF   
zVJEPA2RopeAttention.forward)r1   r   r?   NF)r   r   r   r   rM   r8   r   r   r   r   r   rO   boolr!   rK   r/   r"   r"   r-   r#   r      s.    %
	r   Finput	drop_probri   rF   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )zc
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    r^   r   r   r   rs   )rV   ndimr   randrU   rt   floor_div)r   r   ri   	keep_probrV   random_tensorr,   r"   r"   r#   	drop_pathX  s   r   c                       sP   e Zd ZdZddedB f fddZdejdejfdd	Zde	fd
dZ
  ZS )VJEPA2DropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   c                    s   t    || _d S r?   )r(   r8   r   )r+   r   r-   r"   r#   r8   k  s   

zVJEPA2DropPath.__init__r   rF   c                 C   s   t || j| jS r?   )r   r   ri   )r+   r   r"   r"   r#   rK   o  s   zVJEPA2DropPath.forwardc                 C   s   d| j  S )Nzp=)r   r+   r"   r"   r#   
extra_reprr  s   zVJEPA2DropPath.extra_reprr?   )r   r   r   r   floatr8   r   rO   rK   strr   r/   r"   r"   r-   r#   r   h  s
    r   c                       sB   e Zd Zddededef fddZdejd	ejfd
dZ	  Z
S )	VJEPA2MLPr1         @r2   r3   	mlp_ratioc                    sR   t    | }}t|| }tj||dd| _t|j | _tj||dd| _	d S NTr   )
r(   r8   rM   r   r   fc1r   
hidden_act
activationfc2)r+   r2   r3   r   in_featuresout_featureshidden_featuresr-   r"   r#   r8   w  s   
zVJEPA2MLP.__init__hidden_staterF   c                 C   s"   |  |}| |}| |}|S r?   )r   r   r   )r+   r   r"   r"   r#   rK     s   


zVJEPA2MLP.forward)r1   r   )r   r   r   r   rM   r   r8   r   rO   rK   r/   r"   r"   r-   r#   r   v  s    r   c                       sr   e Zd ZdZ				ddededed	ed
ef
 fddZ		ddej	dej	dB de
deej	df fddZ  ZS )VJEPA2LayerzCThis corresponds to the Block class in the original implementation.r^   r1   r   r   r2   drop_path_rater3   r   r   c                    s   t    || _|| _|| _|| _tj||jd| _	t
|||| _|jdkr*t|nt | _tj||jd| _t|||d| _d S )Nepsr^   )r3   r   )r(   r8   r2   r3   r   r   r   	LayerNormlayer_norm_epsnorm1r   	attentionr   r   Identityr   norm2r   mlp)r+   r2   r   r3   r   r   r-   r"   r#   r8     s   
zVJEPA2Layer.__init__NFr   r   r   rF   .c                 C   st   |}|  |}| j|||d}|d }| || }|}| |}| |}| || }|dd  }|f| }|S )N)r   r   r   r   )r   r   r   r   r   )r+   r   r   r   residualself_attention_outputsattention_outputr   r"   r"   r#   rK     s    



zVJEPA2Layer.forward)r^   r1   r   r   r   )r   r   r   r   r   r   rM   r8   r   rO   r   r!   rK   r/   r"   r"   r-   r#   r     s8    r   c                
       sN   e Zd Zdef fddZe			ddejdB deded	e	fd
dZ
  ZS )VJEPA2Encoderr2   c                    sx   t     | _t  jd| _ fddt jD t	 fddt jD | _
tj j jd| _d| _d S )NrR   c                    .   g | ]} j d kr j|  j d   ndqS r   r^   )num_hidden_layersr   .0irC   r"   r#   
<listcomp>  s     z*VJEPA2Encoder.__init__.<locals>.<listcomp>c              	      (   g | ]}t  |  j j jd qS )r   r3   r   r   )r   r3   r   r   r   r2   drop_path_ratesr"   r#   r         r   F)r(   r8   r2   rP   r3   r]   ranger   r   
ModuleListlayerr   r   	layernormgradient_checkpointingr+   r2   r-   r   r#   r8     s   


zVJEPA2Encoder.__init__NFrE   r   output_hidden_statesrF   c                 K   s   |rdnd }|r
dnd }|  |}t| jD ]\}}	|r!||f }|	|d |}
|
d }|r4||
d f }q| |}|rA||f }t|||dS )Nr"   r   r   r   r   r   )r]   	enumerater  r  r	   )r+   rE   r   r  ro   all_hidden_statesall_self_attentionsr   r   layer_modulelayer_outputsr"   r"   r#   rK     s&   



zVJEPA2Encoder.forward)NFF)r   r   r   r   r8   r   r   rO   r   r	   rK   r/   r"   r"   r-   r#   r     s    r   tensorr   c                 C   sX   g }|D ] }| | j}|ddd| d}|tj| d|dg7 }qtj|ddS )z
    Args:
        tensor (`torch.Tensor`):
            Tensor of shape [batch_size, num_patches, feature_dim]
        masks (`List[torch.Tensor]`):
            List of tensors of shape [batch_size, num_patches] containing indices of patches to keep
    r&   r   rg   indexr   ru   )rZ   rt   rx   rX   rv   r   gatherr   )r  r   all_masked_tensorsmask	mask_keepr"   r"   r#   apply_masks  s   r  c                       sn   e Zd ZdZdef fddZedd Z	ddej	d	e
ej	 d
e
ej	 dedeej	ej	f f
ddZ  ZS )VJEPA2PredictorEmbeddingsrQ   r2   c                    sf   t    || _t|j|j| _d| _|j	| _
|j| _tt| jdd|j| _|j| _|| _d S )Nr   r   )r(   r8   r2   r   r   r3   pred_hidden_sizepredictor_embeddingsnum_mask_tokenspred_zero_init_mask_tokenszero_init_mask_tokenspred_num_mask_tokens	Parameterr   zerosmask_tokensr9   r  r-   r"   r#   r8     s   

z"VJEPA2PredictorEmbeddings.__init__c                 C   sF   | j dkr| j | j | j| j  | j| j  S | j| j | j| j  S )Nr   r@   rC   r"   r"   r#   rD     s   



z%VJEPA2PredictorEmbeddings.num_patchesr   r   context_masktarget_mask
mask_indexrF   c                 C   s   | d}| |}|| j }| j| }|d  d }|||d}t||}|t|dd}tj	||gdd}	tj	|dd}
tj	|dd}tj	|
|gdd}|	|fS )z
        hidden_states : encoder outputs (context)
        context_mask: tokens of the context (outputs from the encoder)
        target_mask: tokens to predict
        mask_index: index of the target mask to choose (useful for multiclip?)
        r   r   ru   )
rv   r  r  r   maxrX   r  lenr   r   )r+   r   r!  r"  r#  r   contexttargetmax_patch_numr]   cmtmr   r"   r"   r#   rK   %  s   




z!VJEPA2PredictorEmbeddings.forwardr   )r   r   r   r   r   r8   rN   rD   r   rO   r'   rM   r!   rK   r/   r"   r"   r-   r#   r    s"    
r  c                       sl   e Zd Zdef fddZdd Zdd Ze		dd	ej	d
e
ej	 de
ej	 dededefddZ  ZS )VJEPA2Predictorr2   c                    s   t     | _d| _t | _ fddt jD t	 fddt jD | _
tj j jd| _tj j jdd| _d S )NFc                    r   r   )pred_num_hidden_layersr   r   rC   r"   r#   r   T  s    
z,VJEPA2Predictor.__init__.<locals>.<listcomp>c              	      r   r   )r   r  pred_num_attention_headspred_mlp_ratior   r   r"   r#   r   ]  r  r   Tr   )r(   r8   r2   r  r  r]   r  r,  r   r  r  r   r  r   r  r   r3   r=   r  r-   r   r#   r8   O  s   


zVJEPA2Predictor.__init__c                 C   sZ   | |j}tj|d|d}| |j}|ddd|d}tj|d|d}||fS )Nr   r  r&   )rZ   rt   r   r  rx   expandrv   )r+   r   position_masksargsorthidden_states_argsortr"   r"   r#   sort_tokensk  s   zVJEPA2Predictor.sort_tokensc                 C   sH   | |j}tj|dd}|ddd|d}tj|d|d}|S )Nr   ru   r&   r  )rZ   rt   r   r1  rx   r/  rv   r  )r+   r   r1  reverse_argsortr"   r"   r#   unsort_tokensw  s
   zVJEPA2Predictor.unsort_tokensFencoder_hidden_statesr!  r"  r   r  rF   c                 K   s   |rdnd }|r
dnd }t ||}|j\}	}
}| |||\}}tj|dd}| |||\}}t| jD ]\}}|r@||f }||||}|d }|rS||d f }q5|r[||f }| |}| 	||}|d d |
d f }| 
|}t|||dS )Nr"   r   ru   r   r	  )r  rV   r]   r   r1  r3  r
  r  r  r5  r=   r	   )r+   r6  r!  r"  r   r  ro   r  r  r   N_ctxtr   r   r0  r1  r   r  r  r"   r"   r#   rK   ~  s4   





zVJEPA2Predictor.forward)FF)r   r   r   r   r8   r3  r5  r   r   rO   r'   r   r	   rK   r/   r"   r"   r-   r#   r+  N  s(    r+  c                       sb   e Zd ZdZdef fddZ		ddejdejdB d	edB d
e	ejejdB f fddZ
  ZS )VJEPA2PoolerSelfAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr2   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkr-td| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r   F)r(   r8   r2   r3   	embed_dimr   r   head_dimr   scaleattention_dropoutre   r   r   r   k_projv_projq_projout_projr  r-   r"   r#   r8     s$   

z"VJEPA2PoolerSelfAttention.__init__NFr   rc   r   rF   c              
   C   s   |j \}}}| |}| |}| |}	|||| j| jdd}|||| j| jdd}|	||| j| jdd}	t	| j
jt}
|
| |||	|| j| j| jsVdn| jd\}}|||| }| |}|sod}||fS z#Input shape: Batch x Time x Channelr   rG   r^   r   N)rV   r@  r>  r?  r   r   r;  rI   r   r   r2   r   rr   r   r<  ri   re   r   rn   rA  )r+   r   rc   r   r   r   r:  querieskeysvaluesr   rq   rp   r"   r"   r#   rK     s2   




z!VJEPA2PoolerSelfAttention.forwardr   r   r   r   r   r   r8   r   rO   r   r!   rK   r/   r"   r"   r-   r#   r8    s    r8  c                       sn   e Zd ZdZdef fddZ		ddejdejd	ejd
ejdB dedB de	ejejdB f fddZ
  ZS )VJEPA2PoolerCrossAttentionz_It's different from other cross-attention layers, doesn't have output projection layer (o_proj)r2   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkr-td| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _d S r9  )r(   r8   r2   r3   r:  r   r   r;  r   r<  r=  re   r   r   r   r>  r?  r@  r  r-   r"   r#   r8     s"   

z#VJEPA2PoolerCrossAttention.__init__NFrC  rD  rE  rc   r   rF   c              
   C   s   |j \}}}|j d }	| |}| |}| |}|||| j| jdd}|||	| j| jdd}|||	| j| jdd}t	| j
jt}
|
| ||||| j| j| js[dn| jd\}}|||| }|sod}||fS rB  )rV   r@  r>  r?  r   r   r;  rI   r   r   r2   r   rr   r   r<  ri   re   r   rn   )r+   rC  rD  rE  rc   r   r   q_seq_lengthr:  kv_seq_lengthr   rq   rp   r"   r"   r#   rK   	  s2   





z"VJEPA2PoolerCrossAttention.forwardr   rF  r"   r"   r-   r#   rG    s$    rG  c                       sR   e Zd Zdef fddZ	ddejdejdedB d	eejd
f fddZ	  Z
S )VJEPA2PoolerSelfAttentionLayerr2   c                    P   t    tj|j|jd| _t|| _tj|j|jd| _	t
||jd| _d S Nr   rR   )r(   r8   r   r   r3   r   layer_norm1r8  	self_attnlayer_norm2r   r   r  r-   r"   r#   r8   7  
   

z'VJEPA2PoolerSelfAttentionLayer.__init__Fr   rc   r   NrF   .c                 C   sb   |}|  |}| j|||d\}}|| }|}| |}| |}|| }|f}|r/||f7 }|S )a=  
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (`torch.FloatTensor`):
                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   rc   r   )rM  rN  rO  r   )r+   r   rc   r   r   rp   r   r"   r"   r#   rK   >  s    




z&VJEPA2PoolerSelfAttentionLayer.forward)Fr   r   r   r   r8   r   rO   r   r!   rK   r/   r"   r"   r-   r#   rJ  6  s    rJ  c                       sZ   e Zd Zdef fddZ		ddejdejdejdB d	ed
eejdf f
ddZ	  Z
S )VJEPA2PoolerCrossAttentionLayerr2   c                    rK  rL  )r(   r8   r   r   r3   r   rM  rG  
cross_attnrO  r   r   r  r-   r"   r#   r8   e  rP  z(VJEPA2PoolerCrossAttentionLayer.__init__NFrC  r   rc   r   rF   .c                 C   sh   |}|  |}| j|||||d^}}|| }|}| |}| |}|| }|f}|r2|t|7 }|S )N)rc   r   )rM  rS  rO  r   r!   )r+   rC  r   rc   r   r   rp   r   r"   r"   r#   rK   l  s$   



z'VJEPA2PoolerCrossAttentionLayer.forwardr   rQ  r"   r"   r-   r#   rR  d  s    rR  c                       s<   e Zd ZdZdef fddZdejdejfddZ  Z	S )	VJEPA2AttentivePoolerzAttentive Poolerr2   c                    sP   t    ttdd j| _t | _	t
 fddt jD | _d S )Nr   c                    s   g | ]}t  qS r"   )rJ  )r   r   rC   r"   r#   r     s    z2VJEPA2AttentivePooler.__init__.<locals>.<listcomp>)r(   r8   r   r  r   r  r3   query_tokensrR  cross_attention_layerr  r  num_pooler_layersself_attention_layersr  r-   rC   r#   r8     s   


zVJEPA2AttentivePooler.__init__r   rF   c                 C   sL   | j D ]
}||d dd }q| j|jd dd}| ||d }|dS )N)rc   r   r   )rX  rU  rX   rV   rV  squeeze)r+   r   r  rC  r"   r"   r#   rK     s
   

zVJEPA2AttentivePooler.forward)
r   r   r   r   r   r8   r   rO   rK   r/   r"   r"   r-   r#   rT    s    rT  c                   @   sF   e Zd ZU eed< dZdZdZdZg dZ	dZ
dZe dd Zd	S )
VJEPA2PreTrainedModelr2   vjepa2rE   videoT)r   rJ  rR  r  c                 C   s@  | j j}t|trNtj|j|d t|jdD ]\}}||d  }tj|j	j
j|d tj|jjj|d q|t|jd d  }tj|jjjj|d dS t|trh|jr^t|j dS tj|j|d dS t|tjtjtjfrtj|j|d |jdurt|j dS dS t|tjrt|j t|j dS dS )zInitialize the weights)stdr   g      ?N)r2   initializer_ranger*   rT  inittrunc_normal_rU  r
  rX  rN  rA  rY   r   r   r%  rV  r  r  zeros_r   r   r   Conv2dr;   r   r   ones_)r+   r_   init_stdr   r  r]  r"   r"   r#   _init_weights  s,   


z#VJEPA2PreTrainedModel._init_weightsN)r   r   r   r   r    base_model_prefixmain_input_nameinput_modalitiessupports_gradient_checkpointing_no_split_modules_supports_sdpa_supports_flash_attnr   no_gradre  r"   r"   r"   r#   rZ    s   
 rZ  c                       s   e Zd Zdef fddZdefddZee					dd	e	j
d
ee	j
 dB dee	j
 dB dededB dedB defddZde	j
fddZ  ZS )VJEPA2Modelr2   c                    s2   t  | || _t|| _t|| _|   d S r?   )r(   r8   r2   r   encoderr+  	predictor	post_initr  r-   r"   r#   r8     s
   

zVJEPA2Model.__init__rF   c                 C   s
   | j jjS r?   )ro  r]   rS   r   r"   r"   r#   get_input_embeddings  s   
z VJEPA2Model.get_input_embeddingsNFrE   r!  r"  skip_predictorr   r  c                 K   s  |dur|n| j j}|dur|n| j j}|du rtd| j|||d}|j}	|du r[|du r[|d}
|	d}tj||j	d
d|
dfg}tj||j	d
d|
dfg}|sv| j|	||||d}t|jt|	||j|jd}nd}t|	t|	||j|j|d	}|S )
az  
        context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
            The mask position ids indicating which encoder output patches are going to be exposed to the predictor.
            By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context
            available to the predictor.
        target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
            The mask position ids indicating which encoder output patches are going to be used as a prediction target
            for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating
            that the predictor should predict all encoder patches.
        skip_predictor (bool):
            flag to skip the predictor forward, useful if you just need the encoder outputs
        Nz'You have to specify pixel_values_videos)rE   r   r  r   r   r   )r6  r!  r"  r   r  )r   r   r   r   )r   r   r   r   r%   )r2   r   r  r   ro  r   rv   r   rw   rt   rx   rX   rp  r   r  r   r   r$   )r+   rE   r!  r"  rs  r   r  ro   encoder_outputssequence_outputr   r   predictor_outputsr%   encoder_outputr"   r"   r#   rK     sN   

""zVJEPA2Model.forwardc                 C   s   | j |dd}|jS )NT)rs  )rK   r   )r+   rE   rw  r"   r"   r#   get_vision_features!  s   zVJEPA2Model.get_vision_features)NNFNN)r   r   r   r   r8   r0   rr  r   r   r   rO   r'   r   r$   rK   rx  r/   r"   r"   r-   r#   rn    s4    
	Frn  z}
    V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler).
    c                       sd   e Zd Zdef fddZee			ddejdejdB de	dB de	dB d	e
eB f
d
dZ  ZS )VJEPA2ForVideoClassificationr2   c                    sJ   t  | |j| _t|| _t|| _tj|j	|jdd| _
|   d S r   )r(   r8   
num_labelsrn  r[  rT  poolerr   r   r3   
classifierrq  r  r-   r"   r#   r8   ,  s   

z%VJEPA2ForVideoClassification.__init__NrE   labelsr   r  rF   c                 K   s^   | j |d||d}|j}| |}| |}	d}
|dur%| j|	|| jd}
t|
|	|j|jdS )ag  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> import torch
        >>> import numpy as np
        >>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification

        >>> device = "cuda"

        >>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
        >>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device)

        >>> video = np.ones((64, 256, 256, 3))  # 64 frames, 256x256 RGB
        >>> inputs = video_processor(video, return_tensors="pt").to(device)

        >>> # For inference
        >>> with torch.no_grad():
        ...     outputs = model(**inputs)
        >>> logits = outputs.logits

        >>> predicted_label = logits.argmax(-1).item()
        >>> print(model.config.id2label[predicted_label])

        >>> # For training
        >>> labels = torch.ones(1, dtype=torch.long, device=device)
        >>> loss = model(**inputs, labels=labels).loss

        ```T)rE   rs  r   r  N)pooled_logitsr}  r2   )losslogitsr   r   )	r[  r   r{  r|  loss_functionr2   r
   r   r   )r+   rE   r}  r   r  ro   r   r   pooler_outputr  r  r"   r"   r#   rK   9  s$   -

z$VJEPA2ForVideoClassification.forward)NNN)r   r   r   r   r8   r   r   r   rO   r   r!   r
   rK   r/   r"   r"   r-   r#   ry  &  s$    ry  )rn  rZ  ry  )r^   )r^   F):collections.abcr   dataclassesr   r   r    r   r_  activationsr   modeling_layersr   modeling_outputsr	   r
   modeling_utilsr   r   utilsr   r   r   r   configuration_vjepa2r   
get_loggerr   loggerr   r$   Moduler0   rP   rO   r   rr   r   r   r   r   r   r   r   r   r'   r  r  r+  r8  rG  rJ  rR  rT  rZ  rn  ry  __all__r"   r"   r"   r#   <module>   s   
#*
  5 <FbAE.(*[R