o
    }oi                     @   s  d dl Z d dlmZmZ d dlmZ d dlmZmZm	Z	 d dl
Z
d dlZ
d dlmZ d dl
mZ d dlmZ d dlmZ d d	lmZmZ d d
lmZ d dlmZ d dlmZmZmZmZ d dlm Z  d dl!m"Z"m#Z# d dl$m%Z% d dl&m'Z' d dl(m)Z) eG dd deZ*eG dd deZ+eG dd deZ,eG dd de,Z-e".edG dd de"j/def Z0e"1edG dd de"j/ed f Z2d!efd"d#Z3d$d% Z4d&d' Z5d(d) Z6d*e"j7fd+d,Z8d*e"j7fd-d.Z9d*e"j7fd/d0Z:d1d2 Z;d3ed4ed5e<d6e<d7e<d8e<fd9d:Z=d;ed3ed4ed5e<d6e<d7e<d8e<fd<d=Z>d>ed5e<d6e<d7e<d8e<f
d?d@Z?d5e<d6e<d7e<d8e<fdAdBZ@dCdD ZAdEdF ZBd*e"j7fdGdHZCd*e"j7fdIdJZDd*e"j7fdKdLZEdMdN ZFdOdP ZGdS )Q    N)	dataclassfield)Path)DictOptionalTuple)TransformerConfig)Tensor)MllamaConfig)MllamaForConditionalGeneration)MllamaTextConfigMllamaVisionConfig)AutoTokenizer)TokenizerSpec)CrossAttentionTextConfigCrossAttentionVisionConfigMLlamaModelMLlamaModelConfig)load_distributed_model_weights)ioteardown)_ModelState)dtype_from_hf)loggingc                   @   B   e Zd ZU edd dZee ed< edd dZee ed< dS )MLlamaConfig11Bc                   C      t  S Nr    r   r   \/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/vlm/mllama/model/mllama.py<lambda>/       zMLlamaConfig11B.<lambda>default_factorylanguage_model_configc                   C   
   t ddS )Ni  vision_chunk_sizer   r   r   r   r    r!   1      
 vision_model_configN	__name__
__module____qualname__r   r%   r   r   __annotations__r+   r   r   r   r    r   -   
   
 r   c                   @   r   )MLlamaConfig11BInstructc                   C   r   r   r   r   r   r   r    r!   7   r"   z MLlamaConfig11BInstruct.<lambda>r#   r%   c                   C   r&   )N0  r'   r)   r   r   r   r    r!   9   r*   r+   Nr,   r   r   r   r    r2   5   r1   r2   c                   @   r   )MLlamaConfig90Bc                   C   s   t ddddddS )N    i p  @   P      )hidden_sizeffn_hidden_sizenum_attention_heads
num_layersnum_cross_attention_layersr   r   r   r   r    r!   @   s    zMLlamaConfig90B.<lambda>r#   r%   c                   C   s   t dddS )Nr3   r5   )r(   text_hidden_sizer)   r   r   r   r    r!   I   s    r+   Nr,   r   r   r   r    r4   =   s   
 	r4   c                   @   s   e Zd ZdS )MLlamaConfig90BInstructN)r-   r.   r/   r   r   r   r    r?   M   s    r?   hfc                       s   e Zd ZdefddZddee def fddZdedefd	d
Zdd Z	e
dddZe
defddZdee fddZdee fddZ  ZS )HFMLlamaImporterreturnc                 C   s   t | j| jdS )N)	tokenizer)r   configrC   selfr   r   r    initT   s   zHFMLlamaImporter.initN	base_pathc                    s   t  |}|S r   )super
local_path)rF   rH   output_path	__class__r   r    rJ   W   s   zHFMLlamaImporter.local_pathrK   c                 C   sp   t jt| dd}t| }t|}|  }| |}| || | 	|| t
d|  t|| ~~|S )Nautotorch_dtypez/Converted Mllama model to Nemo, model saved to )r   from_pretrainedstr_rename_xattn_layer_nums_hf
state_dictr   rG   
nemo_setupconvert_state	nemo_saveprintr   )rF   rK   sourcerT   targettrainerr   r   r    apply\   s   

