o
    	۷ie                     @   s  d dl Zd dlmZ d dlmZmZmZ d dlZd dl	m
Z
 ddlmZ ddlmZ ddlmZ ddlmZmZ dd	lmZmZ dd
lmZ ddlmZmZmZmZmZ ddlm Z  ddl!m"Z" ddl#m$Z$ ddl%m&Z& ddl'm(Z(m)Z)m*Z*m+Z+m,Z, ddl-m.Z.m/Z/ e0e1Z2	dBde
j3dej4dej4dej4deej4 de5de5fddZ6G dd de&Z7G d d! d!e$Z8eed"d#G d$d% d%eZ9G d&d' d'e
j3Z:G d(d) d)e
j3Z;G d*d+ d+e"Z<e
j=e7d,Z>G d-d. d.eZ?G d/d0 d0e
j3Z@eG d1d2 d2eZAeG d3d4 d4eAZBG d5d6 d6e,ZCdZDG d7d8 d8e
j3ZEG d9d: d:e+ZFG d;d< d<e*ZGG d=d> d>e(ZHG d?d@ d@e)ZIg dAZJdS )C    N)	dataclass)CallableOptionalUnion   )ACT2FN)Cache)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tuplelogging	torch_int)check_model_inputs   )CLIPMLP)JanusVisionAttention)LlamaRMSNorm)LlavaCausalLMOutputWithPastLlavaForConditionalGeneration
LlavaModelLlavaModelOutputWithPastLlavaPreTrainedModel   )InternVLConfigInternVLVisionConfig        modulequerykeyvalueattention_maskscalingdropoutc                 K   s   |}|}	t ||dd| }
|d ur+|d d d d d d d |jd f }|
| }
tjj|
dd}
tjj|
|| jd}
t |
|	}|dd	 }||
fS )Nr   r   dim)ptrainingr   )
torchmatmul	transposeshapenn
functionalsoftmaxr(   r.   
contiguous)r"   r#   r$   r%   r&   r'   r(   kwargs
key_statesvalue_statesattn_weightscausal_maskattn_output r=   c/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/internvl/modular_internvl.pyeager_attention_forward0   s   
&r?   c                   @      e Zd ZdS )InternVLVisionRMSNormN__name__
__module____qualname__r=   r=   r=   r>   rA   K       rA   c                       sH   e Zd Zdef fddZ	d
dejdeej dee	 fdd	Z
  ZS )InternVLVisionAttentionconfigc                    sV   t  | | `d| _|j}|rt| jnt | _	|r$t| j| _
d S t | _
d S )NF)super__init__num_key_value_groups	is_causaluse_qk_normrA   	embed_dimr3   Identityq_normk_norm)selfrH   qk_norm	__class__r=   r>   rJ   P   s   "z InternVLVisionAttention.__init__Nhidden_statesr&   r7   c                 K   s  |  \}}}| |}| |}| |}	| |}| |}|||| j| j	dd}|||| j| j	dd}|	
||| j| j	dd}	t}
| jjdkrXt| jj }
|
| |||	|f| jsddn| j| jdd|\}}|||| j}| |}| |}||fS )Nr   r   eagerr!   F)r(   r'   rL   )sizeq_projk_projv_projrP   rQ   reshape	num_headshead_dimr1   viewr?   rH   _attn_implementationr   r.   attention_dropoutscalerN   projection_layerprojection_dropout)rR   rV   r&   r7   
batch_sizeseq_len_query_statesr8   r9   attention_interfacer<   r:   outputr=   r=   r>   forward[   s:   




	


zInternVLVisionAttention.forwardN)rC   rD   rE   r    rJ   r/   Tensorr   r   r   rk   __classcell__r=   r=   rT   r>   rG   O   s    rG   z7
    Class for outputs of [`InternVLVisionModel`].
    )custom_introc                   @   s   e Zd ZdZdS )$InternVLVisionModelOutputWithPoolingaF  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
        Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
        *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
        will be returned.
    N)rC   rD   rE   __doc__r=   r=   r=   r>   rp      s    rp   c                       s6   e Zd ZdZ fddZdejdejfddZ  ZS )InternVLVisionPatchEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}|d |d  |d |d   }|d |d  |d |d  f}|| _|| _|| _|| _|| _tj	||||d| _
d S )Nr   r   )kernel_sizestride)rI   rJ   
image_size
patch_sizenum_channelshidden_sizenum_patchespatch_shaper3   Conv2d
projection)rR   rH   ru   rv   rw   rx   ry   rz   rT   r=   r>   rJ      s   
  z&InternVLVisionPatchEmbeddings.__init__pixel_valuesreturnc           	      C   s^   |j \}}}}|| jkrtd| |}|j d |j d }}|ddd}|||ffS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   r   )r2   rw   
ValueErrorr|   flattenr1   )	rR   r}   re   rw   heightwidth
embeddingspatch_heightpatch_widthr=   r=   r>   rk      s   

z%InternVLVisionPatchEmbeddings.forward)	rC   rD   rE   rq   rJ   r/   rm   rk   rn   r=   r=   rT   r>   rr      s    rr   c                       sl   e Zd ZdZdeddf fddZdejded	edejfd
dZ		ddejde
ej dejfddZ  ZS )InternVLVisionEmbeddingszc
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.

    rH   r~   Nc                    s   t    ttdd|j| _|jr!ttdd|j| _	nd | _	t
|| _|j| _t|jtjjr8|jn|j|jf| _| jj}|jrUttd|d |j| _nd | _t|j| _d S )Nr   )rI   rJ   r3   	Parameterr/   zerosrx   	cls_tokenuse_mask_token
mask_tokenrr   patch_embeddingsrv   
isinstanceru   collectionsabcIterablery    use_absolute_position_embeddingsposition_embeddingsDropouthidden_dropout_probr(   )rR   rH   ry   rT   r=   r>   rJ      s    


z!InternVLVisionEmbeddings.__init__r   r   r   c                 C   s   |j d d }| jj d d }tj s||kr||kr| jS | jddddf }| jddddf }|j d }|| jd  }	|| jd  }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Nr*   r         ?r   r   bicubicF)rX   modealign_cornersr+   )r2   r   r/   jit
is_tracingrv   r   r\   permuter3   r4   interpolater_   cat)rR   r   r   r   ry   num_positionsclass_pos_embedpatch_pos_embedr,   
new_height	new_widthsqrt_num_positionsr=   r=   r>   interpolate_pos_encoding   s(   

z1InternVLVisionEmbeddings.interpolate_pos_encodingr}   bool_masked_posc                 C   s   |j \}}}}| |\}\}}| \}	}
}|d ur5| j|	|
d}|d|}|d|  ||  }| j|	dd}tj	||fdd}| j
d urT|| ||| }| |}|||ffS )Nr*   r   r+   )r2   r   rX   r   expand	unsqueezetype_asr   r/   r   r   r   r(   )rR   r}   r   rg   r   r   r   r   r   re   rf   mask_tokensw
cls_tokensr=   r=   r>   rk      s   

z InternVLVisionEmbeddings.forwardrl   )rC   rD   rE   rq   r    rJ   r/   rm   intr   r   
BoolTensorrk   rn   r=   r=   rT   r>   r      s    +r   c                   @   r@   )InternVLVisionMLPNrB   r=   r=   r=   r>   r     rF   r   )
layer_normrms_normc                       sX   e Zd ZdZdeddf fddZdejdee	ej e	ejejf f fdd	Z
  ZS )
InternVLVisionLayerz?This corresponds to the Block class in the timm implementation.rH   r~   Nc                    s   t    |j| _d| _t|| _t|| _t|j	 |j
|jd| _t|j	 |j
|jd| _|j}tj|t|j
 dd| _tj|t|j
 dd| _t|j| _d S )Nr   epsT)requires_grad)rI   rJ   chunk_size_feed_forwardseq_len_dimrG   	attentionr   mlpNORM2FN	norm_typerx   layer_norm_epslayernorm_beforelayernorm_afterlayer_scale_init_valuer3   r   r/   oneslambda_1lambda_2r   r   r(   )rR   rH   init_valuesrT   r=   r>   rJ   "  s   


zInternVLVisionLayer.__init__rV   c                 C   sd   |  | |\}}| j| }|| }| |}| |}| |}| jd ur,| j| }|| }|S rl   )r   r   r   r   r   r(   r   )rR   rV   attention_outputrg   layer_outputr=   r=   r>   rk   1  s   





zInternVLVisionLayer.forward)rC   rD   rE   rq   r    rJ   r/   rm   r   tuplerk   rn   r=   r=   rT   r>   r     s    r   c                       sB   e Zd Zdeddf fddZdejdeee	f fddZ
  ZS )	InternVLVisionEncoderrH   r~   Nc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r=   )r   ).0irH   r=   r>   
<listcomp>Q  s    z2InternVLVisionEncoder.__init__.<locals>.<listcomp>F)	rI   rJ   rH   r3   
ModuleListrangenum_hidden_layerslayergradient_checkpointingrR   rH   rT   r   r>   rJ   N  s   
 
zInternVLVisionEncoder.__init__rV   c                 C   s   | j D ]}||}qt|dS )N)last_hidden_state)r   r
   )rR   rV   layer_moduler=   r=   r>   rk   T  s
   

zInternVLVisionEncoder.forward)rC   rD   rE   r    rJ   r/   rm   r   r   r
   rk   rn   r=   r=   rT   r>   r   M  s    
r   c                       sR   e Zd ZU eed< dZdZdZdgZdZ	dZ
dZdZeedZ fddZ  ZS )	InternVLVisionPreTrainedModelrH   internvl_visionr}   Tr   )rV   
attentionsc                    s   t  | t|tr+|jj  |jdur|jj  |jdur)|jj  dS dS t|t	rD|j
j| jj |jj| jj dS dS )zInitialize the weightsN)rI   _init_weightsr   r   r   datazero_r   r   r   r   fill_rH   r   r   )rR   r"   rT   r=   r>   r   q  s   



z+InternVLVisionPreTrainedModel._init_weights)rC   rD   rE   r    __annotations__base_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modules_supports_sdpa_supports_flash_attn_supports_flex_attn_supports_attention_backendr   rG   _can_record_outputsr   rn   r=   r=   rT   r>   r   `  s   
 r   c                       sf   e Zd Zdeddf fddZdd Zedd	e	dd
