o
    
۾i                     @   sP  d Z ddlZddlmZmZ ddlmZ ddlZddlmZ ddl	m
Z
mZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZmZmZmZmZ ddlmZ ddlmZ ddlm Z  ddl!m"Z" ddl#m$Z$ ddl%m&Z&m'Z' ddl(m)Z)m*Z*m+Z+m,Z,m-Z- ddl.m/Z/ ddl0m1Z1m2Z2 ddl3m4Z4 ddl5m6Z6 ddl7m8Z8 ddl9m:Z: ddl;m<Z<m=Z= ddl>m?Z?m@Z@ ddlAmBZB ddlCmDZD ddlEmFZF ddlGmHZH ddlImJZJ ddlKmLZLmMZM d d!lNmOZOmPZPmQZQmRZR d d"lSmTZTmUZUmVZVmWZWmXZX eeYZZG d#d$ d$ej[Z\G d%d& d&ej[Z]G d'd( d(ej[Z^dMd)e_d*e_d+e_fd,d-Z`d.ead/e_d0ejbd+ejbfd1d2ZcG d3d4 d4ej[ZdG d5d6 d6ejj[e"ZeG d7d8 d8ej[ZfG d9d: d:ej[ZgG d;d< d<ej[ZheG d=d> d>ej[ZiG d?d@ d@eOZjG dAdB dBej[eRejeQePZkG dCdD dDekZlG dEdF dFekZmG dGdH dHekZndIe
eB dJeod+eadB fdKdLZpdS )Nz+Inference-only DeepseekV2/DeepseekV3 model.    N)CallableIterable)islice)nn)DeepseekV2ConfigDeepseekV3Config)rocm_aiter_ops)support_torch_compile)CacheConfigParallelConfig
VllmConfigget_current_vllm_config)get_ep_groupget_pp_groupget_tensor_model_parallel_rank$get_tensor_model_parallel_world_size tensor_model_parallel_all_gather)init_logger)
SiluAndMul)	Attention)AttentionLayerBase)SharedFusedMoE)	LayerNormRMSNorm)ColumnParallelLinearMergedColumnParallelLinearQKVParallelLinearReplicatedLinearRowParallelLinear)LogitsProcessor)
MLAModulesMultiHeadLatentAttentionWrapper)QuantizationConfig)per_token_group_quant_fp8)get_rope)SparseAttnIndexer)ParallelLMHeadVocabParallelEmbedding)default_weight_loadermaybe_remap_kv_scale_name)sequence_parallel_chunk)current_platform)IntermediateTensors)AttentionBackendDeepseekV32IndexerBackend)KVCacheSpecMLAAttentionSpec   )MixtureOfExpertsSupportsEagleSupportsLoRA
SupportsPP)PPMissingLayeris_pp_missing_parameter'make_empty_intermediate_tensors_factorymake_layersmaybe_prefixc                       sx   e Zd ZdZ				ddedeeB deded	ed
edB de	dB de
ddf fddZdejdejdejfddZ  ZS )DeepseekAttentionz.Normal MHA implementation used by Deepseek v1.    N vllm_configconfighidden_size	num_headsmax_position_embeddingscache_configquant_configprefixreturnc	              	      s2  t    || _t }
|| _| j|
 dksJ | j|
 | _|j| _| j|
kr0| j|
 dks/J n	|
| j dks9J td| j|
 | _	|| j | _
| j| j
 | _| j	| j
 | _| j
d | _|| _t|| j
| j| jd|d| _t| j| j
 |d|d| _t| j
