o
    i5                  
   @   s   d dl Z d dlmZmZ d dlmZ d dlmZmZ d dl	Z	d dl	m
Z
 d dlmZmZmZ d dlmZ d dlmZ d d	lmZmZ d d
lmZ d dlmZmZ d dlmZ d dlmZ d dlm Z  d dl!m"Z" d dl#m$Z$ d dl%m&Z&m'Z'm(Z(m)Z) d dl*m+Z+ d dl,m-Z- d dl.m/Z/ d dl0m1Z1 d dl2m3Z3 d dl4m5Z5 d dl6m7Z7 d dl8m9Z9m:Z:m;Z; d dl<m=Z=m>Z>m?Z? d dl@mAZAmBZBmCZCmDZDmEZEmFZFmGZG d dlHmIZI d dlJmKZKmLZL ddlMmNZNmOZOmPZPmQZQmRZRmSZS dd lTmUZU dd!lVmWZWmXZXmYZY dd"lZm[Z[m\Z\ G d#d$ d$eKZ]G d%d& d&e
j^Z_G d'd( d(e
j^Z`d)d* ZaG d+d, d,e
j^ZbG d-d. d.e
j^ZcG d/d0 d0e
j^ZdG d1d2 d2e
j^ZeG d3d4 d4e
j^Zfed5d ie5d6G d7d8 d8e
j^ZgG d9d: d:eCZhG d;d< d<eBeh ZiG d=d> d>eAeh Zje7jkeiehejd?G d@dA dAe
j^eReSeNePeQZldS )B    N)IterableMapping)tee)	AnnotatedLiteral)nn)BatchFeatureLlama4ConfigLlama4VisionConfig)SizeDict)Llama4Processor)find_supported_resolutionsget_best_fit)support_torch_compile)
VllmConfigset_current_vllm_config)BaseDummyOptions)$get_tensor_model_parallel_world_size)set_forward_context)MMEncoderAttention)FusedMoE)ColumnParallelLinearQKVParallelLinearReplicatedLinearRowParallelLinear)QuantizationConfig)get_rope)initialize_model)default_weight_loader)MultiModelKeys)should_torch_compile_mm_vit)MULTIMODAL_REGISTRY)MultiModalDataDictMultiModalFieldConfigMultiModalKwargsItems)ImageProcessorItems	ImageSizeMultiModalDataItems)BaseDummyInputsBuilderBaseMultiModalProcessorBaseProcessingInfoInputProcessingContextPromptReplacementPromptUpdatePromptUpdateDetails)IntermediateTensors)TensorSchemaTensorShape   )MixtureOfExpertsMultiModalEmbeddingsSupportsEagle3SupportsLoRASupportsMultiModal
SupportsPP)Llama4ForCausalLM)AutoWeightsLoaderStageMissingLayermaybe_prefix)is_vit_use_data_parallelrun_dp_sharded_vision_modelc                   @   sn   e Zd ZU dZdZed ed< eej	e
ddddf ed< eej	e
df ed< 	 eej	e
dd	f ed
< dS )Llama4ImagePatchInputsz
    Dimensions:
        - batch_size: Batch size
        - total_num_chunks: Batch size * number of chunks
        - num_channels: Number of channels
        - image_size: Size of each image
    pixel_valuestypetotal_num_chunksnum_channels
image_size
batch_sizepatches_per_image   aspect_ratiosN)__name__
__module____qualname____doc__rA   r   __annotations__r   torchTensorr1    rP   rP   X/home/ubuntu/vllm_env/lib/python3.10/site-packages/vllm/model_executor/models/mllama4.pyr?   U   s   
 r?   c                       sZ   e Zd Z		ddededededededB d	ef fd
dZdej	dej	fddZ
  ZS )Llama4VisionMLPN 
input_sizeintermediate_sizeoutput_sizebiasoutput_activationquant_configprefixc           	         s\   t    t }t||||| d|d| _t||||| d|d| _t | _	|| _
d S )Nz.fc1)rT   rV   rW   rY   rZ   
disable_tpz.fc2)super__init__r=   r   fc1r   fc2r   GELUactivation_fnrX   )	selfrT   rU   rV   rW   rX   rY   rZ   use_data_parallel	__class__rP   rQ   r]   v   s(   



zLlama4VisionMLP.__init__hidden_statesreturnc                 C   s:   |  |\}}| |}| |\}}| jr| |S |S N)r^   ra   r_   rX   rb   rf   _rP   rP   rQ   forward   s   

zLlama4VisionMLP.forwardNrS   )rI   rJ   rK   intboolr   strr]   rN   rO   rk   __classcell__rP   rP   rd   rQ   rR   u   s&    rR   c                       s8   e Zd Z		d	dedB def fddZdd Z  ZS )
Llama4MultiModalProjectorNrS   rY   rZ   c                    s2   t    t|jj|jjd|d| dd| _d S )NFTz	.linear_1)rT   rV   rW   rY   gather_outputrZ   )r\   r]   r   vision_configvision_output_dimtext_confighidden_sizelinear_1rb   configrY   rZ   rd   rP   rQ   r]      s   
z"Llama4MultiModalProjector.__init__c                 C   s   |  |\}}|S rh   )rw   )rb   image_featuresrf   rj   rP   rP   rQ   rk      s   z!Llama4MultiModalProjector.forwardrl   )rI   rJ   rK   r   ro   r]   rk   rp   rP   rP   rd   rQ   rq      s    rq   c           
   	   C   s   | j \}}}tt|}| |||d} |  \}}}}| ||t|| t|| }|dddd }||t|| t|| t||d  }|dddd }||d|j d }	|	S )Nr   rG   r2      )shaperm   mathsqrtviewsizepermute
contiguous)
input_tensorshuffle_ratiorE   num_patcheschannels
patch_sizeheightwidthreshaped_tensoroutput_tensorrP   rP   rQ   pixel_shuffle   s"   

r   c                       sF   e Zd Z		ddedB def fddZdejdejfd	d
Z  Z	S )Llama4VisionPixelShuffleMLPNrS   rY   rZ   c              	      sZ   t    |j| _t|j| jd  | _|j| _t|j	|j|j|j
d|| dd| _d S )NrG   T.mlprT   rU   rV   rW   rX   rY   rZ   )r\   r]   pixel_shuffle_ratiorm   projector_input_dim	inner_dimprojector_output_dim
output_dimrR   rU   multi_modal_projector_biasmlprx   rd   rP   rQ   r]      s   
z$Llama4VisionPixelShuffleMLP.__init__encoded_patchesrg   c                 C   s   t || j}| |S rh   )r   r   r   )rb   r   rP   rP   rQ   rk      s   
z#Llama4VisionPixelShuffleMLP.forwardrl   )
rI   rJ   rK   r   ro   r]   rN   rO   rk   rp   rP   rP   rd   rQ   r      s    r   c                       H   e Zd Z	ddededB def fddZdejd	ejfd
dZ	  Z
S )Llama4VisionAttentionrS   ry   rY   NrZ   c                    s  t    || _t }|rdnt | _|j| _|j| _	|j| j	 | _
| j	| j dks,J | j	| j | _| j| j
 | _| j| j
 | _|j| _| j
d | _t| j| j
| j| dd| _|rt| j| jd| j  d|| dd	| _t| j	| j
 | jd|| d
d	| _n$t| j| j
| j	d|| dd	| _t| j	| j
 | jdd|| d
d| _d|jd dd}t| j
|j|j d |dtjd| _d S )Nr2   r         z.attnrZ   rG   Tz	.qkv_proj)rW   rY   rZ   z.o_proj)rW   input_is_parallelrY   rZ   mllama4
rope_thetag      ?)	rope_typer   partial_rotary_factorF)	head_sizemax_positionrope_parametersis_neox_styledtype)r\   r]   ry   r=   r   tp_sizerv   	embed_dimnum_attention_heads	num_headshead_dimnum_local_headsq_sizekv_sizeattention_dropoutscalingr   attnr   qkv_projo_projr   r   r   r   rD   r   rN   	complex64
rotary_emb)rb   ry   rY   rZ   rc   r   rd   rP   rQ   r]      s|   




zLlama4VisionAttention.__init__rf   rg   c           	      C   s   |j d d }| |\}}|j| j| j| jgdd\}}}||j d |j d | j| j}||j d |j d | j| j}| ||\}}||j d |j d d}||j d |j d d}| 	|||}|j
g |dR   }| |\}}|S )Nr{   dimr   r2   )r}   r   splitr   r   r   r   r   r   r   reshaper   r   )	rb   rf   input_shapeqkvrj   qkvattn_outputrP   rP   rQ   rk   7  s      zLlama4VisionAttention.forwardrS   rI   rJ   rK   r
   r   ro   r]   rN   rO   rk   rp   rP   rP   rd   rQ   r      s    Mr   c                       sB   e Zd Z	ddededB def fddZdejfd	d
Z	  Z
S )Llama4VisionEncoderLayerrS   ry   rY   NrZ   c              	      s|   t    |j| _|j| _|j| _t||| dd| _t|j|j|jdd|| dd| _t	
|j| _t	
|j| _d S )Nz
.self_attnrY   rZ   TFr   r   )r\   r]   rv   r   rU   r   	self_attnrR   r   r   	LayerNorminput_layernormpost_attention_layernormrx   rd   rP   rQ   r]   O  s(   

z!Llama4VisionEncoderLayer.__init__hidden_statec                 C   sJ   |}|  |}| |}|| }|}| |}| |}|| }|f}|S rh   )r   r   r   r   )rb   r   residualoutputsrP   rP   rQ   rk   l  s   



z Llama4VisionEncoderLayer.forwardr   r   rP   rP   rd   rQ   r   N  s    r   c                       r   )Llama4VisionEncoderrS   ry   rY   NrZ   c                    s8   t     | _t fddt jD | _d S )Nc                    s$   g | ]}t   d | dqS )z.layers.r   )r   ).0	layer_idxry   rZ   rY   rP   rQ   
<listcomp>  s    z0Llama4VisionEncoder.__init__.<locals>.<listcomp>)r\   r]   ry   r   
ModuleListrangenum_hidden_layerslayersrx   rd   r   rQ   r]     s   

zLlama4VisionEncoder.__init__rf   rg   c                 C   s    | j D ]
}||}|d }q|S )aR  
        Args:
            hidden_states: Input tensor of shape
                (batch_size, sequence_length, hidden_size).
                Hidden states from the model embeddings, representing
                the input tokens.
                associated vectors than the model's internal embedding
                lookup matrix.
        r   )r   )rb   rf   encoder_layerlayer_outputsrP   rP   rQ   rk     s   

zLlama4VisionEncoder.forwardr   r   rP   rP   rd   rQ   r     s    r   c                       J   e Zd Z		ddededB def fddZdejd	ejfd
dZ	  Z
S )Llama4UnfoldConvolutionNrS   ry   rY   rZ   c              	      sr   t    |j}t|tr||f}tjj||jd| _t	 }t
|j|d  |d  |jdd|| d|d| _d S )N)kernel_sizestrider   r2   FTz.linear)rT   rV   rW   rr   rY   rZ   r[   )r\   r]   r   
isinstancerm   rN   r   Unfoldunfoldr=   r   rC   rv   linear)rb   ry   rY   rZ   r   rc   rd   rP   rQ   r]     s   

z Llama4UnfoldConvolution.__init__rf   rg   c                 C   s*   |  |}|ddd}| |\}}|S )Nr   rG   r2   )r   r   r   ri   rP   rP   rQ   rk     s   
zLlama4UnfoldConvolution.forwardrl   r   rP   rP   rd   rQ   r     s    r   images_flattened)dynamic_arg_dims	enable_ifc                       r   )Llama4VisionModelNrS   ry   rY   rZ   c                    s   t    || _|j| _|j| _|j| _|j| _| j| j d d | _|jd | _t	||| dd| _
t| jt| j | _t| jt| j| j | _tj| jdd| _tj| jdd| _t||| dd| _t||| d	d
| _d S )NrG   r2   r   z.patch_embeddingr   gh㈵>)epsz.modelz.vision_adapterr   )r\   r]   ry   rD   r   rv   rC   r   scaler   patch_embeddingr   	ParameterrN   randnclass_embeddingpositional_embedding_vlmr   layernorm_prelayernorm_postr   modelr   vision_adapterrx   rd   rP   rQ   r]     s:   
zLlama4VisionModel.__init__r   rg   c                 C   s   |  |}|j\}}}| j|jd d|jd }tj||gdd}|d7 }||d||}| jj|j	|j
d}|| }| |}||d|}| |}| |}|d d d dd d f }| |}|S )Nr   r2   r{   r   )r   device)r   r}   r   expandrN   catr   r   tor   r   r   r   r   r   r   )rb   r   r   	num_tilesr   
hidden_dimr   positional_embeddingrP   rP   rQ   rk     s0   




zLlama4VisionModel.forwardrl   r   rP   rP   rd   rQ   r     s    ,r   c                       s   e Zd Zdeddf fddZdefddZdedefd	d
Z	de
eedB f fddZededefddZdefddZdefddZ  ZS )Mllama4ProcessingInfoctxrg   Nc                    s   t  | d S rh   )r\   r]   )rb   r   rd   rP   rQ   r]   "  s   zMllama4ProcessingInfo.__init__c                 C   s   | j tS rh   )r   get_hf_configr	   rb   rP   rP   rQ   r   %  s   z#Mllama4ProcessingInfo.get_hf_configkwargsc                 K   s    | j jtfd|ddi|S )Nuse_fastT)r   get_hf_processorr   pop)rb   r   rP   rP   rQ   r   (  s   
z&Mllama4ProcessingInfo.get_hf_processorc                 C   s   dd iS )NimagerP   r   rP   rP   rQ   get_supported_mm_limits-  s   z-Mllama4ProcessingInfo.get_supported_mm_limitsrs   c                 C   sX   | j }| j}|| dksJ d| dd|  ttd| jd  }|| d | S )Nr   zchunk size z should be multiple of zpatch_size g      ?rG   )rD   r   rm   roundr   )rs   rD   r   ds_ratiorP   rP   rQ   get_patch_per_chunk2  s   

z)Mllama4ProcessingInfo.get_patch_per_chunkc                 C   s   |   j}|jS rh   )r   image_processormax_patches)rb   r  rP   rP   rQ   get_max_num_tiles?  s   
z'Mllama4ProcessingInfo.get_max_num_tilesc                 C   s$   |   j}|j}t|  | |dS )Nr   r   )r   rs   rD   r&   r  )rb   rs   rD   rP   rP   rQ   !get_image_size_with_most_featuresC  s   
z7Mllama4ProcessingInfo.get_image_size_with_most_features)rI   rJ   rK   r+   r]   r	   r   objectr   r   r   ro   rm   r   staticmethodr
   r   r  r&   r  rp   rP   rP   rd   rQ   r   !  s    r   c                
       s   e Zd Zdedeeef deeef deeef def
 fddZded	eeef deeef fd
dZ	de
d	eeef dedee fddZ  ZS )Mllama4MultiModalProcessorpromptmm_data	mm_kwargs
tok_kwargsrg   c                    s  | j  }|d u r||ddS t j||||d}| j jdi |}|j | j  j}|dd urd|v s:J d|d }	| j j	d|	idd}
|

dt}|jt| j  td	d
 fdd|D }fdd|D }dd |D }t||d< t||d< |S )NF)add_special_tokens)r	  r
  r  r  r@   imagesz=images expected to be in mm_data when pixel_values is presentr   )validater  )max_num_chunksr   c                    s2   g | ]}t |jd  |jd ft jdqS )r2   r   )resize_to_max_canvas)r   r   rN   tensorr  )r   r   )r  possible_resolutionsrP   rQ   r   o  s    zAMllama4MultiModalProcessor._call_hf_processor.<locals>.<listcomp>c                    s$   g | ]}|d    |d   fqS r   r2   rP   )r   rD   )	tile_sizerP   rQ   r   x  s    c                 S   s,   g | ]\}}|| d krd nd ||  qS )r2   rP   )r   r_hr_wrP   rP   rQ   r   |  s     rH   rF   rP   )infoget_tokenizerr\   _call_hf_processorr   r  r   rs   getparse_mm_data	get_itemsr%   rD   r   r  r   rN   r  )rb   r	  r
  r  r  	tokenizerprocessed_outputs	processorrs   r  mm_itemsparsed_imagesbest_fit_sizesrH   rF   rd   )r  r  r  rQ   r  K  sH   



	z-Mllama4MultiModalProcessor._call_hf_processor	hf_inputshf_processor_mm_kwargsc                 C   s4   | dtd}ttd|tdtddS )NrF   r   r   )r@   rF   rH   )r  rN   emptydictr#   flat_from_sizesbatched)rb   r$  r%  rF   rP   rP   rQ   _get_mm_fields_config  s   z0Mllama4MultiModalProcessor._get_mm_fields_configr!  out_mm_kwargsc                    sb   | j  }|j}| j || j jdi |  j} jdtf fdd}td||dgS )Nitem_idxc                    s0   d |  }|d j } j|d}t|S )Nr   rH   )aspect_rationum_patches_per_chunk)data_prompt_split_imager.   select_text)r,  out_itemr-  replhf_processorimg_patch_tokenr.  r+  rP   rQ   get_replacement  s   
zGMllama4MultiModalProcessor._get_prompt_updates.<locals>.get_replacementr   )modalitytargetreplacementrP   )	r  r   rs   r   r   image_tokenr6  rm   r,   )rb   r!  r%  r+  ry   rs   r;  r7  rP   r4  rQ   _get_prompt_updates  s   
z.Mllama4MultiModalProcessor._get_prompt_updates)rI   rJ   rK   ro   r   r  r   r  r#   r*  r'   r$   listr-   r<  rp   rP   rP   rd   rQ   r  J  s8    


:



r  c                	   @   sX   e Zd Zdeeef defddZ	d
dedeeef deeef dB defdd	Z	dS )Mllama4DummyInputsBuilder	mm_countsrg   c                 C   s$   | dd}| j }|j}|| S )Nr   r   )r  r  r   fake_image_token)rb   r?  
num_imagesr   r;  rP   rP   rQ   get_dummy_text  s   
z(Mllama4DummyInputsBuilder.get_dummy_textNseq_len
mm_optionsc                 C   sB   | dd}| j \}}|r| dnd }d| j||||diS )Nr   r   )r   r   rA  	overrides)r  r  r  _get_dummy_images)rb   rC  r?  rD  rA  target_widthtarget_heightimage_overridesrP   rP   rQ   get_dummy_mm_data  s   z+Mllama4DummyInputsBuilder.get_dummy_mm_datarh   )
rI   rJ   rK   r   ro   rm   rB  r   r"   rJ  rP   rP   rP   rQ   r>    s    
r>  )r  dummy_inputsc                       s  e Zd Zg dddgdZdZedededed	B fd
dZddde	def fddZ
deedf dd	fddZdeedf fddZdejdejdejfddZdedefdd Zd!eded	B fd"d#Zd$edefd%d&Zdefd'd(Z				dId)ejd	B d*ejd+ed	B d,ejd	B d!edejeB fd-d.Zd/ejdejd	B fd0d1Zd2eeeejf  dedeeeeejf  eeeejf  f fd3d4Zd2eeeejf  deeeejf  fd5d6Zd7edefd8d9Zd2eeeejf  deeeeejf  eeeejf  f fd:d;Z d2eeeejf  d<e!deeeeejf  e"e f fd=d>Z#d?eeeejf  d<e!d@ede"e fdAdBZ$deeeeeef  fdCdDZ%d2eeeejf  de"e fdEdFZ&de'fdGdHZ(  Z)S )JLlama4ForConditionalGeneration)q_projk_projv_proj	gate_projup_proj)r   gate_up_projTr8  irg   Nc                 C   s   | drdS td)Nr   z	<|image|>z Only image modality is supported)
startswith
ValueError)clsr8  rS  rP   rP   rQ   get_placeholder_str  s   
z2Llama4ForConditionalGeneration.get_placeholder_strrS   r   vllm_configrZ   c             
      s  t    |jj}|j}|jj}|jdk| _|| _|| _	|| _|| _| 
|dP ddlm} t|* |ddd t|jd t|dd	| _W d    n1 sQw   Y  W d    n1 s`w   Y  t| j	d t|d
d	| _W d    n1 s{w   Y  | | t||jdgt|dtd| _W d    n1 sw   Y  | jj| _d| _| jj| _| jj| _| jj| _| jj| _| jj | _ | jj!| _!| jj"| _"t#| j"| _$d S )Nr/  r   r   )set_model_tagr   T)
is_encodervision_model)ry   rY   rZ   multi_modal_projectorLlamaForCausalLMlanguage_model)rX  rZ   model_classr2   )%r\   r]   model_config	hf_configrY   multimodal_configmm_encoder_tp_moderc   rX  ry   _mark_tower_modelvllm.compilation.backendsrY  r   r   rs   r<   r[  rq   r\  _mark_language_modelr   with_hf_configru   r9   r^  make_empty_intermediate_tensorsnum_expert_groupsnum_logical_expertsnum_physical_expertsnum_local_physical_expertsnum_routed_expertsnum_shared_expertsnum_redundant_experts
moe_layerslennum_moe_layers)rb   rX  rZ   ry   rY   rb  rY  rd   rP   rQ   r]     sb   


 










