o
    	۷i`                     @   s  d Z ddlZddlmZ ddlmZ ddlZddlmZ ddlm	Z	 ddl
mZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZmZ ddlmZ ddlmZ eeZeeG dd deZG dd dejZ G dd dejZ!G dd dejZ"G dd dejZ#G dd dejZ$G dd dejZ%G dd dejZ&G d d! d!eZ'G d"d# d#ejZ(G d$d% d%ejZ)eG d&d' d'eZ*G d(d) d)ejZ+G d*d+ d+ejZ,e+e,d,Z-ed-d.G d/d0 d0e*Z.G d1d2 d2ejZ/ed3d.G d4d5 d5e*Z0g d6Z1dS )7zPyTorch TVP Model    N)	dataclass)Optional)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingModelOutput)PreTrainedModel)prune_linear_layer)auto_docstringlogging)load_backbone   )	TvpConfigc                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )TvpVideoGroundingOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Temporal-Distance IoU loss for video grounding.
    logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
        Contains start_time/duration and end_time/duration. It is the time slot of the videos corresponding to the
        input texts.
    attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.
    Nlosslogits.hidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   tupler    r   r   Z/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/tvp/modeling_tvp.pyr   %   s   
 r   c                       s@   e Zd ZdZ fddZdd Zdd Zdd	 Zd
d Z  Z	S )TvpLossa~  
    This class computes the losses for `TvpForVideoGrounding`. The process happens in two steps: 1) we compute
    hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched
    ground-truth / prediction (supervise class and box).

    Args:
        losses (`list[str]`):
            List of all the losses to be applied.
    c                    sL   t    | j| j| jd| _|D ]}|| jvr td| dq|| _d S )NioudistancedurationzLoss z not supported)super__init__loss_iouloss_distanceloss_durationloss_map
ValueErrorlosses)selfr-   r   	__class__r   r    r'   D   s   


zTvpLoss.__init__c           	      C   sH   t ||t || }t ||t || }d|jdd|  }|S )z6
        Measure the intersection over union.
        r   r   min)r   r2   maxclamp)	r.   
start_timeend_timecandidates_start_timecandidates_end_timer%   interunionr#   r   r   r    r(   Q   s   zTvpLoss.loss_iouc           	      C   sT   t t ||d}t t ||d}t t ||t || |jdd}|S )z5
        Measure the distance of mid points.
        g       @g?r1   )r   divaddr3   r2   r4   )	r.   r5   r6   r7   r8   r%   mid_candidatesmid_groundtruthdistance_diffr   r   r    r)   [   s   zTvpLoss.loss_distancec           	      C   sB   t ||}t ||}t t t |||}|jdd}|S )z5
        Measure the difference of duration.
        g?r1   )r   subsquarer;   r4   )	r.   r5   r6   r7   r8   r%   duration_candidatesduration_groundtruthduration_diffr   r   r    r*   g   s
   zTvpLoss.loss_durationc              
   C   st   |\}}}t ||}|dddf  |dddf  }}i }	| jD ]}
|	|
| j|
 |||||i q%|	S )am  
        This performs the loss computation.

        Args:
            logits (`torch.FloatTensor`):
                The output logits of head module.
            labels (`list[torch.FloatTensor]`):
                List of tensors ([start, end, duration]), which contains start time, end time of the video corresponding to the text, and also the duration.
        Nr   r   )r   mulfloatr-   updater+   )r.   r   labelsr%   r5   r6   
candidatesr7   r8   losses_dictr   r   r   r    forwardr   s   

*
zTvpLoss.forward)
r   r   r   r   r'   r(   r)   r*   rK   __classcell__r   r   r/   r    r!   9   s    

r!   c                       $   e Zd Z fddZdd Z  ZS )TvpVisionModelc              	      s   t    t|| _|jd ur|jjd }n,t| jdr+t| jjdr+| jjjd }nt| jdr>t| jjdr>| jjj}nt	dt
j||jdddddd	| _d S )
Nconfighidden_sizeshidden_sizezBackbone config not foundr   r   F)kernel_sizestridepaddinggroupsbias)r&   r'   r   backbonebackbone_configrQ   hasattrrP   rR   r,   r   Conv2dgrid_encoder_conv)r.   rP   in_channelsr/   r   r    r'      s$   


zTvpVisionModel.__init__c                 C   s   |j \}}}}}||| |||}| |d d }| |}tjj|ddd}tjj|dd}|j dd  \}	}
}||||	|
|}|ddd	d
d}|S )Nfeature_mapsr      )rS   rT   T)inplacer   r      )	shapeviewrX   r\   r   
functional
max_pool2drelupermute)r.   pixel_values
batch_size
num_framesnum_channelsheightwidthgrid_feat_outputsgridnew_channel
new_height	new_widthr   r   r    rK      s   
zTvpVisionModel.forwardr   r   r   r'   rK   rL   r   r   r/   r    rN      s    rN   c                       s^   e Zd ZdZ fddZdejdededejfdd	Zdde	fddZ
dde	fddZ  ZS )TvpVisualInputEmbeddingz;
    Takes input of both image and video (multi-frame)
    c                    s   t    t|j|j| _t|j|j| _t|j	|j| _
td|j| _tj|j|jd| _t|j| _|j| _|j	| _	d S )Nr   eps)r&   r'   r   	Embeddingmax_position_embeddingsrR   position_embeddings max_grid_row_position_embeddingsrow_position_embeddings max_grid_col_position_embeddingscol_position_embeddingstoken_type_embeddings	LayerNormlayer_norm_eps
layer_normDropouthidden_dropout_probdropoutr.   rP   r/   r   r    r'      s   
z TvpVisualInputEmbedding.__init__	embeddingrm   rn   returnc                 C   sl   d }}|| j kr|| j  }|| jkr|| j }|dddd}tjj|||fddd}|dddd}|S )z
        This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
        resolution images (high resolution videos).

        r   r   r   r_   bicubicFscale_factormodealign_corners)r{   r}   rh   r   re   interpolate)r.   r   rm   rn   h0w0r   r   r    interpolate_pos_encoding   s   



z0TvpVisualInputEmbedding.interpolate_pos_encodingFr   c                 C   s   |j \}}}}t| j|}tj|tj|jd}| |}	dt|j d  |d|f }
|	j	|
 }	t| j
|}tj|tj|jd}| |}|d||f}|j	| }|	| }|rj|| jks_|| j
krj|| ||| }|S || }|S )af  
        Args:
            grid: (batch_size, height, width, hidden_dim)
            interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
                Whether to interpolate the pre-trained position encodings.
        Returns:
            grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
        dtypedevice)r   r   r   )rc   r2   r{   r   arangelongr   r|   lenrd   r}   r~   r   )r.   rp   r   rj   rm   rn   
hidden_dim
row_heightrow_position_idsr|   	row_shape	row_widthcol_position_idsr~   	col_shapepositional_embeddingsr   r   r    add_2d_positional_embeddings   s$   	



z4TvpVisualInputEmbedding.add_2d_positional_embeddingsc                 C   s   |j \}}}}}|d}| j||d}||d|}|j dd }	|j}
tj|	tj|
d}| |}|| }| 	|}| 
|}|S )a  
        Args:
            grid: Array of shape (batch_size, num_frames, height, width, num_channels).
                It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
                num_frames can be 1
            interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
                Whether to interpolate the pre-trained position encodings.

        Returns:
            embeddings: The embedding of grid with size (batch_size, height*width, num_channels)

        r   r   rO   Nr   )rc   meanr   rd   r   r   zerosr   r   r   r   )r.   rp   r   rj   rk   rm   rn   rl   visual_tokensvisual_tokens_shaper   token_type_idsr   
embeddingsr   r   r    rK     s   



zTvpVisualInputEmbedding.forwardF)r   r   r   r   r'   r   Tensorintr   boolr   rK   rL   r   r   r/   r    ru      s    )ru   c                       s*   e Zd ZdZ fddZdddZ  ZS )TvpTextInputEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    sl   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _d S )N)padding_idxrv   )r&   r'   r   rx   
vocab_sizerR   pad_token_idword_embeddingsry   rz   type_vocab_sizer   r   r   r   r   r   r   r   r/   r   r    r'   %  s   
zTvpTextInputEmbeddings.__init__Nc                 C   s   |d ur	|  }n|  d d }|d }|d ur|jn|j}|d u r4tj|tj|d}|d|}|d u rAtj|tj|d}|d u rJ| |}| 	|}| 
|}	|| |	 }
| |
}
| |
}
|
S )NrO   r   r   r   )sizer   r   r   r   	unsqueezeexpandr   r   rz   r   r   r   )r.   	input_idsr   position_idsinputs_embedsinput_shape
seq_lengthr   rz   r   r   r   r   r    rK   -  s$   





zTvpTextInputEmbeddings.forward)NNNNr   r   r   r   r'   rK   rL   r   r   r/   r    r   "  s    r   c                       sV   e Zd Z fddZdd Zdejdedefdd	Z	
	
	
dde	e
 fddZ  ZS )TvpAttentionc                    s   t    |j|j dkrt|dstd|j d|j |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t	
|j|j| _t	j|j|jd| _t	|j| _t | _d S )Nr   embedding_sizezThe hidden size z4 is not a multiple of the number of attention heads rv   )r&   r'   rR   num_attention_headsrZ   r,   r   attention_head_sizeall_head_sizer   Linearquerykeyvaluer   attention_probs_dropout_probattn_dropoutdenser   r   r   r   r   setpruned_headsr   r/   r   r    r'   G  s    
zTvpAttention.__init__c                    s   t |dkrd S t| j| j}t|| j }|D ]  t fdd| jD   d| < q|d	 
d}tt ||  }t| j|| _t| j|| _t| j|| _t| j|dd| _| jt | | _| j| j | _| j|| _d S )Nr   c                 3   s     | ]}| k r
d ndV  qdS )r   r   Nr   ).0hheadr   r    	<genexpr>c  s    z+TvpAttention.prune_heads.<locals>.<genexpr>rO   r   dim)r   r   onesr   r   r   r   sumrd   
contiguouseqr   r   r   r   r   r   r   r   r:   )r.   headsmaskindexr   r   r    prune_heads\  s    
zTvpAttention.prune_headstensorsequence_lengthrj   c                 C   s    | ||| j| jdd S )Nr   r_   )rd   r   r   	transposer   )r.   r   r   rj   r   r   r    _reshapes  s   zTvpAttention._reshapeNoutput_attentionsc                 C   s   |j d d \}}| |}| |}| |}	| |||}
| |||}| |	||}t|
|dd}|t	| j
 }|d urG|| }tjj|dd}| |}|d ur\|| }t||}|dd }|||| j}| |}| |}| || }|r||f}|S |f}|S )Nr_   rO   r   r   )rc   r   r   r   r   r   matmulr   mathsqrtr   r   re   softmaxr   r   reshaper   r   r   r   )r.   r   attention_mask	head_maskr   rj   r   mixed_query_layermixed_key_layermixed_value_layerquery_layer	key_layervalue_layerattention_scoresattention_probsattn_outputoutputsr   r   r    rK   z  s2   





zTvpAttention.forwardNNN)r   r   r   r'   r   r   r   r   r   r   r   rK   rL   r   r   r/   r    r   F  s    
r   c                       2   e Zd Z fddZdejdejfddZ  ZS )TvpIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S N)r&   r'   r   r   rR   intermediate_sizer   
isinstance
hidden_actstrr   intermediate_act_fnr   r/   r   r    r'     s
   
zTvpIntermediate.__init__r   r   c                 C   s   |  |}| |}|S r   )r   r   )r.   r   r   r   r    rK     s   

zTvpIntermediate.forwardr   r   r   r'   r   r   rK   rL   r   r   r/   r    r     s    r   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )TvpOutputLayerc                    sB   t    t|j|j| _tj|j|jd| _	t
|j| _d S )Nrv   )r&   r'   r   r   r   rR   r   r   r   r   r   r   r   r   r/   r   r    r'     s   
zTvpOutputLayer.__init__r   input_tensorr   c                 C   s&   |  |}| |}| || }|S r   )r   r   r   )r.   r   r   r   r   r    rK     s   

zTvpOutputLayer.forwardr   r   r   r/   r    r     s    $r   c                       s6   e Zd Z fddZ			ddee fddZ  ZS )TvpEncodeLayerc                    s,   t    t|| _t|| _t|| _d S r   )r&   r'   r   	attentionr   intermediater   outputr   r/   r   r    r'     s   


zTvpEncodeLayer.__init__Nr   c           
      C   sJ   | j ||||d}|d }|dd  }| |}| ||}	|	f| }|S )N)r   r   r   )r   r   r   )
r.   r   r   r   r   self_attention_outputsattention_outputr   intermediate_outputlayer_outputr   r   r    rK     s   

zTvpEncodeLayer.forwardr   )r   r   r   r'   r   r   rK   rL   r   r   r/   r    r     s    	r   c                
       sT   e Zd Z fddZ					d
deej dee dee dee fdd	Z  Z	S )
TvpEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r   )r   )r   _rP   r   r    
<listcomp>  s    z'TvpEncoder.__init__.<locals>.<listcomp>F)	r&   r'   rP   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   r/   r  r    r'     s   
 
zTvpEncoder.__init__Nr   r   output_hidden_statesreturn_dictc                 C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}d}d}t| jD ]!\}	}
|r2||f }|
||||	 |}|d }|rH||d f }q'|rP||f }|se|f}|r\||f }|rc||f }|S t||rk|nd |rr|dS d dS )Nr   r   r   )last_hidden_stater   r   )rP   r  r   r  	enumerater
  r   )r.   r   r   r   r   r  r  all_hidden_statesall_attentionsilayer_modulelayer_outputsr   r   r   r    rK     s<   	




zTvpEncoder.forward)NNNNN)
r   r   r   r'   r   r   r   r   rK   rL   r   r   r/   r    r    s     	r  c                       r   )	TvpPoolerc                    s*   t    t|j|j| _t | _d S r   )r&   r'   r   r   rR   r   Tanh
activationr   r/   r   r    r'     s   
zTvpPooler.__init__r   r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r  )r.   r   first_token_tensorpooled_outputr   r   r    rK     s   

zTvpPooler.forwardr   r   r   r/   r    r    s    r  c                   @   s.   e Zd ZU eed< dZdZdejfddZ	dS )TvpPreTrainedModelrP   modelTmodulec                 C   s0  t |tjtjfr|jjjd| jjd n>t |tj	r)|j
j  |jjd n*t |tjrGtjj|jddd |j
durFtj|j
d nt |trStj|j t |tjrd|j
durd|j
j  t|d	rptj|j t|d
r|tj|j t|drtj|j t|drtj|j dS dS )zInitialize the weights        )r   stdg      ?fan_outrg   )r   nonlinearityNr   pad_uppad_downpad_left	pad_right)r   r   r   rx   weightdatanormal_rP   initializer_ranger   rW   zero_fill_r[   initkaiming_normal_	constant_TvpModeltext_promptrZ   r!  r"  r#  r$  )r.   r  r   r   r    _init_weights-  s.   





z TvpPreTrainedModel._init_weightsN)
r   r   r   r   r   base_model_prefixsupports_gradient_checkpointingr   Moduler0  r   r   r   r    r  '  s
   
 r  c                       s(   e Zd ZdZ fddZdd Z  ZS )TvpFrameDownPadPrompterz>
    Pad frames extracted from videos only at the bottom.
    c              	      sb   |j dvr	tdt   |j| _|j| _|j| _|j | _ tt	
d|jd|j|jg| _d S )Nr<   replaceremove9`visual_prompter_apply` must be in (add, replace, remove)r   r   )visual_prompter_applyr,   r&   r'   visual_prompt_size	frame_nummax_img_sizer   	Parameterr   randnr"  r   r/   r   r    r'   N  s   


z TvpFrameDownPadPrompter.__init__c                 C   s   | j dkr&tj| j| jg|j|jd}d|| j| j | jd d f< ||9 }| j dkrctj|jd |jd d| j| jg|jd}| j| j }| j	|d d d d d d || jd d f< ||
|j7 }|S )	Nr<   r   r  r7  r   r   r   r   )r9  r   r   r<  r   r   r:  r   rc   r"  to)r.   ri   visual_prompt_maskpromptstart_pointr   r   r    rK   \  s   

*zTvpFrameDownPadPrompter.forwardr   r   r   r/   r    r4  I  s    r4  c                       sN   e Zd ZdZ fddZdejdededejfdd	Zdde	fddZ
  ZS )TvpFramePadPrompterz?
    Pad frames extracted from videos in the surroundings.
    c              
      s   |j dvr	tdt   |j| _|j| _|j | _ |j|jd  | _t	t
d|jd|j|jg| _t	t
d|jd|j|jg| _t	t
d|jd|j|jd  |jg| _t	t
d|jd|j|jd  |jg| _d S )Nr5  r8  r_   r   r   )r9  r,   r&   r'   rk   r<  r:  	base_sizer   r=  r   r>  r!  r"  r#  r$  r   r/   r   r    r'   s  sB   


zTvpFramePadPrompter.__init__rB  rm   rn   r   c                 C   sh   || j  || j  }}|j\}}}}	}
||| ||	|
}tjj|||fddd}||||||}|S )z
        This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
        resolution images (high resolution videos).

        r   Fr   )r<  rc   r   r   re   r   )r.   rB  rm   rn   r   r   batchrk   channelsprompt_heightprompt_widthr   r   r    interpolate_pad_encoding  s   z,TvpFramePadPrompter.interpolate_pad_encodingFrJ  c                 C   s   |r|j d |j d fn| j| jf\}}| jdvr!td| j | jdv r6tj||g|j|jd}||9 }| jdv r~tjd| j	d	| j
| j
|jd
}tj| j|| jgdd}tj| j|| jgd	d}t|d|g }|rv| |||}|||j }|S )Nr   rO   )r<   r7  r6  z$Invalid visual_prompter_apply value )r6  r7  r   )r6  r<   r   r   r?  rb   r   r   )rc   r<  r9  r,   r   r   r   r   r   rk   rE  catr#  r$  r!  r"  r   rJ  r@  )r.   ri   rJ  rm   rn   rA  baserB  r   r   r    rK     s$   



zTvpFramePadPrompter.forwardr   )r   r   r   r   r'   r   r   r   rJ  r   rK   rL   r   r   r/   r    rD  n  s
    &rD  )framedownpadframepadzw
    The bare Tvp Model transformer outputting BaseModelOutputWithPooling object without any specific head on top.
    )custom_introc                       s   e Zd Z fddZdd Zdd Zdd Ze															
ddee	j
 dee	j dee	j
 dee	j dee dee dee defddZ  ZS )r.  c                    s   t  | || _t|| _t|| _t|| _t	|| _
t|| _ttdd|jg| _t|j| _|jtvr?tdt|j || _|   d S )Nr   
   z:`visual_prompter_type` must be in (framedownpad, framepad))r&   r'   rP   rN   vision_modelr   r   ru   visual_embeddingsr  encoderr  poolerr   r=  r   r>  rR   r/  r   r   r   visual_prompter_typeTVP_PROMPTER_CLASSES_MAPPINGr,   visual_prompter	post_initr   r/   r   r    r'     s   





zTvpModel.__init__c                 C   s   | j jS r   r   r   )r.   r   r   r    get_input_embeddings  s   zTvpModel.get_input_embeddingsc                 C   s   || j _d S r   rY  )r.   r   r   r   r    set_input_embeddings  s   zTvpModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )zPrunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
        N)itemsrS  r
  r   r   )r.   heads_to_pruner
  r   r   r   r    _prune_heads  s   zTvpModel._prune_headsNFr   ri   r   r   r   r  r  r   c	                 C   sR  |dur|n| j j}| | j||d}| j|d}	| j||d}
|durU||
jdd }t	|jd dj
|j|jd}tj|||gd	d
}| || 
|j}| j|	jd d	d	}tj||	|
gdd
}| j||| || j j|||d}|r|jn|d }| |}| |}| |}|s||f|dd  S t|||j|jdS )a  
        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoConfig, AutoTokenizer, TvpModel

        >>> model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp")

        >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")

        >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
        >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
        >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
        ```N)rJ  )r   r   r_   r   rP  )r   r   rO   r   r   )r   r   r   r  r  )r  pooler_outputr   r   )rP   r  rQ  rW  r   rR  new_onesrc   r   r   r@  r   r   rK  get_extended_attention_maskr   r/  r   rS  get_head_maskr	  r  rT  r   r	   r   r   )r.   r   ri   r   r   r   r  r  r   text_embedding_outputvisual_embedding_outputvisual_attention_maskpt_maskr/  embedding_outputencoder_outputsr  r  r   r   r    rK     sJ   