||jd| _t| j| j
| j| j	||| dd| _d S )	Nr   r2         F)biasrE   )max_positionrope_parameters.attnnum_kv_headsrD   rE   rF   )super__init__rA   r   total_num_headsrB   num_key_value_headstotal_num_kv_headsmaxrN   head_dimq_sizekv_sizescalingrC   r   qkv_projr   o_projr$   rK   
rotary_embr   attn)selfr?   r@   rA   rB   rC   rD   rE   rF   kwargstp_size	__class__ Z/home/ubuntu/.local/lib/python3.10/site-packages/vllm/model_executor/models/deepseek_v2.pyrP   b   sX   

	
zDeepseekAttention.__init__	positionshidden_statesc           
      C   s`   |  |\}}|j| j| j| jgdd\}}}| |||\}}| |||}| |\}	}|	S )Ndim)rY   splitrV   rW   r[   r\   rZ   )
r]   rd   re   qkv_qkvattn_outputoutputrb   rb   rc   forward   s    zDeepseekAttention.forward)r=   NNr>   )__name__
__module____qualname____doc__r   r   r   intr
   r"   strrP   torchTensorrq   __classcell__rb   rb   r`   rc   r<   _   s@    	Ar<   c                       sP   e Zd Z				ddededededB d	ed
eddf fddZdd Z  Z	S )DeepseekV2MLPNTFr>   rA   intermediate_size
hidden_actrE   reduce_resultsrF   rG   c              	      sn   t    t||gd d||| dd| _t||d|||| dd| _|dkr1td| d	t | _d S )
N   Fz.gate_up_proj)rI   rE   
disable_tprF   z
.down_proj)rI   rE   r~   r   rF   siluUnsupported activation: !. Only silu is supported for now.)	rO   rP   r   gate_up_projr   	down_proj
ValueErrorr   act_fn)r]   rA   r|   r}   rE   r~   is_sequence_parallelrF   r`   rb   rc   rP      s.   

	
zDeepseekV2MLP.__init__c                 C   s*   |  |\}}| |}| |\}}|S N)r   r   r   )r]   xgate_uprk   rb   rb   rc   rq      s   
zDeepseekV2MLP.forward)NTFr>   )
rr   rs   rt   rv   rw   r"   boolrP   rq   rz   rb   rb   r`   rc   r{      s*    	'r{   c                	       sR   e Zd Z		ddeeB dededB def fddZd	e	j
d
e	j
fddZ  ZS )DeepseekV2MoENr>   r@   parallel_configrE   rF   c              	      sh  t    t | _t | _t|dd| _t j	| _
t j| _| j
 | _|j| _|j| _|j| _|jdkr>td|j dt|j|jdd | dd| _t|d	d d