z'Llama4ForConditionalGeneration.__init__r   .c                 C   s    t | jdsJ | j| dS )zBSet which layers should output auxiliary hidden states for EAGLE3.set_aux_hidden_state_layersN)hasattrr^  rs  )rb   r   rP   rP   rQ   rs  &  s   z:Llama4ForConditionalGeneration.set_aux_hidden_state_layersc                 C   s   t | jdsJ | j S )zGet the layer indices for auxiliary hidden state outputs.

        Note: The GPU model runner will override this with layers from
        the speculative config if available, providing dynamic configuration.
        "get_eagle3_aux_hidden_state_layers)rt  r^  ru  r   rP   rP   rQ   ru  ,  s   
zALlama4ForConditionalGeneration.get_eagle3_aux_hidden_state_layersexpert_load_viewlogical_to_physical_maplogical_replica_countc                 C   s   | j ||| | j j| _d S rh   )r^  set_eplb_stateexpert_weights)rb   rv  rw  rx  rP   rP   rQ   ry  6  s   z-Llama4ForConditionalGeneration.set_eplb_staterk  rl  c                 C   s   | j || d S rh   )r^   update_physical_experts_metadata)rb   rk  rl  rP   rP   rQ   r{  A  s   z?Llama4ForConditionalGeneration.update_physical_experts_metadatar   c                 K   s<   | dd }|d u rd S | d}| d}td|||dS )Nr@   rF   rH   )rA   r@   rF   rH   )r   r?   )rb   r   r@   rF   rH   rP   rP   rQ   _parse_and_validate_image_inputH  s   