zHFMLlamaImporter.applyc                 C   s  i }g }| ddddddddd	d
dddd |tjddtdtjddtdtjddtdtjddtdtjddtdtjddtdtjddtdg d}| i d| d d!| d"d#| d$d%| d&d'| d(d)| d*d+| d,d-| d.d/| d0d1| d2d3| d4d5| d6d7| d8d9| d:d;| d<d=| d>d?| d@i dA| dBdC| dDdE| dFdG| dHdI| dJdK| dLdM| dNdO| dPdQ| dRdS| dTdU| dVdW| dXdY| dZd[| d\d]| d^d_d`dadb |tjdc| ddt	dtjde| dft	dtjdg| dht
dg tj||||diS )jNAlanguage_model.decoder.layers.*.self_attention.linear_proj.weightHlanguage_model.decoder.xattn_layers.*.cross_attention.linear_proj.weightElanguage_model.decoder.xattn_layers.*.cross_attention.linear_q.weight-language_model.decoder.final_layernorm.weight"language_model.output_layer.weight@language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight5language_model.decoder.layers.*.mlp.linear_fc2.weightKlanguage_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weightHlanguage_model.decoder.xattn_layers.*.cross_attention.k_layernorm.weightzPlanguage_model.decoder.xattn_layers.*.cross_attention.linear_q.layer_norm_weightHlanguage_model.decoder.xattn_layers.*.cross_attention.q_layernorm.weightzFlanguage_model.decoder.xattn_layers.*.mlp.linear_fc1.layer_norm_weightz;language_model.decoder.xattn_layers.*.mlp.linear_fc2.weight)5language_model.model.layers.*.self_attn.o_proj.weightz<language_model.model.xattn_layers.*.cross_attn.o_proj.weightz<language_model.model.xattn_layers.*.cross_attn.q_proj.weight language_model.model.norm.weightlanguage_model.lm_head.weight=language_model.model.layers.*.post_attention_layernorm.weight2language_model.model.layers.*.mlp.down_proj.weight4language_model.model.layers.*.input_layernorm.weightz<language_model.model.xattn_layers.*.cross_attn.k_norm.weightz:language_model.model.xattn_layers.*.input_layernorm.weightz<language_model.model.xattn_layers.*.cross_attn.q_norm.weightzClanguage_model.model.xattn_layers.*.post_attention_layernorm.weightz8language_model.model.xattn_layers.*.mlp.down_proj.weightz8language_model.model.xattn_layers.*.cross_attn_attn_gate/language_model.decoder.xattn_layers.*.gate_attn
source_key
target_keyfnz7language_model.model.xattn_layers.*.cross_attn_mlp_gate.language_model.decoder.xattn_layers.*.gate_ffnz5language_model.model.layers.*.self_attn.q_proj.weightz5language_model.model.layers.*.self_attn.k_proj.weightz5language_model.model.layers.*.self_attn.v_proj.weight@language_model.decoder.layers.*.self_attention.linear_qkv.weightz2language_model.model.layers.*.mlp.gate_proj.weightz0language_model.model.layers.*.mlp.up_proj.weight5language_model.decoder.layers.*.mlp.linear_fc1.weight)z<language_model.model.xattn_layers.*.cross_attn.k_proj.weightz<language_model.model.xattn_layers.*.cross_attn.v_proj.weightFlanguage_model.decoder.xattn_layers.*.cross_attention.linear_kv.weight)z8language_model.model.xattn_layers.*.mlp.gate_proj.weightz6language_model.model.xattn_layers.*.mlp.up_proj.weightz;language_model.decoder.xattn_layers.*.mlp.linear_fc1.weight(language_model.model.embed_tokens.weightz/language_model.embedding.word_embeddings.weightz)language_model.learnable_embedding.weightvision_model.vision_encoder@vision_model.global_transformer.layers.*.self_attn.o_proj.weight>.global_transformer.layers.*.self_attention.linear_proj.weight2vision_model.global_transformer.layers.*.gate_attn&.global_transformer.layers.*.gate_attn1vision_model.global_transformer.layers.*.gate_ffn%.global_transformer.layers.*.gate_ffn=vision_model.global_transformer.layers.*.input_layernorm.bias1.global_transformer.layers.*.input_layernorm.bias?vision_model.global_transformer.layers.*.input_layernorm.weight3.global_transformer.layers.*.input_layernorm.weightFvision_model.global_transformer.layers.*.post_attention_layernorm.bias3.global_transformer.layers.*.pre_mlp_layernorm.biasHvision_model.global_transformer.layers.*.post_attention_layernorm.weight5.global_transformer.layers.*.pre_mlp_layernorm.weight5vision_model.global_transformer.layers.*.mlp.fc1.bias0.global_transformer.layers.*.mlp.linear_fc1.bias7vision_model.global_transformer.layers.*.mlp.fc1.weight2.global_transformer.layers.*.mlp.linear_fc1.weight5vision_model.global_transformer.layers.*.mlp.fc2.bias0.global_transformer.layers.*.mlp.linear_fc2.bias7vision_model.global_transformer.layers.*.mlp.fc2.weight2.global_transformer.layers.*.mlp.linear_fc2.weight9vision_model.transformer.layers.*.self_attn.o_proj.weight7.transformer.layers.*.self_attention.linear_proj.weight6vision_model.transformer.layers.*.input_layernorm.bias*.transformer.layers.*.input_layernorm.bias8vision_model.transformer.layers.*.input_layernorm.weight,.transformer.layers.*.input_layernorm.weight?vision_model.transformer.layers.*.post_attention_layernorm.bias,.transformer.layers.*.pre_mlp_layernorm.biasAvision_model.transformer.layers.*.post_attention_layernorm.weight..transformer.layers.*.pre_mlp_layernorm.weight.vision_model.transformer.layers.*.mlp.fc1.bias).transformer.layers.*.mlp.linear_fc1.bias0vision_model.transformer.layers.*.mlp.fc1.weight+.transformer.layers.*.mlp.linear_fc1.weight.vision_model.transformer.layers.*.mlp.fc2.bias).transformer.layers.*.mlp.linear_fc2.bias0vision_model.transformer.layers.*.mlp.fc2.weight+.transformer.layers.*.mlp.linear_fc2.weightvision_model.class_embedding.class_embedding1vision_model.gated_positional_embedding.embedding.positional_embedding=vision_model.gated_positional_embedding.tile_embedding.weight'.gated_tile_positional_embedding.weight,vision_model.gated_positional_embedding.gate .gated_positional_embedding_gate vision_model.layernorm_post.bias.ln_post.bias"vision_model.layernorm_post.weight.ln_post.weightvision_model.layernorm_pre.bias.ln_pre.bias!vision_model.layernorm_pre.weight.ln_pre.weight<vision_model.post_tile_positional_embedding.embedding.weight%.post_tile_pos_embed.embedding.weight0vision_model.post_tile_positional_embedding.gate.post_tile_pos_embed.gate;vision_model.pre_tile_positional_embedding.embedding.weight$.pre_tile_pos_embed.embedding.weight/vision_model.pre_tile_positional_embedding.gate.pre_tile_pos_embed.gatemulti_modal_projector.bias+vision_model.vision_projection.encoder.biasmulti_modal_projector.weight-vision_model.vision_projection.encoder.weightz@vision_model.global_transformer.layers.*.self_attn.q_proj.weightz@vision_model.global_transformer.layers.*.self_attn.k_proj.weightz@vision_model.global_transformer.layers.*.self_attn.v_proj.weight=.global_transformer.layers.*.self_attention.linear_qkv.weightz9vision_model.transformer.layers.*.self_attn.q_proj.weightz9vision_model.transformer.layers.*.self_attn.k_proj.weightz9vision_model.transformer.layers.*.self_attn.v_proj.weight6.transformer.layers.*.self_attention.linear_qkv.weight#vision_model.patch_embedding.weight.conv1._linear.weightmapping
transforms)updateextendr   state_transform_import_gate_import_text_qkv_import_simple_concat_import_text_kv_import_embedding_hf_import_vision_qkv_import_patch_embedding_hfapply_transforms)rF   rY   rZ   r   r   vr   r   r    rV   m   s&  	8








	























 !"&		zHFMLlamaImporter.convert_stater   c                 C   s   t | t| S r   )r   save_hf_tokenizer_assetsrR   rE   r   r   r    rC         zHFMLlamaImporter.tokenizerc                 C   s2   ddl m} |t| }t| || |dS )Nr   )