ej	de
ej deeef fddZ  ZS )InternVLVisionModelrH   r~   Nc                    sT   t  | || _t|| _t|| _|jrt	 ntj
|j|jd| _|   d S )Nr   )rI   rJ   rH   r   r   r   encoderuse_mean_poolingr3   rO   	LayerNormrx   r   	layernorm	post_initr   rT   r=   r>   rJ     s   

zInternVLVisionModel.__init__c                 C   s   | j jS rl   )r   r   )rR   r=   r=   r>   get_input_embeddings  s   z(InternVLVisionModel.get_input_embeddingsF)tie_last_hidden_statesr}   r   c                 C   s@   | j ||d\}}| |}|d }| |}t||j|jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        )r   r   )r   rV   r   )r   r   r   rp   rV   r   )rR   r}   r   embedding_outputrg   encoder_outputssequence_outputr=   r=   r>   rk     s   

zInternVLVisionModel.forwardrl   )rC   rD   rE   r    rJ   r   r   r   r/   rm   r   r   r   r   rp   rk   rn   r=   r=   rT   r>   r     s    
r   c                   @   r@   )InternVLPreTrainedModelNrB   r=   r=   r=   r>   r     rF   r   c                       s*   e Zd Zdef fddZdd Z  ZS )InternVLMultiModalProjectorrH   c                    sz   t    t|jjtd|j d  | _t	|jjtd|j d  |j
j| _t|j | _t	|j
j|j
j| _d S )Nr   r   )rI   rJ   r3   r   vision_configrx   r   downsample_ratior   Lineartext_configlinear_1r   projector_hidden_actactlinear_2r   rT   r=   r>   rJ     s   
"z$InternVLMultiModalProjector.__init__c                 C   s,   |  |}| |}| |}| |}|S rl   )r   r   r   r   )rR   image_featuresrV   r=   r=   r>   rk     s
   



