o
    i                     @   s  d Z ddlZddlZddlmZ ddlmZ ddlm	Z	m
Z
 ddlZddlmZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZmZm Z m!Z!m"Z" ddl#m$Z$ e!%e&Z'dZ(dZ)eG dd deZ*eG dd deZ+eG dd deZ,dNddZ-dOddZ.dPd d!Z/G d"d# d#ej0Z1G d$d% d%ej0Z2G d&d' d'ej0Z3G d(d) d)ej0Z4G d*d+ d+ej0Z5G d,d- d-ej0Z6G d.d/ d/ej0Z7G d0d1 d1ej0Z8G d2d3 d3ej0Z9G d4d5 d5eZ:G d6d7 d7ej0Z;G d8d9 d9eZ<d:Z=d;Z>ed<e=G d=d> d>e<Z?G d?d@ d@ej0Z@edAe=G dBdC dCe<ZAG dDdE dEej0ZBG dFdG dGej0ZCG dHdI dIej0ZDedJe=G dKdL dLe<ZEg dMZFdS )QzPyTorch TVLT model.    N)deepcopy)	dataclass)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputSequenceClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
TvltConfigr   zZinengTang/tvlt-basec                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed	< dZeeejd
f  ed< dZeeejd
f  ed< dS )TvltModelOutputa  
    Class for TvltModel's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        last_pixel_hidden_state (`torch.FloatTensor` of shape `(batch_size, pixel_sequence_length, hidden_size)`):
            Pixel sequence of hidden-states at the output of the last layer of the model.
        last_audio_hidden_state (`torch.FloatTensor` of shape `(batch_size, audio_sequence_length, hidden_size)`):
            Audio sequence of hidden-states at the output of the last layer of the model.
        pixel_label_masks (`torch.FloatTensor` of shape `(batch_size, pixel_patch_length)`):
            Tensor indicating which pixel patches are masked (1) and which are not (0).
        audio_label_masks (`torch.FloatTensor` of shape `(batch_size, audio_patch_length)`):
            Tensor indicating which audio patches are masked (1) and which are not (0).
        pixel_ids_restore (`torch.LongTensor` of shape `(batch_size, pixel_patch_length)`):
            Tensor containing the ids permutation of pixel masking.
        audio_ids_restore (`torch.LongTensor` of shape `(batch_size, audio_patch_length)`):
            Tensor containing the ids permutation of audio masking.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlast_hidden_statelast_pixel_hidden_statelast_audio_hidden_statepixel_label_masksaudio_label_maskspixel_ids_restoreaudio_ids_restore.hidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   
LongTensorr   r   r    r!   tupler"    r,   r,   e/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/deprecated/tvlt/modeling_tvlt.pyr   0   s   
 r   c                   @   sX   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dS )TvltDecoderOutputaM  
    Class for TvltDecoder's outputs, with potential hidden states and attentions.

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
            Pixel reconstruction logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlogits.r!   r"   )r#   r$   r%   r&   r/   r   r'   r(   r)   r!   r+   r"   r,   r,   r,   r-   r.   Y   s
   
 r.   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeeejdf  ed< dZeeejdf  ed	< dS )
TvltForPreTrainingOutputa
  
    Class for TvltForPreTraining's outputs, with potential hidden states and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`):
            Pixel reconstruction loss.
        matching_logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
            Matching objective logits.
        pixel_logits (`torch.FloatTensor` of shape
            `(batch_size, pixel_patch_length, image_patch_size ** 3 * pixel_num_channels)`): Pixel reconstruction
            logits.
        audio_logits (`torch.FloatTensor` of shape
            `(batch_size, audio_patch_length, image_patch_size[0] * image_patch_size[1])`): Audio reconstruction
            logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlossmatching_logitspixel_logitsaudio_logits.r!   r"   )r#   r$   r%   r&   r1   r   r'   r(   r)   r2   r3   r4   r!   r+   r"   r,   r,   r,   r-   r0   p   s   
 r0         ?c                 C   s>   | j dd \}}tj||f| jd}t|d|  }||fS )!Generate noise for audio masking.N   devicer   )shaper'   randr9   int)pixel_values