AutoConfig)r%   r+   )transformersr   rQ   rR   r   _language_model_config_vision_model_config)rF   r   rY   r   r   r    rD     s   zHFMLlamaImporter.configc                 C   sn   dd }t |jjd||jj|jjt|jj|jj|jj|jj|jj	|jj
t|tjkt|tjkt|dS )Nc                 S   s   | t | S r   )len)num_hidden_layerscross_attention_layersr   r   r    _calculate_num_layers     zFHFMLlamaImporter._language_model_config.<locals>._calculate_num_layersr5   )rotary_base
seq_lengthr<   r=   r9   r:   r;   num_query_groups
vocab_sizefp16bf16params_dtype)r   text_config
rope_thetar   r   r   r9   intermediate_sizer;   num_key_value_headsr   r   torchfloat16bfloat16)rF   rY   r   r   r   r    r     s$   
z'HFMLlamaImporter._language_model_configc                 C   sJ   t |jj|jj|jj|jj|jj|jjt|t	j
kt|t	jkt|d	S )N)	r<   r9   r;   r(   vision_max_num_chunksr>   r   r   r   )r   vision_configr   r9   attention_heads
image_sizemax_num_tilesr   r   r   r   r   )rF   rY   r   r   r    r   $  s   z%HFMLlamaImporter._vision_model_configr   )rB   r   )r-   r.   r/   r   rG   r   r   rJ   r\   rV   propertyrC   r   rD   r   r   r   r   __classcell__r   r   rL   r    rA   R   s     
rA   r   c                   @   sx   e Zd ZdZejfdddZdedefddZd	d
 Z	e