z>Llama4ForConditionalGeneration._parse_and_validate_image_inputimage_inputc                 C   sd   | j r| jsJ |d }|d  }| jrt|| j }n|  |}| |}dd |j|ddD S )Nr@   rF   c                 S   s   g | ]}| d dqS r  )flatten)r   imgrP   rP   rQ   r   k  s    
zGLlama4ForConditionalGeneration._process_image_input.<locals>.<listcomp>r   r   )r[  r\  tolistrc   r>   r   )rb   r}  r@   rF   vision_embeddings_flatrP   rP   rQ   _process_image_inputZ  s   

z3Llama4ForConditionalGeneration._process_image_inputc                 K   sV   | j di |}|d u rg S td | j | |W  d    S 1 s$w   Y  d S )NrP   )r|  r   rX  r  )rb   r   r}  rP   rP   rQ   embed_multimodalp  s   
$z/Llama4ForConditionalGeneration.embed_multimodal	input_ids	positionsintermediate_tensorsinputs_embedsc                 K   s   |d urd }|  ||||S rh   )r^  )rb   r  r  r  r  r   rP   rP   rQ   rk   z  s
   z&Llama4ForConditionalGeneration.forwardrf   c                 C   s   | j |S rh   )r^  compute_logits)rb   rf   rP   rP   rQ   r    s   z-Llama4ForConditionalGeneration.compute_logitsweightsc                    s^   t |d\dttttjf  f fdd}dttttjf  f fdd}| | fS )NrG   rg   c                  3   s(    D ]\} }|   r| |fV  qd S rh   rT  namer/  )rZ   weights1rP   rQ   get_prefix_weights     