krdttj|jtjd| j_nd | j_|j}|j| _|j | _!| j| _"| j"| j! | _#| j#| j | _$| j| j$ | _%| j%| j$ | _&t'( | _)t'* | _*|jd u s| j*rd | _+n|j,|j }t-|j||j|| jd| dd| _+t.d&i d| j+d| jd|jd|j/d|jd|j,ddd|j0d|dddt|dddt|ddd| ddt|dd d| j)sdn| jd!| jjd"| jd#| j!d$| jd%| j*r'|jnd | _1d S | _1d S )'Nrouted_scaling_factor      ?r   r   r   Fz.gaterI   rE   rF   topk_methodnoaux_tc)dtypez.shared_experts)rA   r|   r}   rE   r   r~   rF   shared_expertsgatenum_expertstop_krA   r|   r~   renormalizerE   use_grouped_topkTnum_expert_groupn_groupr2   
topk_grouprF   z.expertsscoring_funcsoftmaxe_score_correction_biasenable_eplbnum_redundant_expertsr   n_shared_expertsrb   )2rO   rP   r   r_   r   tp_rankgetattrr   r   device_groupep_grouprank_in_groupep_ranksizeep_sizen_routed_expertsr   use_sequence_parallel_moer   r}   r   r   rA   r   r   	Parameterrx   emptyfloat32r   eplb_configr   r   n_redundant_expertsn_logical_expertsn_physical_expertsn_local_physical_expertsphysical_expert_startphysical_expert_endr   is_fused_moe_enabledis_rocm_aiter_moe_enabled$is_fusion_moe_shared_experts_enabledr   moe_intermediate_sizer{   r   num_experts_per_toknorm_topk_probexperts)r]   r@   r   rE   rF   r   r|   r`   rb   rc   rP      s   







	




zDeepseekV2MoE.__init__re   rG   c           	      C   s  |j \}}|d|}| jrt|}| jjr| j||d}n| |\}}| j||d}|\}}| jd u r;|d u s;J |jt	j
krJ| jsI|| j9 }n| jd ur\|d usUJ |d| j 9 }| jd urk|d usgJ ||7 }| jrzt|d}|d | }n| jdkr| j|}|||S )Nrf   )re   router_logitsr   r   r2   )shapeviewr   r*   r   is_internal_routerr   r   r   rx   float16r   r   r   r_   &maybe_all_reduce_tensor_model_parallel)	r]   re   
num_tokens
hidden_dimfused_moe_outr   rk   shared_outputfinal_hidden_statesrb   rb   rc   rq   G  sF   





zDeepseekV2MoE.forward)Nr>   )rr   rs   rt   r   r   r   r"   rw   rP   rx   ry   rq   rz   rb   rb   r`   rc   r      s    gr   scalemscalerG   c                 C   s*   dd l }| dkr
dS d| ||  d S )Nr   r2   r   g?)mathlog)r   r   r   rb   rb   rc   yarn_get_mscale|  s   r    original_max_position_embeddingsscaling_betard   c              	   C   s(   d|t dt ||     }|d S )Nr2   ).NN)rx   r   floor)r   r   rd   rX   rb   rb   rc   _get_llama_4_scaling  s   r   c                       s   e Zd Z					ddedeeB dededed	ed
edededededB dedB de	j
dB deddf fddZde	j
de	j
de	j
dB de	j
fddZ  ZS )DeepseekV2Attentionr=   Nr>   r?   r@   rA   rB   qk_nope_head_dimqk_rope_head_dim
v_head_dimq_lora_rankkv_lora_rankrC   rD   rE   topk_indices_bufferrF   rG   c              	      sD  t    || _|| _|| _|| | _|| _|| _|	| _|| _	t
 }|| dks*J || | _| jd | _|
| _|d u s@J d| jd urot| j| jd|| dd| _t| j|jd| _t|| j	| j d|| dd| _nt| j| j	| j d|| d	d| _t| j| j| j d|| d
d| _t| j|jd| _t| j| j	| j| j  d|| dd| _t| j	| j | jd|| dd| _|jd dkr|jddrdnd|jd< t||
|jdd| _|jd dkr|jd dkr|jdd}|jd }t|t |}| j| | | _t!| j| j| j| j||| dd| _"d S )Nr   rH   zDtopk_indices_buffer is not         supported for DeepseekV2AttentionFz	.q_a_projr   eps	.q_b_proj.q_proj.kv_a_proj_with_mqa
.kv_b_proj.o_proj	rope_typedefaultapply_yarn_scalingTdeepseek_yarndeepseek_llama_scalingrJ   rK   is_neox_stylemscale_all_dimfactorrL   rM   )#rO   rP   rA   r   r   qk_head_dimr   r   r   rB   r   num_local_headsrX   rC   r   q_a_projr   rms_norm_epsq_a_layernormr   q_b_projq_projkv_a_proj_with_mqakv_a_layernorm	kv_b_projr   rZ   rK   getr$   r[   r   floatr   r\   )r]   r?   r@   rA   rB   r   r   r   r   r   rC   rD   rE   r   rF   r_   r   scaling_factorr   r`   rb   rc   rP     s   










zDeepseekV2Attention.__init__rd   re   llama_4_scalingc                 C   s  | j d ur | |d }| |}| |d d| j| j}n| |d d| j| j}|j| j	| j
gdd\}}| |d }|j| j| j
gdd\}}	|d}| |}| |d }
|
d| j| j	| j }
|
j| j	| jgdd\}}|d d d d | jd f }| |||\}}||d| j	d f< t|}||dd | j	f< ||d| j	d f< |d ur||9 }tjjj|d| j| j gddd| j| j }| |||}|d| j| jdd | jf d| j| j }| |\}}	|S )Nr   rf   rg   r2   .)value)r   r   r   r   r   r   r   r   ri   r   r   r   r   	unsqueezer   r   r   r[   rx   
empty_liker   
functionalpadr\   reshaperZ   )r]   rd   re   r   rl   q_nopeq_pelatent_cachekv_ark   kvk_nopern   k_perm   ro   rp   rb   rb   rc   rq     sH   