z#InternVLMultiModalProjector.forward)rC   rD   rE   r   rJ   rk   rn   r=   r=   rT   r>   r     s    	r   c                   @   r@   )InternVLModelOutputWithPastNrB   r=   r=   r=   r>   r     rF   r   c                   @   s   e Zd ZddejdefddZ		ddejdee	e
ee
 f  d	ee fd
dZee									ddeej deej deej deej dee deej dee	e
ee
 f  d	ee deej dee de	eef fddZdS )InternVLModelr   vision_featuresscale_factorc              	   C   s   |  \}}}}|| dks|| dkrtd|||t|| t|| }|dddd }||t|| t|| t||d  }|dddd }|S )a&  Perform pixel shuffle downsampling on vision features.

        Args:
            vision_features (`torch.Tensor`):
                Input tensor of shape (batch_size, width, height, channels).
            scale_factor (`float`, *optional*, defaults to `0.5`):
                Factor by which to downsample. Default is 0.5, which halves the dimensions.

        Returns:
            vision_features (`torch.Tensor`):
                Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
        r   zKHeight and width must be divisible by scale_factor for proper downsampling.r   r   r   )rX   r   r_   r   r   r6   )rR   r   r   re   r   r   channelsr=   r=   r>   pixel_shuffle  s   $zInternVLModel.pixel_shuffleNr}   vision_feature_layervision_feature_select_strategyc           
      K   s   |dur|n| j j}|dur|n| j j}|j| jd}| j j}|dkr+| j|dj}n	| j|dj	| }|dkrE|ddddddf }|j