pixel_mask
mask_ratio
batch_sizeseq_lennoiselen_keepr,   r,   r-   generate_pixel_mask_noise   s   rD   patch-level   c           
      C   s   | j dd \}}|dkr'|| }tj||| jdddd|||}n|dkr4tj||| jd}t|d|  }	||	fS )r6   Nr7   zframe-levelr8   r   rE   )r:   r'   r;   r9   	unsqueezerepeatviewr<   )
audio_values
audio_maskr?   	mask_typefreq_lenr@   rA   num_time_patchesrB   rC   r,   r,   r-   generate_audio_mask_noise   s   
rP   c                 C   s   | j \}}}tj|dd}tj|dd}|ddd|f }	tj| d|	ddd|d}
tj||g| jd}d|ddd|f< tj|d|d}|durZ||9 }tj|d|	d}|
|||fS )z
    Perform random masking by per-sample shuffling on frame-level. Per-sample shuffling is done by argsort random
    noise. sequence: [batch_size, seq_len, hidden_dim], sequence
    r   dimNrG   rR   indexr8   r   )r:   r'   argsortgatherrH   rI   onesr9   )sequencerB   rC   attention_masksr@   rA   
hidden_dimids_shuffleids_restoreids_keepsequence_maskedlabel_masksr,   r,   r-   random_masking   s    r`   c                       *   e Zd ZdZ fddZdddZ  ZS )TvltPixelEmbeddings,Construct the patch and position embeddings.c                    st   t    t|| _| jj| _ttdd|j	| _
ttd|j|j	| _ttd| j|j	| _|| _d S Nr   )super__init__TvltPixelPatchEmbeddingspatch_embeddingsnum_patches_per_imager   	Parameterr'   zeroshidden_sizetype_embed_v
num_framestemporal_embedpos_embed_vconfigselfrq   	__class__r,   r-   rf      s   



zTvltPixelEmbeddings.__init__Nc           	      C   sh   |j \}}}}}| |}|| jd|d7 }|tj| jd d d |f | jdd7 }|| j7 }||fS Nr   rQ   )	r:   rh   rp   rI   r'   repeat_interleavero   ri   rm   )	rs   r=   rY   r@   rn   num_channelsheightwidth
embeddingsr,   r,   r-   forward   s   
(
zTvltPixelEmbeddings.forwardNr#   r$   r%   r&   rf   r|   __classcell__r,   r,   rt   r-   rb      s    rb   c                       ra   )TvltAudioEmbeddingsrc   c                    s   t    t|| _| jj| _ttdd|j	| _
|j|jd  | _ttd| j| j |j	| _ttd| j|j	| _|j|jd  | _|| _d S rd   )re   rf   TvltAudioPatchEmbeddingsrh   num_patchesr   rj   r'   rk   rl   type_embed_afrequency_lengthaudio_patch_sizenum_freq_patchespos_embed_a
freq_embedrq   rr   rt   r,   r-   rf      s   


 
zTvltAudioEmbeddings.__init__Nc                 C   sh   |  |}|d| j }|| jd|d7 }|tj| jd d d |f | jdd7 }|| j7 }||fS rv   )	rh   sizer   r   rI   r'   rw   r   r   )rs   rK   rY   r{   rO   r,   r,   r-   r|      s   
(
zTvltAudioEmbeddings.forwardr}   r~   r,   r,   rt   r-   r      s    r   c                       6   e Zd ZdZ fddZdejdejfddZ  ZS )rg   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _
|| _|| _|| _tj||||d| _d S Nr   r   )kernel_sizestride)re   rf   
image_sizeimage_patch_sizenum_image_channelsrl   
isinstancecollectionsabcIterable
patch_sizerx   ri   r   Conv2d
projection)rs   rq   r   r   rx   rl   ri   rt   r,   r-   rf   	  s   
 z!TvltPixelPatchEmbeddings.__init__r=   returnc              
   C   s   |j \}}}}}|| jkrtd|| jd ks|| jd kr6td| d| d| jd  d| jd  d	||| |||}| |ddd}|||| j | j	}|S )	NeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*) doesn't match model ().r7   )
r:   rx   
ValueErrorr   reshaper   flatten	transposeri   rl   )rs   r=   r@   rn   rx   ry   rz   r{   r,   r,   r-   r|     s   
(z TvltPixelPatchEmbeddings.forward	r#   r$   r%   r&   rf   r'   Tensorr|   r   r,   r,   rt   r-   rg     s    rg   c                       r   )r   z
    This class turns `audio_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c           
         s   t    |j|j|j}}}|j|j}}||f}t|tj	j