zDeepseekV2Attention.forward)r=   NNNr>   )rr   rs   rt   r   r   r   rv   r
   r"   rx   ry   rw   rP   rq   rz   rb   rb   r`   rc   r     s\    	
ur   c                       sX   e Zd Zdedejdedef fddZde	de
fd	d
Zdd ZdefddZ  ZS )DeepseekV32IndexerCacherU   r   rF   rD   c                    s^   t    tg g| _|| _|| _|| _|| _t	 j
}||jv r(td| | |j|< d S )NzDuplicate layer name: )rO   rP   rx   tensorkv_cacherU   rF   rD   r   r   compilation_configstatic_forward_contextr   )r]   rU   r   rF   rD   r  r`   rb   rc   rP   4  s   

z DeepseekV32IndexerCache.__init__r?   rG   c                 C   s   t | jjd| j| jdS )Nr2   )
block_sizerN   	head_sizer   )r1   rD   r
  rU   r   )r]   r?   rb   rb   rc   get_kv_cache_specB  s   z)DeepseekV32IndexerCache.get_kv_cache_specc                 C   s   d S r   rb   r]   rb   rb   rc   rq   J  s    zDeepseekV32IndexerCache.forwardc                 C   s   t S r   r.   r  rb   rb   rc   get_attn_backendL  s   z(DeepseekV32IndexerCache.get_attn_backend)rr   rs   rt   rv   rx   r   rw   r
   rP   r   r0   r  rq   r-   r  rz   rb   rb   r`   rc   r  3  s    r  c                       sp   e Zd Z	ddedeeB dedededB dedB d	e	j
dB d
ef fddZde	j
de	j
de	j
fddZ  ZS )Indexerr>   r?   r@   rA   r   rE   NrD   r   rF   c	           
   	      sD  t    || _|| _|j| _|j| _|j| _	|j
| _|| _t| j| j	| j d|| dd| _t|| j	d|| dd| _t| j	dd| _t|| jdd | dd| _| j	d | _d	| _d
| _|| _t| j	| j	| j d  tj| d|d| _|jj| _|| _ddlm}	 |	|| _ t!| j| j| j| j| j	| j| j | j| _"d S )NFz.wq_br   z.wkgư>r   z.weights_projrH   ue8m0      z.k_cache)rU   r   rF   rD   r   )get_max_prefill_buffer_size)#rO   rP   r?   r@   
index_topktopk_tokensindex_n_headsn_headindex_head_dimrU   r   rope_dimr   r   wq_bwkr   k_normweights_projsoftmax_scale	scale_fmtquant_block_sizer   r  rx   uint8k_cachemodel_configmax_model_lenrF   &vllm.v1.attention.backends.mla.indexerr  max_total_seq_lenr%   
indexer_op)
r]   r?   r@   rA   r   rE   rD   r   rF   r  r`   rb   rc   rP   Q  sl   




zIndexer.__init__re   qrrG   c                 C   sp  |  |\}}|d| j| j}tj|| j| j| j gdd\}}| |\}	}| |	}	tj|	| j| j| j gdd\}
}||||
	d\}}
|
d| j| j}|

dd| j}
tj||gdd}tj|
d|gdd}	|d| j}t|| jd| jd ud\}}|d| j| j}|d| jd}| |\}}|	d| | j | jd  }|d}| |||	|S )Nrf   rg   r2   F)column_major_scales	use_ue8m0rH   )r  r   r  rU   rx   ri   r  r  r  r   r   catsqueezer#   r   r  r  r  r'  )r]   re   r(  rd   r[   rl   rk   r   r   rm   r  r  q_fp8q_scaleweightsrb   rb   rc   rq     s:   




zIndexer.forward)r>   )rr   rs   rt   r   r   r   rv   r"   r
   rx   ry   rw   rP   rq   rz   rb   rb   r`   rc   r  P  s4    
	Jr  c                       s   e Zd ZdZ					ddedeeB deded	ed