d }t|d }|j
d }	||	||d}| j||d	}||	d|j
d }| |}|S )
a%  
        Obtains image last hidden states from the vision tower and apply multimodal projection.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
               The tensors corresponding to the input images.
            vision_feature_layer (`int` or `list[int]`):
                Layer index or list of layer indices to extract features from.
        Returns:
            vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
        N)dtyper*   )r}   defaultr   r   r   )r   )rH   r   r  tor  r   vision_towerr   vision_modelrV   r2   r   r\   r   multi_modal_projector)
rR   r}   r   r  r7   r   r   r   feature_sizere   r=   r=   r>   get_image_features  s*   


z InternVLModel.get_image_features	input_idsr&   position_idspast_key_valuesinputs_embedscache_positionr7   r~   c
                 K   s   |d ur|n| j j}|d ur|n| j j}|d u |d uA r td|d u r*|  |}|d urL| j|||d}||j|j}| j	|||d}|
||}| jd|||||	d|
}t|j|j|j|j|d urk|dS d dS )Nz:You must specify exactly one of input_ids or inputs_embeds)r}   r   r  )r  r   )r&   r  r  r  r  )r   r  rV   r   image_hidden_statesr=   )rH   r   r  r   r   r	  r  devicer  get_placeholder_maskmasked_scatterlanguage_modelr   r   r  rV   r   )rR   r
  r}   r&   r  r  r  r   r  r  r7   r   special_image_maskoutputsr=   r=   r>   rk   !  sN   	
zInternVLModel.forward)r   )NN)	NNNNNNNNN)rC   rD   rE   r/   rm   floatr   FloatTensorr   r   r   liststrr	  r   r   
LongTensorr   r   r   r   r   rk   r=   r=   r=   r>   r     sZ    &
6	

r   c                   @   r@   )InternVLCausalLMOutputWithPastNrB   r=   r=   r=   r>   r  ]  rF   r  c                       s   e Zd Z fddZ  ZS ) InternVLForConditionalGenerationc                     s   t  jdi |  dS )ac  
        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, AutoModelForImageTextToText

        >>> torch_device = "cuda"
        >>> processor = AutoProcessor.from_pretrained("OpenGVLab/InternVL3-1B-hf")
        >>> model = AutoModelForImageTextToText.from_pretrained(
        ...     "OpenGVLab/InternVL3-1B-hf", dtype=torch.bfloat16, device_map=torch_device
        ... )

        >>> messages = [
        ...     {
        ...         "role": "user",
        ...         "content": [
        ...             {
        ...                 "type": "image",
        ...                 "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
        ...             },
        ...             {
        ...                 "type": "image",
        ...                 "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
        ...             },
        ...             {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
        ...         ],
        ...     },
        ... ]

        >>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(torch_device)
        >>> generate_ids = model.generate(**inputs, max_new_tokens=200)
        >>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
        The images depict the Statue of Liberty and the Golden Gate Bridge.
        ```Nr=   )rI   rk   )super_kwargsrT   r=   r>   rk   b  s   $z(InternVLForConditionalGeneration.forward)rC   rD   rE   rk   rn   r=   r=   rT   r>   r  a  s    r  )r   r   r   r   r  )r!   )Kcollections.abcr   dataclassesr   typingr   r   r   r/   torch.nnr3   activationsr   cache_utilsr   modeling_layersr	   modeling_outputsr
   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   r   r   utils.genericr   clip.modeling_clipr   janus.modeling_janusr   llama.modeling_llamar   llava.modeling_llavar   r   r   r   r   configuration_internvlr   r    
get_loggerrC   loggerModulerm   r  r?   rA   rG   rp   rr   r   r   r   r   r   r   r   r   r   INTERNVL_INPUTS_DOCSTRINGr   r   r   r  r  __all__r=   r=   r=   r>   <module>   sz   


6	&^.* (