dddZdedeeef fddZdd Ze
dddZdS )HFMLlamaExportera  
    Exporter class for converting NeMo MLlama model to HuggingFace format.

    Inherits:
        io.ModelConnector: Connector interface to handle setup, save, and load using the Lightning framework.

    Methods:
        init: Initializes a new HuggingFace MLlama model instance.
        apply: Converts the NeMo model to HuggingFace format and saves it.
        convert_state: Maps and transforms the state dictionary from NeMo to HuggingFace format.
        config: Generates and returns the HuggingFace MLlama config for the model.
    rB   r   c                 C   sF   ddl m} |  tj| j|dW  d   S 1 sw   Y  dS )a-  
        Initializes a HuggingFace MllamaForConditionalGeneration model.

        Args:
            dtype: The data type to use for the model (default: torch.bfloat16)

        Returns:
            MllamaForConditionalGeneration: A HuggingFace MLlama model initialized with the configuration.
        r   )no_init_weightsrO   N)transformers.modeling_utilsr   r   _from_configrD   )rF   dtyper   r   r   r    rG   A  s   
$zHFMLlamaExporter.initrK   c                 C   s   t d | | \}}t d t d |  }t d | |||}| }|| z	| jj| W n tyF   t 	d Y nw t
d|  |S )a2  
        Converts the NeMo MLlama model to HuggingFace format and saves it to the specified path.

        Args:
            output_path (Path): The path where the converted HuggingFace model will be saved.

        Returns:
            Path: The output path where the HuggingFace model was saved.
        z8Loading MLlama NeMo checkpoint. This may take a while...zMLlama NeMo checkpoint loaded.zInitializing the HF model..zStart Converting the model..zFailed to save tokenizerz Converted MLlama model saved to )r   info	ckpt_loadrG   rV   cpusave_pretrainedrC   	ExceptionwarningrX   )rF   rK   rY   source_configrZ   r   r   r    r\   P  s    





zHFMLlamaExporter.applyc                 C   s  |  ||}i }g }|ddddddddd	d
d
 |tjddtdtjddtdtjddtdtjddtdtjddtdtjddt	dg d}|i | dd| dd| dd| d d!| d"d#| d$d%| d&d'| d(d)| d*d+| d,d-| d.d/| d0d1| d2d3| d4d5| d6d7| d8d9| d:d;i | d<d=| d>d?| d@dA| dBdC| dDdE| dFdG| dHdI| dJdK| dLdM| dNdO| dPdQ| dRdS| dTdU| dVdW| dXdYdZd[d\d] |tj| d^d_t