r$|n||f}|d |d  |d |d   }|d |d  |d |d  f}	|| _|| _|| _|| _|	| _tj||||d| _d S r   )re   rf   spectrogram_lengthr   r   num_audio_channelsrl   r   r   r   r   spectrogram_sizer   rx   r   patch_shaper   r   r   )
rs   rq   r   r   r   rx   rl   r   r   r   rt   r,   r-   rf   2  s    

  z!TvltAudioPatchEmbeddings.__init__rK   r   c              
   C   s   |j \}}}}|| jkrtd|| jd ks|| jd kr5td| d| d| jd  d| jd  d	| |ddd}|S )	Nr   r   r   zInput audio size (r   r   r   r7   )r:   rx   r   r   r   r   r   )rs   rK   r@   rx   ry   rz   r{   r,   r,   r-   r|   G  s   
z TvltAudioPatchEmbeddings.forwardr   r,   r,   rt   r-   r   +  s    r   c                       .   e Zd Z fddZdd Zd	ddZ  ZS )
TvltSelfAttentionc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   embedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .bias)re   rf   rl   num_attention_headshasattrr   r<   attention_head_sizeall_head_sizer   Linearqkv_biasquerykeyvalueDropoutattention_probs_dropout_probdropoutrr   rt   r,   r-   rf   X  s   

zTvltSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrG   r   r7   r      )r   r   r   rJ   permute)rs   xnew_x_shaper,   r,   r-   transpose_for_scoresj  s   
z&TvltSelfAttention.transpose_for_scoresNFc                 C   s   |  |}| | |}| | |}| |}t||dd}	|	t| j	 }	|d ur4|	| }	t
jdd|	}
| |
}
|d urI|
| }
t|
|}|dddd }| d d | jf }|j| }|rr||
f}|S |f}|S )NrG   rQ   r   r7   r   r   )r   r   r   r   r'   matmulr   mathsqrtr   r   Softmaxr   r   
contiguousr   r   rJ   )rs   r!   attention_mask	head_maskoutput_attentionsmixed_query_layer	key_layervalue_layerquery_layerattention_scoresattention_probscontext_layernew_context_layer_shapeoutputsr,   r,   r-   r|   o  s(   



zTvltSelfAttention.forwardNNF)r#   r$   r%   rf   r   r|   r   r,   r,   rt   r-   r   W  s    r   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )TvltSelfOutputz
    The residual connection is defined in TvltLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    rq   r   Nc                    s.   t    t|j|j| _t|j| _d S r}   )	re   rf   r   r   rl   denser   hidden_dropout_probr   rr   rt   r,   r-   rf        
zTvltSelfOutput.__init__r!   input_tensorc                 C      |  |}| |}|S r}   r   r   rs   r!   r   r,   r,   r-   r|        