edededB dedededB de	dB de
dejdB ddf fddZdejdejdejdB dejfddZ  ZS )DeepseekV2MLAAttentiona  
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
    r=   Nr>   r?   r@   rA   rB   r   r   r   r   r   rC   rD   rE   rF   r   rG   c                    sB  t    || _|| _|| _|| | _|| _|| _|	| _|| _	t
 }|| dks*J || | _| jd | _|
| _| jd urTt| j| j| j| j gd|| ddd| _nt| j| j| j d|| dd| _| jd urt| j|jd	| _t| j| j	| j d|| d
d| _nt| j| j	| j d|| dd| _t| j|jd	| _t| j| j	| j| j  d|| dd| _t| j	| j | jd|| dd| _|jd dkr|jddrdnd|jd< t||
|jdd| _|jd dkr|jd dkr|jdd}|jd }t |t!|}| j| | | _t"|d| _#| j#r<t||
|jt$|dd d| _%t&|||||||| d| _'nd | _%d | _'t(| j| j| j| j| jd urT| jnd | jd u r^| jnd | jd urh| jnd | jd urr| jnd | jd u r|| jnd | j'| j%| j#|d}t)| j| j| j| j| j| j| j| j||||| _*d S )Nr   rH   Fz.fused_qkv_a_projT)rI   rE   rF   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  indexer_rope_interleavez.indexer)r   r   r[   rZ   fused_qkv_a_projr   r   r   r   indexerindexer_rotary_emb	is_sparser   )+rO   rP   rA   r   r   r   r   r   r   rB   r   r   rX   rC   r   r3  r   r   r   r   r   r   r   r   r   r   r   rZ   rK   r   r$   r[   r   r   hasattris_v32r   indexer_rope_embr  r4  r    r!   mla_attn)r]   r?   r@   rA   rB   r   r   r   r   r   rC   rD   rE   rF   r   r_   r   r   r   mla_modulesr`   rb   rc   rP     s  




	







zDeepseekV2MLAAttention.__init__rd   re   r   c                 C   s   |  |||S r   )r:  )r]   rd   re   r   rb   rb   rc   rq   z  s   zDeepseekV2MLAAttention.forward)r=   NNr>   N)rr   rs   rt   ru   r   r   r   rv   r
   r"   rw   rx   ry   rP   rq   rz   rb   rb   r`   rc   r1    s`    	
 (r1  c                       sv   e Zd Z		ddedededB dejdB ddf
 fddZ	dd	ejd
ejdejdB dejdB dejf
ddZ	  Z
S )DeepseekV2DecoderLayerNr?   rF   r@   r   rG   c                    s  t    |d u r|jj}|j}|j}|j}|j}|j| _t|dd}	t|dd}
t	|j
ddd }|| _t|dd	}t|d
d	}t|dd	}t|dd	}|jdkp]tdd ||fD }|| _|rft}n|jrlt}nt}|||| j|j|||t|dr|jnd ||	||| d|d| _|jd ur||jkr||
 d	krt|||| dd| _nt|j|j|j|| dd| _t|j|jd| _ t|j|jd| _!t|dd| _"d S )NrC   r=   moe_layer_freqr2   .)seprf   r   r   r   r   r   deepseekc                 s       | ]}|d kV  qdS r   Nrb   .0rh   rb   rb   rc   	<genexpr>      
z2DeepseekV2DecoderLayer.__init__.<locals>.<genexpr>r   z
.self_attn)r?   r@   rA   rB   r   r   r   r   r   rC   rD   rE   rF   r   z.mlp)r@   r   rE   rF   )rA   r|   r}   rE   rF   r   r   r   )#rO   rP   r#  	hf_configrD   rE   r   rA   r   rv   ri   	layer_idx
model_typealluse_mhar<   use_mlar1  r   num_attention_headsr7  r   	self_attnr   first_k_dense_replacer   mlpr{   r|   r}   r   r   input_layernormpost_attention_layernormr   )r]   r?   rF   r@   r   r#  rD   rE   r   rC   r=  rH  r   r   r   r   rK  attn_clsr`   rb   rc   rP     s|   