dtj| d`dat
dtj| dbdctdg tj||||ddS )ea  
        Maps and transforms the state dictionary from NeMo to HuggingFace format.

        Args:
            source: The source NeMo model.
            target: The target HuggingFace model.

        Returns:
            The target HuggingFace model with the converted state.
        rg   rl   rj   rk   z6language_model.model.layers.*.cross_attn.q_norm.weightz6language_model.model.layers.*.cross_attn.q_proj.weightz6language_model.model.layers.*.cross_attn.k_norm.weightz6language_model.model.layers.*.cross_attn.o_proj.weightrh   ri   )
r]   rd   rb   rc   rf   r_   re   r^   r`   ra   rm   z2language_model.model.layers.*.cross_attn_attn_gatern   rr   z1language_model.model.layers.*.cross_attn_mlp_gatert   rs   rv   ru   rw   )z6language_model.model.layers.*.cross_attn.k_proj.weightz6language_model.model.layers.*.cross_attn.v_proj.weightry   rx   rz   r|   r{   r~   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )_modify_mllama_source_stater   r   r   r   _export_gate_export_text_qkv_export_simple_split_export_text_kv_export_embedding_hf_export_vision_qkv_export_patch_embedding_hfr   )rF   rY   rZ   r   r   r   r   r   r   r    rV   m  s  	/








	























 !"&		zHFMLlamaExporter.convert_stater   c                 C   s   t jt| ddjS )z~
        Gets the tokenizer from the loaded model context.

        Returns:
            The tokenizer specification.
        modelsubpath)r   load_contextrR   rC   rE   r   r   r    rC     s   zHFMLlamaExporter.tokenizerpathc              
   C   s  t jt| dd}|d }i }|jj}|jj}t|d }|D ]\\}}	d|v r)q |dd}
d|
v r[|		d	|ksA|		d	|kr[t
|		d	D ]}|	| ||
dd
t| < qHnd|
v rxt
|		d	D ]}|	| ||
dd
t| < qf|	||
< q ||fS )a  
        This function loads the state dict directly from a distributed checkpoint, and modify the state dict
        so that it is consistent with the key names you would get from loading the checkpoint into a model.
        This is a more memory-efficient method to obtain a state dict without initializing the nemo model.

        Args:
            path (Path): The path from which the model will be loaded.

        Returns
        -------
            Tuple[Dict, Dict]: The loaded state dict and the yaml config dict.
        model.configr	  weightsT_extra_statezmodule. layersr   layers.zglobal_transformer.layers)r   r  rR   r%   r<   r+   r   itemsreplacesizerange)rF   r  rD   dist_ckpt_folderrT   langauge_layersvision_layersdistributed_model_weightskr   new_kir   r   r    r     s&   $  
zHFMLlamaExporter.ckpt_loadc                    s   fdd}|j }|j|j  |j|j }d}i }t|D ]}|d  d  }	|d  d  dkr|	  d }
|| d|
 d|| d	| d< || d|
 d
|| d	| d
< || d|
 d|| d	| d< || d|
 d|| d	| d< q||	 d }|| d	| d|| d	| d< || d	| d
|| d	| d
< || d	| d|| d	| d< || d	| d|| d	| d< q| D ]\}}|||< qi }| D ]\}}d|v r||td||< q|||< qt|}|S )a#  
        - Modify state dict to integrate cross-attention layers into self-attention layer.
        e.g. 11B: 32 self-attn + 8 cross-attn -> 40 layers, 90B: 80 self-attn + 20 cross-attn -> 100 layers
        - Change the layer index to match the cross_attention_layers in the model config.
        e.g. 11B: [3, 7, 11, 15, 19, 23, 27, 31] -> [3, 8, 13, 18, 23, 28, 33, 38]

        Args:
            state_dict: Source model state dict
            source_config: Model config dict

        Returns:
            _ModelState: Modified state
        c                    sT   t | d}|d   }|d   dkr|| }d| dS td| d  d)N      r   .zUnexpected layer_num: z0 (does not align with cross_attention_frequency=))intgroup