zTvltSelfOutput.forward)
r#   r$   r%   r&   r   rf   r'   r   r|   r   r,   r,   rt   r-   r     s    $r   c                       r   )
TvltAttentionc                    s*   t    t|| _t|| _t | _d S r}   )re   rf   r   	attentionr   outputsetpruned_headsrr   rt   r,   r-   rf     s   


zTvltAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rQ   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)rs   headsrT   r,   r,   r-   prune_heads  s   zTvltAttention.prune_headsNFc                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r   r   )rs   r!   r   r   r   self_outputsattention_outputr   r,   r,   r-   r|     s   zTvltAttention.forwardr   )r#   r$   r%   rf   r   r|   r   r,   r,   rt   r-   r     s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	TvltIntermediaterq   r   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r}   )re   rf   r   r   rl   intermediate_sizer   r   
hidden_actstrr   intermediate_act_fnrr   rt   r,   r-   rf     s
   
zTvltIntermediate.__init__r!   c                 C   r   r}   )r   r   rs   r!   r,   r,   r-   r|     r   zTvltIntermediate.forward	r#   r$   r%   r   rf   r'   r   r|   r   r,   r,   rt   r-   r     s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )

TvltOutputrq   r   Nc                    s.   t    t|j|j| _t|j| _	d S r}   )
re   rf   r   r   r   rl   r   r   r   r   rr   rt   r,   r-   rf     r   zTvltOutput.__init__r!   r   c                 C   s    |  |}| |}|| }|S r}   r   r   r,   r,   r-   r|     s   

zTvltOutput.forwardr   r,   r,   rt   r-   r     s    $r   c                       s*   e Zd ZdZ fddZdddZ  ZS )		TvltLayerz?This corresponds to the Block class in the timm implementation.c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S Nr   eps)re   rf   chunk_size_feed_forwardseq_len_dimr   r   r   intermediater   r   r   	LayerNormrl   layer_norm_epslayernorm_beforelayernorm_afterrr   rt   r,   r-   rf     s   



zTvltLayer.__init__NFc           	      C   sj   | j | ||||d}|d }|dd  }|||j }| |}| |}| ||}|f| }|S )Nr   r   r   )r   r   tor9   r   r   r   )	rs   r!   r   r   r   self_attention_outputsr   r   layer_outputr,   r,   r-   r|     s   


zTvltLayer.forwardr   r~   r,   r,   rt   r-   r     s    
r   c                       s0   e Zd Z fddZ					dddZ  ZS )	TvltEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                       g | ]}t  qS r,   r   .0_rq   r,   r-   
<listcomp>      z(TvltEncoder.__init__.<locals>.<listcomp>F)	re   rf   rq   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingrr   rt   r  r-   rf     s   
 
zTvltEncoder.__init__NFTc                 C   s   |rdnd }|r
dnd }t | jD ])\}	}
|r||f }|d ur$||	 nd }|
||||}|d }|r:||d f }q|rB||f }|sPtdd |||fD S t|||dS )Nr,   r   r   c                 s       | ]	}|d ur|V  qd S r}   r,   r   vr,   r,   r-   	<genexpr>0      z&TvltEncoder.forward.<locals>.<genexpr>)r   r!   r"   )	enumerater  r+   r   )rs   r!   r   r   r   output_hidden_statesreturn_dictall_hidden_statesall_self_attentionsilayer_modulelayer_head_masklayer_outputsr,   r,   r-   r|     s(   	

zTvltEncoder.forward)NNFFTr#   r$   r%   rf   r|   r   r,   r,   rt   r-   r     s    	r   c                   @   s.   e Zd ZU dZeed< dZdZdZdd Z	dS )	TvltPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    rq   tvltr=   Tc                 C   st   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjr8|j	j
  |jjd dS dS )zInitialize the weights        )meanstdNg      ?)r   r   r   r   weightdatanormal_rq   initializer_ranger   zero_r   fill_)rs   moduler,   r,   r-   _init_weightsC  s   