zKLlama4ForConditionalGeneration.separate_weights.<locals>.get_prefix_weightsc                  3   s(    D ]\} }|   s| |fV  qd S rh   r  r  )rZ   weights2rP   rQ   get_other_weights  r  zJLlama4ForConditionalGeneration.separate_weights.<locals>.get_other_weights)r   r   tuplero   rN   rO   )rb   r  rZ   r  r  rP   )rZ   r  r  rQ   separate_weights  s   ""z/Llama4ForConditionalGeneration.separate_weightsc                 c   s    dddd}i }|D ]/\}}|  D ]!\}}||vrq||d}||vr-d gd ||< ||| |<  n||fV  q|  D ]\}	}
tj|
dd}|	|fV  q?d S )Nr   r2   rG   ).self_attn.q_proj.self_attn.k_proj.self_attn.v_proj.self_attn.qkv_projr|   r   )itemsreplacerN   r   )rb   r  qkv_idx_mappingsqkv_weightsr  loaded_weightweight_nameidxnew_namekeyweight
qkv_weightrP   rP   rQ   _consolidate_qkv_weights  s*   
z7Llama4ForConditionalGeneration._consolidate_qkv_weightsr  c                 C   s   | ds
| drr| dr|dddn|}d|v rNd|v s$d|v rNd|v r.|ddS d	|v r8|d	d
S d|v rB|ddS d|v rL|ddS |S d|v rpd|v sZd|v rpd|v rd|ddS d|v rn|ddS |S |S | dr}|ddS |S )zKRename weights from ModelOpt llama4 fp8 checkpoints to vLLM
        format.zmodel.zlanguage_model.model.r2   feed_forward.experts._input_scale_weight_scaledown_proj_input_scalew2_input_scaledown_proj_weight_scalew2_weight_scalegate_up_proj_input_scalew13_input_scalegate_up_proj_weight_scalew13_weight_scalez