ValueError)match	layer_numx_numnew_layer_numcross_attention_frequencyr   r    convert_layer_num9  s   zGHFMLlamaExporter._modify_mllama_source_state.<locals>.convert_layer_numzlanguage_model.decoderr  r  r   z.xattn_layers.z!.mlp.linear_fc1.layer_norm_weightz.layers.z.mlp.linear_fc2.weightz+.cross_attention.linear_q.layer_norm_weightz,.self_attention.linear_qkv.layer_norm_weightz.mlp.linear_fc1.weightxattn_layersz	\.(\d+)\.)	r%   r<   r=   r  popr  resubr   )rF   rT   r   r+  r   total_num_layerprefixnew_state_dictr  	cross_numxattn_index
attn_indexr  r   rY   r   r)  r    r   *  sX   


z,HFMLlamaExporter._modify_mllama_source_stateHFMllamaConfigc                 C   s   t jt| dd}|j}|j}t|j|j|j|j	|j
dd}dd t||jD }t|j|j|j |j||j|j|j|j|jddd	d
ddg ddd}t||ddS )z
        Generates the configuration for the HuggingFace MLlama model based on the NeMo model.

        Returns:
            HFMllamaConfig: A configuration object for the HuggingFace MLlama model.
        r  r	  r   )r   r9   r   r   r   rP   c                 S   s   g | ]\}}|| qS r   r   ).0r  xr   r   r    
<listcomp>  s    z+HFMLlamaExporter.config.<locals>.<listcomp>g       @g      @g      ?r5   llama3)factorhigh_freq_factorlow_freq_factor original_max_position_embeddings	rope_type)i i i	 )r   r   tie_word_embeddingsr   r9   r   r;   r   r   rope_scalingeos_token_idrP   )r   r   rP   )r   r  rR   r+   r%   r   r<   r9   r;   r(   r   	enumerate_init_fusion_scheduler=   r   r   #share_embeddings_and_output_weightsr:   r   r   r6  )rF   rY   r+   language_configr   r   r   r   r   r    rD   w  sD   
zHFMLlamaExporter.configN)rB   r   )rB   r   )rB   r6  )r-   r.   r/   __doc__r   r   rG   r   r\   rV   r   rC   r   r   r   r   rD   r   r   r   r    r   2  s     	"Mr   r   rY   c                    sN    fdd}d i }|   D ]\}}d|v r ||td||< q|||< q|S )Nc                    sb   t | d}|d  d  }|d  d  dkr%|  d }d| dS || d }d| dS )Nr  r  r   zxattn_layers.r   r  )r"  r#  )r%  r&  r3  r(  r)  r   r    r+    s   z6_rename_xattn_layer_nums_hf.<locals>.convert_layer_num   language_modelzlayers\.(\d+)\.)r  r.  r/  )rY   r+  output_dictr  r   r   r)  r    rS     s   

rS   c                 C   s   t j| | jd d ddS )Nr      dim)r   splitshapear   r   r    r     s   r   c                 C   s   |  | jd dS )Nr   reshaperO  rP  r   r   r    r     s   r   c                 C      | dd S Nr   r  r   gater   r   r    r     r   r   ctxc           	      C   6   | j jj}|j}|j}|j}|j}t|||||||S r   )rZ   rD   r+   r;   r   kv_channelsr9   
_merge_qkv)	rY  qr  r   r   head_numr   	head_sizer9   r   r   r    r        
r   c           	      C   rZ  r   )rZ   rD   r%   r;   r   r[  r9   r\  )	rY  r]  r  r   r   r^  r   r_  r9   r   r   r    r     r`  r   c                 C   s4   | j jj}|j}|j}|j}|j}t||||||S r   )rZ   rD   r%   r;   r   r[  r9   	_merge_kv)rY  r  r   r   r^  r   r_  r9   r   r   r    r     s   
r   c                 C      t j| |fddS )Nr   rL  r   cat)rQ  br   r   r    r     r   r   r  r   r^  r   r_  r9   c           	      C   s   |   }||f|dd   }| j| } |j| }tj| |fdd}|jdg|dd  R  }|jdks8J |j|jd d| ksFJ |j|jd |ksRJ |j|jd |d ks`J |j||d | |g}|S )Nr  rL  rR  r  r      )r  viewr   stackrT  ndimrO  )	r  r   r^  r   r_  r9   old_tensor_shapenew_kv_tensor_shape
kv_weightsr   r   r    ra    s   