z!TvltPreTrainedModel._init_weightsN)
r#   r$   r%   r&   r   r)   base_model_prefixmain_input_namesupports_gradient_checkpointingr%  r,   r,   r,   r-   r  8  s   
 r  aF  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`TvltConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a	  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        audio_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Audio values. Audio values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        pixel_mask (`torch.FloatTensor` of shape `(batch_size, num_pixel_patches)`):
            Pixel masks. Pixel masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        audio_mask (`torch.FloatTensor` of shape `(batch_size, num_audio_patches)`):
            Audio masks. Audio masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Pixel values mixed can
            be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.

        pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel masks of pixel_values_mixed. Pixel masks mixed can be obtained using [`TvltProcessor`]. See
            [`TvltProcessor.__call__`] for details.

        mask_pixel (`bool`, *optional*):
            Whether to mask pixel for MAE tasks. Only set to True in TvltForPreTraining.

        mask_audio (`bool`, *optional*):
            Whether to mask audio for MAE tasks. Only set to True in TvltForPreTraining.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.

        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare TVLT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Z fddZdd Zdd Zeeee	e
d									dd
ejdejdeej deej dededee dee dee deeej e	f fddZ  ZS )	TvltModelc                    sv   t  | || _t|| _t|| _t|| _t	
tdd|j| _|jr+d | _n
t	j|j|jd| _|   d S r   )re   rf   rq   rb   pixel_embeddingsr   audio_embeddingsr   encoderr   rj   r'   rk   rl   cls_embeddinguse_mean_pooling	layernormr   r   	post_initrr   rt   r,   r-   rf     s   


zTvltModel.__init__c                 C   s   | j j| jjfS r}   )r*  rh   r+  )rs   r,   r,   r-   get_input_embeddings  s   zTvltModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr,  r  r   r   )rs   heads_to_pruner  r   r,   r,   r-   _prune_heads  s   zTvltModel._prune_headsoutput_typeconfig_classNFr=   rK   r>   rL   
mask_pixel
mask_audior   r  r  r   c
                 C   s  |dur|n| j j}|dur|n| j j}|	dur|	n| j j}	| ||\}
}| ||\}}d}d}|rKt|
|| j jd\}}t|
|||d\}
}}}d}d}|rv| j j	| j j
d  }t||| j j| j j|d\}}t||||d\}}}}|d}t| j|dd|
|gd}|
d}d}|dur|durt|ddddf ||gd}| }d}|dur| ||}| j|||||	d}|d }| jdur| |}|dddd| f }|ddd| df }|	s|||||||f|dd  S t||||||||j|jd	S )	a  
        Returns:

        Examples:

        ```python
        >>> from transformers import TvltProcessor, TvltModel
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))

        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltModel.from_pretrained("ZinengTang/tvlt-base")

        >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```N)r>   r?   )rY   r   )rL   r?   rM   rN   r   )r   r   r  r  )	r   r   r   r   r   r   r    r!   r"   )rq   r   r  use_return_dictr*  r+  rD   pixel_mask_ratior`   r   r   rP   audio_mask_ratioaudio_mask_typer   r'   catr-  rI   get_extended_attention_maskr,  r/  r   r!   r"   )rs   r=   rK   r>   rL   r8  r9  r   r  r  pixel_embedding_outputaudio_embedding_outputr   r   pixel_mask_noisepixel_len_keepr   r    r   audio_mask_noiseaudio_len_keepr@   embedding_outputmasked_pixel_lenr   input_shapeextended_attention_maskencoder_outputssequence_outputpixel_sequence_outputaudio_sequence_outputr,   r,   r-   r|     s   %




"



zTvltModel.forward)NNFFNNN)r#   r$   r%   rf   r1  r4  r   TVLT_INPUTS_DOCSTRINGr   r   _CONFIG_FOR_DOCr'   r(   r   boolr   r+   r|   r   r,   r,   rt   r-   r)    sD    
	