self_attn.z.k_scalez.v_scalez.k_proj.k_scalez.attn.k_scalez.v_proj.v_scalez.attn.v_scalezlm_head.weightzlanguage_model.lm_head.weight)rT  r  )rb   r  renamedrP   rP   rQ   &_rename_weight_for_modelopt_checkpoint  s<   
zELlama4ForConditionalGeneration._rename_weight_for_modelopt_checkpointc                 C   sr   g }g }|D ].\}}|  |}|ddd }tt| |tr q|dr-|||f q|||f q||fS )zORename weights and separate them into language_model and other
        weights..r2   r   zlanguage_model.)r  r   r   getattrr;   rT  append)rb   r  language_model_weightsother_weightsr  r  r  attrrP   rP   rQ   _separate_and_rename_weights  s   

z;Llama4ForConditionalGeneration._separate_and_rename_weightsparams_dictc           	      C   s   g }g }t  }|D ]G\}}d|v rId|v rId|vrI||v rA|| }t|drA|j dkrA| dkrA|j|  || q	|||f q	|||f q	|||fS )zHandle expert scale parameters that need broadcasting.

        ModelOpt checkpoints use a single value tensor scalar for BMM style
        experts, vLLM expects the scale to be broadcasted across all experts.
        r  r   z.shared_expertr/  r2   )setrt  r/  numelfill_itemaddr  )	rb   r  r  regular_weightsexpert_scale_weightsupdated_paramsr  r  paramrP   rP   rQ   !_handle_expert_scale_broadcasting  s&   