zTvpModel.forward)NNNNNNNF)r   r   r   r'   rZ  r[  r^  r   r   r   
LongTensorr   r   rK   rL   r   r   r/   r    r.    s>    	r.  c                       rM   )TvpVideoGroundingHeadc                    sL   t    t|j|jd | _t|jd d| _t | _t	 | _
d S )Nr_   )r&   r'   r   r   rR   layer_0layer_1ReLUactivation_0Sigmoidactivation_1r   r/   r   r    r'   <  s
   

zTvpVideoGroundingHead.__init__c                 C   s$   |  | |}| | |}|S r   )rn  rk  rp  rl  )r.   r_  r   r   r   r    rK   C  s   zTvpVideoGroundingHead.forwardrt   r   r   r/   r    rj  ;  s    rj  zb
    Tvp Model with a video grounding head on top computing IoU, distance, and duration loss.
    c                       s   e Zd Z fddZe									ddeej deej deej dee	ej
  d	eej d
ee dee dee defddZ  ZS )TvpForVideoGroundingc                    s2   t  | || _t|| _t|| _|   d S r   )r&   r'   rP   r.  r  rj  video_grounding_headrX  r   r/   r   r    r'   O  s
   

zTvpForVideoGrounding.__init__NFr   ri   r   rH   r   r   r  r  r   c
              
   C   s   |dur|n| j j}| j||||||||	d}