r)  c                       s,   e Zd Z fddZ			dddZ  ZS )TvltDecoderc                    sv   t    t| |j _|j _|j _|j	 _
t fddt|jD | _tj|j|jd| _d| _|| _d S )Nc                    r   r,   r   r   decoder_configr,   r-   r  9  r  z(TvltDecoder.__init__.<locals>.<listcomp>r   F)re   rf   r   decoder_hidden_sizerl   decoder_num_hidden_layersr  decoder_num_attention_headsr   decoder_intermediate_sizer   r   r  r  decoder_layersr   r   r/  r	  rq   rr   rt   rR  r-   rf   0  s   

zTvltDecoder.__init__FTc                 C   s   |rdnd }|r
dnd }t | jD ]\}}|r||f }|||d}	|	d }|r/||	d f }q|r7||f }| |}
|sJtdd |
||fD S t|
||dS )Nr,   r   r   r   c                 s   r
  r}   r,   r  r,   r,   r-   r  ]  r  z&TvltDecoder.forward.<locals>.<genexpr>)r/   r!   r"   )r  rX  r/  r+   r.   )rs   r!   r   r  r  r  r  r  r  r  r/   r,   r,   r-   r|   A  s    


zTvltDecoder.forward)FFTr  r,   r,   rt   r-   rQ  /  s    rQ  zTThe TVLT Model transformer with the decoder on top for self-supervised pre-training.c                       s   e Zd Z fddZdd Zdd Zdd Zd	d
 Zdd Ze	e
eeed								ddejdejdeej deej deej deej deej dee dee dee deeej ef fddZ  ZS )TvltForPreTrainingc           	         s  t  | || _|j| _|j| _| js| jstdt|| _| jr(t|| _	| jrt
j|j|jdd| _t
tdd|j| _t
tdd|j| _t|| _|j}|j}| jjj}t
td||| _t
td|j|| _t
tdd|| _| jjj}|j|jd  }t
td|| || _ t
td||| _!t
tdd|| _"| jj#d d | jj$ }t%||| _&| jjd | jjd  | jj' }t%||| _(|| _|| _|| _)|j#| _#|j| _| *  d S )Nz;Must set at least one of matching task and MAE task to trueTr   r   r   r7   )+re   rf   rq   task_matchingtask_maer   r)  r  TvltMatchingHeadmatching_headr   r   rl   rT  encoder_to_decoderrj   r'   rk   pixel_mask_tokenaudio_mask_tokenrQ  decoderrn   r*  ri   decoder_pixel_pos_embeddecoder_temporal_embeddecoder_pixel_type_embedr+  r   r   r   decoder_audio_pos_embeddecoder_freq_embeddecoder_audio_type_embedr   r   TvltMAEHeadpixel_mae_headr   audio_mae_headr   r0  )	rs   rq   rT  rn   ri   num_audio_patchesr   pixel_mae_output_dimaudio_mae_output_dimrt   r,   r-   rf   f  sL   




zTvltForPreTraining.__init__c           
   	   C   s   |j \}}}}}|j d | jd  }|j d | jd  }|j||||| jd || jd fd}	td|	}	|	j||| | | jd | jd  | fd}	|	S )zJ
        pixel_values: [batch_size, num_frames, 3, height, width]
        r   r   r
   r   r:   zntchpwq->nthwpqc)r:   r   r   r'   einsum)
rs   r=   r@   rn   rx   ry   rz   num_patches_heightnum_patches_widthpatchified_pixel_valuesr,   r,   r-   patchify_pixel  s*   
z!TvltForPreTraining.patchify_pixelc           	      C   s   |j \}}}}|| jd  }|| jd  }|j|||| jd || jd fd}td|}|j||| | jd | jd  | fd}|S )z>
        audio_values: [batch_size, 1, height, width]
        r   r   rn  znchpwq->nhwpqc)r:   r   r   r'   ro  )	rs   rK   r@   rx   ry   rz   rp  rq  patchified_audio_valuesr,   r,   r-   patchify_audio  s(   