zDeepseekV2DecoderLayer.__init__rd   re   residualr   c                 C   s   |d u r|  }| |}n| ||\}}||d}| js"||d< | jdi |}t| jtsI|jtjkrI|d| j	 9 }| j
dkrI|d| j	 9 }| ||\}}| |}t| jtri|jtjkri|d| j	 9 }||fS )N)rd   re   r   r   r   rb   )clonerQ  rK  rN  
isinstancer<   r   rx   r   r   rH  rR  rP  r{   )r]   rd   re   rT  r   attn_kwargsrb   rb   rc   rq     s*   


zDeepseekV2DecoderLayer.forwardNNr   )rr   rs   rt   r   rw   r   rx   ry   rP   rq   rz   rb   rb   r`   rc   r<    s4    Xr<  c                       s   e Zd ZdZdddedef fddZdejd	ejfd
dZ		ddejdB dejde
dB dejdB d	eje
B f
ddZ  ZS )DeepseekV2ModelFr>   rF   r?   rF   c                   s   t    jj}j}|| _tj| _|j	| _	t
|d| _| jr1|j}tjjj|tj| jd nd  t jrFt|j	|j|| dd| _nt | _t|j fdd| dd\| _| _| _t jrnt|j|jd	| _ nt | _ t!d
dg|j| _"d S )Nr  )r   devicez.embed_tokensrE   rF   c                    s   t |  dS )N)r   )r<  rZ  r   r?   rb   rc   <lambda>-  s    z*DeepseekV2Model.__init__.<locals>.<lambda>z.layersrZ  r   re   rT  )#rO   rP   r#  rG  rE   r@   r+   device_typer[  
vocab_sizer7  r8  r  rx   r   scheduler_configmax_num_batched_tokensint32r   is_first_rankr'   rA   embed_tokensr7   r:   num_hidden_layersstart_layer	end_layerlayersis_last_rankr   r   normr9   make_empty_intermediate_tensors)r]   r?   rF   r@   rE   r  r`   r]  rc   rP     sF   



zDeepseekV2Model.__init__	input_idsrG   c                 C   s
   |  |S r   )re  r]   rm  rb   rb   rc   embed_input_ids;  s   
zDeepseekV2Model.embed_input_idsNrd   intermediate_tensorsinputs_embedsc                 C   s   t  jr|d ur|}n| |}d }n|d usJ |d }|d }t| jdd }|d ur8t|d |d |d}nd }t| j| j| j	D ]}	|	||||\}}qCt  j
sZt||dS | ||\}}
|S )Nre   rT  r   r   beta)r   r   rd   )re   rT  )r   rd  ro  r   r@   r   r   ri  rg  rh  rj  r,   rk  )r]   rm  rd   rp  rq  re   rT  llama_4_scaling_configr   layerrk   rb   rb   rc   rq   >  s8   