|
d }| |}d}|durKtg d}|| j |||}|d | j j|d   | j j|d   }|sa|f|
dd  }
|dur_|f|
 }
|
S t	|||
j
|
jd	S )
a  
        labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
            The labels contains duration, start time, and end time of the video corresponding to the text.

        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoConfig, AutoTokenizer, TvpForVideoGrounding

        >>> model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp")

        >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")

        >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
        >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
        >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
        ```N)r   r   r  r  r   r   r"   r#   r$   r%   r_   )r   r   r   r   )rP   r  r  rr  r!   r@  r   distance_loss_weightduration_loss_weightr   r   r   )r.   r   ri   r   rH   r   r   r  r  r   r   r_  r   r   	criterion	loss_dictr   r   r    rK   W  sF   



zTvpForVideoGrounding.forward)	NNNNNNNNF)r   r   r   r'   r   r   r   ri  r   r   r   r   rK   rL   r   r   r/   r    rq  I  s>    	
rq  )r.  r  rq  )2r   r   dataclassesr   typingr   r   r   activationsr   modeling_layersr   modeling_outputsr   r	   r
   modeling_utilsr   pytorch_utilsr   utilsr   r   utils.backbone_utilsr   configuration_tvpr   
get_loggerr   loggerr   r3  r!   rN   ru   r   r   r   r   r   r  r  r  r4  rD  rV  r.  rj  rq  __all__r   r   r   r    <module>   sZ   
P(q$c6!%[hM