z!TvltForPreTraining.patchify_audioc                 C   :   |  |}|| d }|jdd}||  |  }|S Nr7   rG   rQ   )rs  r  sum)rs   r=   pixel_predictionsmaskrr  r1   r,   r,   r-   pixel_mae_loss  
   
z!TvltForPreTraining.pixel_mae_lossc                 C   rv  rw  )ru  r  rx  )rs   rK   audio_predictionsrz  rt  r1   r,   r,   r-   audio_mae_loss  r|  z!TvltForPreTraining.audio_mae_lossc           	      C   sZ   |j \}}}|||j d | d}tj||gdd}tj|d|ddd|d}|S )Nr   rQ   rG   rS   )r:   rI   r'   r>  rV   rH   )	rs   
mask_tokenrX   r\   r@   
seq_lengthrR   mask_tokenspadded_sequencer,   r,   r-   concatenate_mask  s   z#TvltForPreTraining.concatenate_maskr5  Nr=   rK   r>   rL   labelspixel_values_mixedpixel_mask_mixedr   r  r  r   c                  C   s  |
dur|
n| j j}
d}| jrF|du rtd|du rtd| j||||||	|
d}|d }| |}t }||d|d}||7 }d}d}| jr+| j	r+| j||||dd||	|
d		}|
re|j
n|d
 }|
rn|jn|d }|
rw|jn|d }|
r|jn|d }|
r|jn|d }|
r|jn|d }| |}| |}|d
}| | j||}|| jd
|d
 }|tj| jddd|f | jd
d }|| j }| |}| |j}| | j||}|d
| j }|| j d
|d
 }|tj| j!ddd|f | jd
d }|| j" }| |}| #|j}| $|||| %||| }||7 }|
sE|||f|dd  }|durC|f| S |S t&|||||j'|j(dS )aF  
        pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be
            obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.

        pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel masks of pixel_values_mixed. Pixel values mixed can be obtained using [`TvltProcessor`]. See
            [`TvltProcessor.__call__`] for details.

        labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the vision audio matching loss. Indices should be in `[0, 1]`. num_labels has to be 1.

        Return:

        Examples:

        ```python
        >>> from transformers import TvltProcessor, TvltForPreTraining
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> images_mixed = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))
        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltForPreTraining.from_pretrained("ZinengTang/tvlt-base")
        >>> input_dict = processor(
        ...     images, audio, images_mixed, sampling_rate=44100, mask_pixel=True, mask_audio=True, return_tensors="pt"
        ... )

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```Nr  zMatching task requires labelsz)Matching task requires pixel_values_mixedr>   rL   r   r  r  r   rG   T)r>   rL   r8  r9  r   r  r  r   r7   r   r
         rQ      )r1   r2   r3   r4   r!   r"   ))rq   r:  rZ  r   r  r]  r   rJ   r[  trainingr   r   r   r   r   r    r^  r   r  r_  rb  rI   r'   rw   rc  ri   rd  ra  ri  r/   r`  r   rf  re  rg  rj  r{  r~  r0   r!   r"   ) rs   r=   rK   r>   rL   r  r  r  r   r  r  
total_lossr   rK  r2   loss_fctr1   r3   r4   rL  rM  r   r   r   r    pixel_decoder_inputaudio_decoder_inputrn   pixel_decoder_outputsrO   audio_decoder_outputsr   r,   r,   r-   r|     s   1






zTvltForPreTraining.forward)NNNNNNNN)r#   r$   r%   rf   rs  ru  r{  r~  r  r   rN  r   r0   rO  r'   r(   r   r*   rP  r   r+   r|   r   r,   r,   rt   r-   rY  a  sP    6	
	