zDeepseekV2Model.forwardr   )rr   rs   rt   fall_back_to_pt_during_loadr   rw   rP   rx   ry   ro  r,   rq   rz   rb   rb   r`   rc   rY  	  s     .rY  c                   @   sD   e Zd ZU ee ed< 	 dedB fddZdededdfd	d
ZdS )DeepseekV2MixtureOfExpertsmoe_mlp_layersexample_moeNc                 C   sz   |d u r#d| _ d| _d| _d| _d| _d| _d| _d| _t	d d S |j
| _|j| _|j| _|j| _|j| _|j| _d S )Nr   z9DeepSeekV2: No DeepseekV2MoE layer found in model.layers.)num_moe_layersnum_expert_groupsnum_logical_expertsnum_physical_expertsnum_local_physical_expertsnum_routed_expertsnum_shared_expertsr   loggerwarningr   r   r   r   r   r   )r]   rx  rb   rb   rc   extract_moe_parametersr  s    z1DeepseekV2MixtureOfExperts.extract_moe_parametersr|  r}  rG   c                 C   sT   | j |ksJ || _|| _ || j | _| jD ]}||_||_| j|_|j	  qd S r   )
r}  r|  r{  r   rw  r   r   r   r   update_expert_map)r]   r|  r}  moerb   rb   rc    update_physical_experts_metadata  s   
z;DeepseekV2MixtureOfExperts.update_physical_experts_metadata)	rr   rs   rt   listr   __annotations__r  rv   r  rb   rb   rb   rc   rv  l  s   
 rv  c                       s   e Zd ZdddgiZeZdddedef fdd	Zd
d Z	de
jde
jfddZ		dde
jdB de
jdedB de
jdB de
jeB f
ddZde
jde
jdB fddZdeeeeeef  fddZdeeee
jf  dee fddZ  ZS )DeepseekV2ForCausalLMr   	gate_projup_projr>   rZ  r?   rF   c                   s  t    |jj}|j}|| _|| _t|dd}t|dd}|jdkp-tdd ||fD | _	| j	r9g d| j
d< t|d	oB|jd u| _| jrNd
dg| j
d< | j|t|dd| _t jrlt|j|j|t|dd| _nt | _t|j| _| jj| _| jj| jj | _|   d S )Nr   r   r   r@  c                 s   rA  rB  rb   rC  rb   rb   rc   rE    rF  z1DeepseekV2ForCausalLM.__init__.<locals>.<genexpr>)r   k_projv_projrY   r   r   r   r3  model)r?   rF   lm_headr\  )rO   rP   r#  rG  rE   r@   r   rI  rJ  rK  packed_modules_mappingr7  r   fuse_qkv_a_proj	model_clsr;   r  r   rj  r&   r`  rA   r  r7   r   logits_processorrl  rf  rO  ry  set_moe_parameters)r]   r?   rF   r@   rE   r   r   r`   rb   rc   rP     sF   




zDeepseekV2ForCausalLM.__init__c                 C   s   g | _ t| jdd| _g | _g | _d }| jjD ]'}t|t	rqt|t
s&J t|jtr>|j}| j|j | j|jj q| | d S )Nr   r2   )expert_weightsr   r@   rz  
moe_layersrw  r  ri  rV  r7   r<  rP  r   appendr   r  )r]   rx  rt  rb   rb   rc   r    s   
z(DeepseekV2ForCausalLM.set_moe_parametersrm  rG   c                 C   s   | j |S r   )r  ro  rn  rb   rb   rc   ro    s   z%DeepseekV2ForCausalLM.embed_input_idsNrd   rp  rq  c                 C   s   |  ||||}|S r   )r  )r]   rm  rd   rp  rq  re   rb   rb   rc   rq     s   zDeepseekV2ForCausalLM.forwardre   c                 C   s   |  | j|}|S r   )r  r  )r]   re   logitsrb   rb   rc   compute_logits  s   z$DeepseekV2ForCausalLM.compute_logitsc                 C   s   t j| ddd| jjddS )Nr  r   r  r   ckpt_gate_proj_nameckpt_down_proj_nameckpt_up_proj_namer   r   )r   make_expert_params_mappingr@   r   r  rb   rb   rc   get_expert_mapping  s   z(DeepseekV2ForCausalLM.get_expert_mappingr0  c                 C   sT  t  }ddg}ddg}g d}| jr|| n|| tj| ddd| jj|r-| jjnd	 | j	d
}t
|  }t }|D ]g\}	}
d|	v rIq?t| j|	}|d urTq?|oYd|	v }|D ]E\}}}||	vrfq\d|	v ro|	|vroq\|rrq\|	||}|dkr||vrq\|}	|	dr|	|vrq\t|	| rq\||	 }|j}|||
|  nd}d}|rt| jddpd}d|	v r|
jdkrdnd	}|
j| }|| d	ksJ d| d| || }t|D ]}|	}|
}|rt|| |d | }|
jdkr|
| }n|d	kr|
|d d f }n|
d d |f }|	dd| jj|  }|D ]I}|\}}}}||vr-qd}|||}t|| r=q|| }ttdtf |j}||||||dd}|rf|s_|}	n||  n3q|rlq|	drx|	|vrxqt|	|}	|	d u rqt|	| rq||	 }t|dt}|||
 q|	d ur|s||	 q?|S )N)r   r  r   )r   r  r2   )r3  r   r   )r3  r   r2   ))rY   r   rl   )rY   r  rm   )rY   r  rn   r  r   r  r   r  zrotary_emb.inv_freqzmlp.shared_expertszmlp.experts.r3  z.biasFr2   r   zdown_proj.weightzShared expert weight dim z not divisible by num_chunks T.)shard_id	expert_idreturn_successweight_loader)r   r   rK  extendr   r  r@   r   r   r   dictnamed_parametersset#get_spec_layer_idx_from_weight_namereplaceendswithr8   r  r   ndimr   rangeslicetypingcastr   r   addr)   r(   )r]   r0  $rocm_aiter_moe_shared_expert_enabledstacked_params_mappingmla_params_mappingmha_params_mappingexpert_params_mappingparams_dictloaded_paramsnameloaded_weight
spec_layer"is_fusion_moe_shared_experts_layer
param_nameweight_namer  name_mappedparamr  is_expert_weight
num_chunks	split_dimtotal
chunk_sizej
chunk_nameweight_to_loadchunk_slicemappingr  successrb   rb   rc   load_weights  s   














z"DeepseekV2ForCausalLM.load_weightsrX  )rr   rs   rt   r  rY  r  r   rw   rP   r  rx   ry   ro  r,   rq   r  r  tuplerv   r  r   r  r  rz   rb   rb   r`   rc   r    s6    3