ra  r]  c                 C   sl  || }|   }||f|dd   }	||f|dd   }
| j|	 } |j|
 }|j|
 }g }t|D ]<}|| || |d | d d d d f  ||||d d d d d f  ||||d d d d d f  q1t|}|jdks}J |j|jd |d | ksJ |j|jd |ksJ |j|jd |d ksJ |j|||d|   |g}|S )Nr  r  r   rf  )	r  rg  r  appendr   rd  ri  rO  rT  )r]  r  r   r^  r   r_  r9   heads_per_grouprj  new_q_tensor_shaperk  qkv_weights_lr  qkv_weightsr   r   r    r\    s&   	


,$&
 r\  kvc                 C   sd   d| }|  |||g}td|d}td|d}||  d| }	||  d| }
|	|
fS )Nrf  r   r  rR  )rT  r   aranger   )rr  r^  r   r_  r9   kv_total_dim	linear_kvk_slicev_slicek_projv_projr   r   r    	_split_kv(  s   rz  c                    s   ||  |d|  }|  |||g}t fddt|D }t | d }t d | d }	||  d| }
||  d| }||	  d| }|
||fS )Nrf  c                    s,   g | ]}t  d  |  d  |   qS )rf  )r   rs  )r7  r  rn  r   r    r9  B  s    z_split_qkv.<locals>.<listcomp>r  rR  )rT  r   rd  r  rs  r   )qkvr^  r   r_  r9   qkv_total_dim
linear_qkvq_slicerv  rw  q_projrx  ry  r   r{  r    
_split_qkv<  s   

r  c                 C   rU  rV  r   rW  r   r   r    r  Q  r   r  c                 C   s   |  | jd dddS )Nr   r     rS  rP  r   r   r    r  U  s   r  c                 C   s4   | j jj}|j}|j}|j}|| }t|||||S r   )rZ   rD   r   r   r9   r  )rY  r|  r   r^  r   r9   r_  r   r   r    r  Y     
r  c                 C   4   | j jj}|j}|j}|j}|| }t|||||S r   )rZ   rD   r   r;   r   r9   rz  )rY  rr  r   r^  r   r9   r_  r   r   r    r  c  r  r  c                 C   r  r   )rZ   rD   r   r;   r   r9   r  )rY  r|  r   r^  r   r9   r_  r   r   r    r  m  r  r  c                 C   s   t j| ddd\}}||fS )z\Splits NeMo's fused MLP linear_fc1 weight into gate_proj and up_proj for HuggingFace format.rf  r   rL  )r   chunk)
linear_fc1	gate_projup_projr   r   r    r  w  s   r  c                 C   rb  )z?Transforms the word embeddings from NeMo to HuggingFace format.r   rL  rc  )word_embeddingslearnable_embeddingr   r   r    r  }  r   r  )Hr.  dataclassesr   r   pathlibr   typingr   r   r   r   torch.distributedmegatron.core.transformerr   r	   r   r
   r6  r   /transformers.models.mllama.configuration_mllamar   r   =nemo.collections.common.tokenizers.huggingface.auto_tokenizerr   1nemo.collections.common.tokenizers.tokenizer_specr   &nemo.collections.vlm.mllama.model.baser   r   r   r   .nemo.export.trt_llm.nemo_ckpt_loader.nemo_filer   nemo.lightningr   r   nemo.lightning.io.stater   nemo.lightning.pytorch.utilsr   
nemo.utilsr   r   r2   r4   r?   model_importerModelConnectorrA   model_exporterr   rS   r   r   r   TransformCTXr   r   r   r   r"  ra  r\  rz  r  r  r  r  r  r  r  r  r   r   r   r    <module>   s   
 
`  w




"