rY  c                       $   e Zd Z fddZdd Z  ZS )
TvltPoolerc                    s*   t    t|j|j| _t | _d S r}   )re   rf   r   r   rl   r   Tanh
activationrr   rt   r,   r-   rf   x  s   
zTvltPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r  )rs   r!   first_token_tensorpooled_outputr,   r,   r-   r|   }  s   

zTvltPooler.forwardr  r,   r,   rt   r-   r  w      r  c                       r  )r\  c                    s(   t    t|| _t|jd| _d S rd   )re   rf   r  poolerr   r   rl   fcrr   rt   r,   r-   rf     s   

zTvltMatchingHead.__init__c                 C   s   |  | |}|S r}   )r  r  r   r,   r,   r-   r|     s   zTvltMatchingHead.forwardr  r,   r,   rt   r-   r\    r  r\  c                       s&   e Zd Zd fdd	Zdd Z  ZS )rh  Nc                    s$   t    || _t|j|| _d S r}   )re   rf   rq   r   r   rT  ra  )rs   rq   
output_dimrt   r,   r-   rf     s   
zTvltMAEHead.__init__c                 C   s   |  |}|S r}   )ra  r   r,   r,   r-   r|     s   
zTvltMAEHead.forwardr}   r  r,   r,   rt   r-   rh    s    rh  z
    Tvlt Model transformer with a classifier head on top (an MLP on top of the final hidden state of the [CLS] token)
    for audiovisual classification tasks, e.g. CMU-MOSEI Sentiment Analysis and Audio to Video Retrieval.
    c                       s   e Zd Z fddZeeeeed						dde	j
de	j
dee	j
 dee	j
 d	ee d
ee dee dee	j deee	j
 ef fddZ  ZS ) TvltForAudioVisualClassificationc              	      sp   t  | t|| _tt|j|jd tj|jd |j	dt
 t|jd |j| _|| _|   d S )Nr7   r   )re   rf   r)  r  r   
Sequentialr   rl   r   r   GELU
num_labels
classifierrq   r0  rr   rt   r,   r-   rf     s   
z)TvltForAudioVisualClassification.__init__r5  Nr=   rK   r>   rL   r   r  r  r  r   c	              	   C   s   |dur|n| j j}| j|||||||d}	|	d dddf }
| |
}d}|durH| j jdkr:t }|||}n| j jdkrHt }|||}|s^|f|	dd  }|dur\|f| S |S t|||	j|	j	dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes
            refers to the number of classes in audiovisual tasks.

        Return:

        Examples:
        ```python
        >>> from transformers import TvltProcessor, TvltForAudioVisualClassification
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))
        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltForAudioVisualClassification.from_pretrained("ZinengTang/tvlt-base")
        >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```Nr  r   
regressionclassificationr
   )r1   r/   r!   r"   )
rq   r:  r  r  	loss_typer	   r   r   r!   r"   )rs   r=   rK   r>   rL   r   r  r  r  r   rK  r/   r1   r  r   r,   r,   r-   r|     s:   $	

z(TvltForAudioVisualClassification.forward)NNNNNN)r#   r$   r%   rf   r   rN  r   r   rO  r'   r(   r   rP  r*   r   r+   r|   r   r,   r,   rt   r-   r    s:    
	
r  )r)  rY  r  r  )Nr5   )Nr5   rE   rF   r}   )Gr&   collections.abcr   r   copyr   dataclassesr   typingr   r   r'   r   torch.nnr   r   r	   activationsr   modeling_layersr   modeling_outputsr   r   modeling_utilsr   pytorch_utilsr   r   utilsr   r   r   r   r   configuration_tvltr   
get_loggerr#   loggerrO  _CHECKPOINT_FOR_DOCr   r.   r0   rD   rP   r`   Modulerb   r   rg   r   r   r   r   r   r   r   r   r  TVLT_START_DOCSTRINGrN  r)  rQ  rY  r  r\  rh  r  __all__r,   r,   r,   r-   <module>   s   
(
!
	
),<"&,- $2  Y