,r  c                   @      e Zd ZdS )DeepseekForCausalLMNrr   rs   rt   rb   rb   rb   rc   r        r  c                   @   r  )DeepseekV3ForCausalLMNr  rb   rb   rb   rc   r    r  r  c                   @   r  )GlmMoeDsaForCausalLMNr  rb   rb   rb   rc   r    r  r  r@   r  c                 C   sP   t | dr&| jdkr&| j}t| jD ]}|d||  dr%||   S qd S )Nnum_nextn_predict_layersr   zmodel.layers.r>  )r7  r  rf  r  
startswith)r@   r  rH  irb   rb   rc   r    s   
r  )r2   r2   )qru   r  collections.abcr   r   	itertoolsr   rx   r   transformersr   r   vllm._aiter_opsr   vllm.compilation.decoratorsr	   vllm.configr
   r   r   r   vllm.distributedr   r   r   r   r   vllm.loggerr   %vllm.model_executor.layers.activationr   $vllm.model_executor.layers.attentionr   /vllm.model_executor.layers.attention_layer_baser   $vllm.model_executor.layers.fused_moer   $vllm.model_executor.layers.layernormr   r   !vllm.model_executor.layers.linearr   r   r   r   r   +vllm.model_executor.layers.logits_processorr   vllm.model_executor.layers.mlar    r!   'vllm.model_executor.layers.quantizationr"   7vllm.model_executor.layers.quantization.utils.fp8_utilsr#   +vllm.model_executor.layers.rotary_embeddingr$   .vllm.model_executor.layers.sparse_attn_indexerr%   3vllm.model_executor.layers.vocab_parallel_embeddingr&   r'   -vllm.model_executor.model_loader.weight_utilsr(   r)    vllm.model_executor.models.utilsr*   vllm.platformsr+   vllm.sequencer,   vllm.v1.attention.backendr-   r%  r/   vllm.v1.kv_cache_interfacer0   r1   
interfacesr3   r4   r5   r6   utilsr7   r8   r9   r:   r;   rr   r  Moduler<   r{   r   r   r   rv   ry   r   r   r  r  r1  r<  rY  rv  r  r  r  r  rw   r  rb   rb   rb   rc   <module>   s   Q/ 

 &z : b
)  C