z@Llama4ForConditionalGeneration._handle_expert_scale_broadcastingr  stacked_params_mappingc                 C   s   t  }| jr| |}|D ]A\}}|D ]&\}}}	||vs| jr q|||}|| }
|| |
j}||
||	  n|| }
t|
dt}||
| || q|S )z6Load non-language-model weights with stacking support.weight_loader)r  rc   r  r  r  r  r  r   )rb   r  r  r  r  r  r  
param_namer  shard_idr  r  rP   rP   rQ   _load_other_weights!  s&   



z2Llama4ForConditionalGeneration._load_other_weightsc                 C   s   t j| ddd| jjj| jdS )NrP  	down_projrQ  )ckpt_gate_proj_nameckpt_down_proj_nameckpt_up_proj_namenum_expertsro  )r   make_expert_params_mappingry   ru   num_local_expertsro  r   rP   rP   rQ   get_expert_mappingA  s   z1Llama4ForConditionalGeneration.get_expert_mappingc                 C   s   g d}t |  }t }| |\}}| ||\}}}	||	 t| }
|
|}|d us1J || |rD|
|}|rD|| || ||| |S )N))r  r  r   )r  r  r   )r  r  r   ).shared_expert.gate_up_projz.shared_expert.gate_projr   )r  z.shared_expert.up_projr2   ).feed_forward.gate_up_projz.feed_forward.gate_projr   )r  z.feed_forward.up_projr2   )	r'  named_parametersr  r  r  updater:   load_weightsr  )rb   r  r  r  r  r  r  r  r  updated_params_from_expertsloaderloaded_language_model_paramsloaded_expert_scale_paramsrP   rP   rQ   r  M  s*   





z+Llama4ForConditionalGeneration.load_weightsc                 C   s   t jddddS )z<
        Get the module prefix in multimodal models
        r^  zmulti_modal_projector.zvision_model.)r^  	connectortower_model)r   from_string_fieldr   rP   rP   rQ   get_mm_mappingx  s
   z-Llama4ForConditionalGeneration.get_mm_mapping)NN)*rI   rJ   rK   packed_modules_mappingsupports_encoder_tp_dataclassmethodro   rm   rW  r   r]   r  rs  ru  rN   rO   ry  r{  r  r?   r|  r4   r  r  r/   rk   r  r   r  r  r  r=  r  r'  r  r  r  r  r  r   r  rp   rP   rP   rd   rQ   rL    s    7






&

,&

%
 $+rL  )mr~   collections.abcr   r   	itertoolsr   typingr   r   rN   r   transformersr   r	   r
   transformers.image_utilsr   transformers.models.llama4r   7transformers.models.llama4.image_processing_llama4_fastr   r   vllm.compilation.decoratorsr   vllm.configr   r   vllm.config.multimodalr   vllm.distributedr   vllm.forward_contextr   $vllm.model_executor.layers.attentionr   $vllm.model_executor.layers.fused_moer   !vllm.model_executor.layers.linearr   r   r   r   'vllm.model_executor.layers.quantizationr   +vllm.model_executor.layers.rotary_embeddingr   &vllm.model_executor.model_loader.utilsr   -vllm.model_executor.model_loader.weight_utilsr   )vllm.model_executor.models.module_mappingr   !vllm.model_executor.models.visionr    vllm.multimodalr!   vllm.multimodal.inputsr"   r#   r$   vllm.multimodal.parser%   r&   r'   vllm.multimodal.processingr(   r)   r*   r+   r,   r-   r.   vllm.sequencer/   vllm.utils.tensor_schemar0   r1   
interfacesr3   r4   r5   r6   r7   r8   llama4r9   utilsr:   r;   r<   visionr=   r>   r?   ModulerR   rq   r   r   r   r   r   r   r   r   r  r>  register_processorrL  rP   rP   rP   rQ   <module>   s~   $	  )e2)W)k

