o
    }oi                     @   sV  d dl Z d dlZd dlm  mZ d dlmZmZ d dl	m
Z
mZmZmZmZmZmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlmZmZmZ d d	lmZ zd d
lm Z m!Z! d dl"m#Z$ dZ%W n e&e'fy{   dZ%e  Z( Z  Z!Z)Y nw zd dl*m+Z+m,Z,m-Z- dZ.W n e&e'fy   eZ+dZ.Y nw zd dl/m0Z1 W n e&e'fy   dZ1Y nw zd dl2m3Z3m4Z4 d dl5m6Z6 dZ7dZ0W nI e&e'fy   zd dl8m0Z0 d dl5m9Z6 dZ7W n e&e'fy   dZ7d\Z6Z0d\Z4Z3Y nw zd dl8m:Z: W n e&e'fy
   dZ:Y nw Y nw 	 G dd deej;Z<G dd deZ=G dd deZ>dS )    N)	rearrangerepeat)AdapterNameInfusedAdapterConfigLoraDenseAttentionAdapterConfigLoraKQVAdapterConfigLoraKQVAdapterWeightTyingConfigLoraKVAdapterConfigLoraQAdapterConfig)MatchedScaleMaskSoftmax)MegatronModule)XPOSPositionEmbedding)apply_rotary_pos_emb)ApexGuardDefaults_cast_if_autocast_enabledattention_mask_func)adapter_mixins)AttnMaskTypeAttnType)divideTF)ModelParallelConfigparallel_statetensor_parallel)flash_attn_func)	pad_inputunpad_input)flash_attn_unpadded_func)flash_attn_varlen_func)NN)flash_attn_with_kvcachec                       s   e Zd ZdZejejdddddddddddddfdef fd	d
Z					dddZ
dd Zdd Z									dddZ  ZS )ParallelAttentionzParallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
       FNT皙?learned_absoluteconfigc                    s  t t| j|d td|| _|| _|| _|| _|| _|| _	|| _
|| _| tjtjtjtjtjtjg |d u rG|| dksCJ d|| }|| }t }t||| _t||| _| jt  | _|tjkrwtj|d| |d||d| _n!|tj ks~J tj|||d||d| _!tj|d| |d||d| _"t#|| j||| j| j|	|
|||||||d	| _$tj%|||d
|d
|d| _&|| _'|rt(j)j*t(+d| jddd
d| _,d | _-d | _.d| _/|| _0d S )Nr#      r   Khidden_size must be divisible by num_attention_heads if kv_channels is None   F)r#   gather_outputinit_methodbias   )r#   layer_numbernum_attention_headshidden_sizeattention_typeattn_mask_type	precisionapply_query_key_layer_scalingkv_channelsmasked_softmax_fusionattention_dropoutmulti_query_attentionnormalize_attention_scoresposition_embedding_typeuse_flash_attentionT)r#   input_is_parallelr)   skip_bias_addr*   )requires_grad)1superr   __init__maxr,   r/   r0   r7   r8   r6   r9   megatron_legacyset_accepted_adapter_typesr   _target_r   r
   r	   r   r   r   $get_tensor_model_parallel_world_sizesafe_dividehidden_size_per_attention_head!num_attention_heads_per_partitionget_tensor_model_parallel_rank$num_attention_heads_partition_offsetr   	self_attnr   ColumnParallelLinearquery_key_value
cross_attnquery	key_valueCoreAttentioncore_attentionRowParallelLineardense	headscaletorchnn	Parameteroneshead_scale_tensorinference_key_memoryinference_value_memoryinference_current_sequence_len
layer_type)selfr#   r)   output_layer_init_methodr,   r-   r.   r/   r0   r1   r2   r3   r4   r5   r\   r@   r*   rS   r8   r6   r7   r9   projection_size
world_size	__class__ j/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/nlp/modules/common/megatron/attention.pyr>      s   

			

zParallelAttention.__init__c	                    sV    fdd}	|du r|f}
n|d |d f}
t j|	d||||g|
||R  }|S )z-Forward method with activation checkpointing.c            	   
      s   t | dkr#| d }| d }| d }| d }| d }| d }| d }n+t | d	krJ| d }| d }| d }| d }| d | d f}| d }| d }ntd
j||||||| d}|S )N   r   r%   r+   r'               zunexpected number of inputsrotary_pos_embrelative_position_biasheadscale_tensorinference_mode)len
ValueErrorrP   )	inputsquery_layer	key_layervalue_layerattention_maskrk   rl   rm   output_rn   r]   rc   rd   custom_forward  s8   


zIParallelAttention._checkpointed_attention_forward.<locals>.custom_forwardNr   r%   F)r   
checkpoint)r]   rr   rs   rt   ru   rk   rl   rm   rn   rx   	rot_tuplehidden_statesrc   rw   rd   _checkpointed_attention_forward  s&   	z1ParallelAttention._checkpointed_attention_forwardc                 C   s   t j||| j| j||dS )Ndtypedevice)rT   emptyrF   rE   )r]   inference_max_sequence_len
batch_sizer~   r   rc   rc   rd   _allocate_memoryG  s   z"ParallelAttention._allocate_memoryc                 C   s   |  }|r"	 |d d || j| jf }|j| }|dd }n	 |d d | j| j|f }|j| }|dd }|j| }|S )N)sizerF   rE   view	transpose
contiguous)r]   mixed_layer
num_splitsnum_splits_firstinput_shapeintermediate_shaperc   rc   rd   _transpose_last_dimQ  s(   




z%ParallelAttention._transpose_last_dimc           -      C   s~  |r)|r|dks
J |  ||d|j|j| _|  ||d|j|j| _d| _|r@| j| jdks6J || jdks@J |sHd | _d | _| jtj	kr| 
|\}}|  rq| tj}|rq| jtj d rq||}|| }| d d | jd| j f }| jr| |dd}|j| }tj|ddd\}}}n|d u s| j|k r| |\}}|  r| tj}|r| jtj d r||}|| }| d d | jd| j f }| jr| |dd}|j| }tj|ddd\}}n%| jd | jd	f }| jd | jd	f }|d ur|d	dd d f d
}| |\}}|  r=| tj}|r=| jtj d r=||}|| }| d d | j| jf }|j| }|  r| tj}| tj}|r| jtj d r|d usvJ d|j}|||d |d d|}|r| jtj d r|d usJ d|j}|||d |d d|}|d urt |t!r|n|fd }|rP| j|k rP| j}|  j|d7  _| j}|| j||d	f< || j||d	f< | jd |d	f }| jd |d	f }|d ur| jtj	kr|d	||d |f }|d urP|\}} |s,||d | }n|d |d d d d d d f }| d |d d d d d d f } || f}|d urq|\}!}"t"j#|!$||fdd}t"j#|"$||fdd}|rx||f}#t%d ur| j&r|d ur|r|st't(t)||d d}$t't(t)||d d}%t't(|d}&t%|$|%|&| j*t+j,kd}'t(|'d}'nJ|
r| j-||||||	| j.r| j/nd |d uo|jd dkd}'n(| j0||||||||	| j.r| j/nd |d uo|jd dk|d}'|r|'\}'}(| 1|'\})}*|  r-| tj2}+|+r-| jtj2 d r-|+|'},|)|, })|r4|)|#g})|r;|)|(g})|)|*fS )Nr   r%   enabledr   r'   T)contiguous_split_chunksr+   .r   z)Expected value_infused_adapter not found!z'Expected key_infused_adapter not found!dimsq b np hn -> b sq np hnsk b np hn -> b sk np hn)qk_cachev_cachecausalzb sq np hn -> sq b (np hn)rj   )
layer_pastget_key_valuerk   rl   rm   rn   return_scores)3r   r   r~   r   rY   rZ   r[   r/   r   rI   rK   is_adapter_availableget_adapter_moduler   LORA_KQV_ADAPTERadapter_cfgrF   rE   r@   r   r   r   split_tensor_along_last_dimrN   LORA_KV_ADAPTER	unsqueezerM   LORA_Q_ADAPTERKEY_INFUSEDVALUE_INFUSEDshapereshape
isinstancetuplerT   cattype_asr   r9   r   r   r   r0   r   r   r|   rS   rX   rP   rR   LORA_DENSE_ATTENTION_ADAPTER)-r]   r{   ru   r   r   encoder_outputset_inference_key_value_memoryr   rk   rl   checkpoint_core_attentionr   mixed_x_layer_lora_kqv_adapterlora_mixed_x_layernew_tensor_shaperr   rs   rt   mixed_kv_layerlora_kv_adapterlora_mixed_kv_layerlora_q_adapterlora_q_layerkey_infused_adaptervalue_infused_adapterklsvlsstartend	q_pos_emb	k_pos_embpast_key
past_valuepresentr   kvcontext_layerattention_probsoutputr*   lora_dense_adapterlora_dense_outputrc   rc   rd   forwards  s@  






  

  




zParallelAttention.forward)NNNN)	NFNFNNNFF)__name__
__module____qualname____doc__r   rI   r   paddingr   r>   r|   r   r   r   __classcell__rc   rc   ra   rd   r   {   sL     
?
&r   c                       sP   e Zd ZdZ										ddef fd	d
Z					dddZ  ZS )ParallelChunkedCrossAttentionzParallel chunked cross-attention layer class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    r    FNTr!   @   r#   c                    s   t t| j|d tdi d|d|d|d|d|d|dtjd	tjd
|d|d|	d|
d|d|d|d|d|| _|| _	d S )Nr$   r#   r)   r^   r,   r-   r.   r/   r0   r1   r2   r3   r4   r5   r@   r*   rS   r7   rc   )
r=   r   r>   r   r   rL   r   r   cross_attention
chunk_size)r]   r#   r)   r^   r,   r-   r.   r1   r2   r3   r4   r5   r@   r   r*   rS   r7   ra   rc   rd   r>   m  sJ   	

z&ParallelChunkedCrossAttention.__init__c                 C   s  |rt d|}| j}	|jd |jd |jd }
}}| jjj}|r-||	 |	 }|| _nb|d ur|dks7J |  j|7  _| j|	 }|dkrNt||fS |	d }t	j
|dddd|dfdd}||d |d d d d d d d d f }t|d|
d}|d d |d f }|	}n||	 |	 }|| jk r|r|d urt||fS |jd	 |jd
 }}|	d }t	j
|dddd| |fdd}|d | ||d  }}|jd }|d ur|\}}|d ur|s| jd |	 }t	j
|dddddd| | | | fdd}nt	j
|dddddd| dfdd}t|d|d}||f}|jd |	 |ks&J t|d|d}t|d}| j||||d\}}t|d|
d}t	j
|dddd|| | fdd}|sc|d urc|dd  }||fS )Nz^checkpoint_core_attention during forward not implemented yet for ParallelChunkedCrossAttentionr%   r   r+           )valuez(b k) 1 q v -> b k 1 q v)bzn b h d -> (r n) b h d)rz(k n) b d -> n (b k) d)r   zk r n b d -> (r n) (b k) d)r   rk   zn (b k) d -> (k n) b dr   )rp   r   r   r   rR   r*   current_lenrT   
zeros_likeFpadr   r   )r]   r{   ru   r   r   r   rk   r   contextr   r   nr   default_bias	seq_indexchunk_idcausal_padding
num_chunksnum_retrievedxx_remainderseq_remain_lenr   r   	token_posoutr*   rc   rc   rd   r     sn   



*

""
"z%ParallelChunkedCrossAttention.forward)
r    FNTr!   Fr   TFT)NFNNF)r   r   r   r   r   r>   r   r   rc   rc   ra   rd   r   f  s*    -r   c                       s   e Zd ZdZejejdddddddddfdef fd	d
Z								dddZ
dd Z	dddZdd Zdd Zdd Z  ZS )rO   zRegion where selective activation recomputation is applied.
    See Figure 3. in Reducing Activation Recomputation in Large Transformer Models
    https://arxiv.org/pdf/2205.05198.pdf for more details.

    r    FNTr!   r"   r#   c                    s  t t| j|d || _d| _d| _|dv rd| _n|dv r!d| _|| _|| _|| _d| _	| jr3d| _	t
d|| _|| _|| _|j| _|| _|	d u rX|| dksTJ d|| }	|	| }t }t||| _t||| _t||| _| jt  | _d }t| j| _| jr| j}|  j|9  _t| j| j| j|
t| j	|| _|| _tj !|| _"|r| j#| _$n| j%| _$|& d	krt'|	| _(d S d S )
Nr$   F)bf16z
bf16-mixedT)r    16z16-mixedr%   r   r&   xpos))r=   rO   r>   r1   fp16r   r6   r8   r2   attention_softmax_in_fp32r?   r,   r/   r0   sequence_parallelr7   r   rC   rD   hidden_size_per_partitionrE   rF   rG   rH   mathsqrtnorm_factorr   r   scale_mask_softmaxattention_dropout_prT   rU   Dropoutr5   flash_attentionattn_fntorch_attentionlowerr   r   )r]   r#   r,   r-   r.   r/   r0   r1   r2   r3   r4   r5   r7   r6   r8   r9   r_   r`   coeffra   rc   rd   r>     sj   
zCoreAttention.__init__c              	   C   s  | d| d| d| d| df\}}}}}|rOt & |d ur5|d|d d |f d}n|dd |d |f }W d    n1 sJw   Y  |d urj|d d | j| j| j | d | d f }|d ur||\}}t||}t||}| j dkr| j	||j
d |j
d  dd	}| j	|dd
d	}|s| ||||||
}n| j||||||
|d}|\}}|	d ur||	 }|dddd }|  d d | jf }|j| }|r||fS |S )Nr%   r+   r   r'   .r   r   F)offset	downscaleT)r   )r   rT   no_gradr   rH   rF   r   r8   r   r   r   r   torch_attention_with_priorpermuter   r   r   )r]   rr   rs   rt   ru   r   r   rk   rl   rm   rn   r   r   npsqskhnr   r   r   r   new_context_layer_shaperc   rc   rd   r   g  sr   
	

 

	
zCoreAttention.forwardc                 C   sH  |j \}}}	}
|j d }| jrt|d}t|d}t|d}nt|d}t|d}t|d}tj|j d |j d |j d |j|jd	}tj|||d
| jrQd| j	 ndd}|
||	||}|d ure||7 }| ||}| jstj   | |}W d    n1 sw   Y  n| |}t|d}t||}t|d|	d}|S )Nr   sq b np hn -> b (np sq) hnsk b 1 hn -> b hn sksv b np hn -> (b np) sv hnsq b np hn -> (b np) sq hnsk b np hn -> (b np) hn skr%   r+   r}   r         ?betaalphab np sq sk -> (b np) sq sk(b np) sq hn -> b np sq hnr  )r   r6   r   rT   r   r~   r   baddbmmr7   r   r   r   r   r   randomget_cuda_rng_trackerforkr5   bmm)r]   rr   rs   rt   ru   attention_biasrn   r  r   r  r
  r	  matmul_input_buffermatmul_resultattention_scoresr   r   rc   rc   rd   r     sJ   





	

zCoreAttention.torch_attentionc                 C   s^  |j \}}	}
}|j d }| jrt|d}t|d}t|d}nt|d}t|d}t|d}tj|j d |j d |j d |j|jd	}tj|||d
| jrQd| j	 ndd}|
|	|
||}|d urjtj|dd| }| ||}| jstj   | |}W d    n1 sw   Y  n| |}t|d}t||}t|d|
d}|r||fS |S )Nr   r  r  r  r  r  r%   r+   r}   r   r  r  r   r   r  r  r  )r   r6   r   rT   r   r~   r   r  r7   r   r   log_softmaxr   r   r   r  r  r  r5   r  )r]   rr   rs   rt   ru   r  rn   r   r  r   r  r
  r	  r  r  r   _attention_probsr   r   rc   rc   rd   r  
  sN   





	

z(CoreAttention.torch_attention_with_priorc                 C   s~   t |d}t |d}t |d}t|}t|}t|}t|}| jtjko'| }|d ur6| ||||||S | |||||S )Nr   r   zsv b np hn -> b sv np hn)r   r   r0   r   r   flash_attention_tritonflash_attention_cuda)r]   rr   rs   rt   ru   r  rn   	is_causalrc   rc   rd   r   I  s0   


	zCoreAttention.flash_attentionc                 C   s  |j \}}}}	|d u r(tj||j d |jd }
tj||j d |jd }n3t|j dkrLtjt|dddd}
tjt|dddd}nt|j dksUJ | }
| }t|
j	dtj
d	 }t|j	dtj
d	 }|dkr|dkrtd urt|||| jr| jnd
|d}n5t||
\}}}}t||\}}	}}t||\}}	}	}	t|||||||| jr| jnd
|d	}t||||}|dddd}|S )Nr%   )r   rf   Fr'   r   r+   r   )r   r~   r   )	dropout_pr   r   )r   rT   rW   r   boolro   anyeqsqueezesumint32uniquer   trainingr   r   r   r   r  )r]   rr   rs   rt   ru   r%  r   seqlennheadsr   attention_mask_qattention_mask_kvseqlens_q_in_batchseqlens_kv_in_batchr   r   	indices_qcu_seqlens_qmax_seqlen_qr   cu_seqlens_kmax_seqlen_kr   rc   rc   rd   r$  h  sJ   z"CoreAttention.flash_attention_cudac           
      C   s$  | j dkr	td|d urxt|jdkr1tjt|dddd}tjt|dddd}nt|jdks:J | dd}| dd}|jd |jd krb|| t	|j
j}|jd |jd krx|| t	|j
j}t|||||}	|	d	ddd}	|d ur|	| }	|	S )
Nr   zIattention_dropout not implemented for flash_attention with attention biasrf   Fr'   r   r+   r%   r   )r   NotImplementedErrorro   r   rT   r(  r)  r   masked_fillfinfor~   minflash_attn_func_tritonr  )
r]   rr   rs   rt   ru   r  r%  r1  r2  r   rc   rc   rd   r#    s0   
	z$CoreAttention.flash_attention_triton)NFNNNNN)F)r   r   r   r   r   rI   r   r   r   r>   r   r   r  r   r$  r#  r   rc   rc   ra   rd   rO     s<    `
k9
?4rO   )?r   rT   torch.nn.functionalrU   
functionalr   einopsr   r   Gnemo.collections.nlp.modules.common.megatron.adapters.parallel_adaptersr   r   r   r   r   r	   r
   :nemo.collections.nlp.modules.common.megatron.fused_softmaxr   3nemo.collections.nlp.modules.common.megatron.moduler   ?nemo.collections.nlp.modules.common.megatron.position_embeddingr   Ynemo.collections.nlp.modules.common.megatron.position_embedding.rotary_position_embeddingr   2nemo.collections.nlp.modules.common.megatron.utilsr   r   r   	nemo.corer   apex.transformer.enumsr   r   apex.transformer.utilsr   rD   	HAVE_APEXImportErrorModuleNotFoundError	ModelType	LayerTypemegatron.corer   r   r   HAVE_MEGATRON_COREflash_attn.flash_attn_tritonr   r>  flash_attn.bert_paddingr   r   flash_attn.flash_attn_interfacer   HAVE_FLASH_ATTENTION
flash_attnr   r   AdapterModuleMixinr   r   rO   rc   rc   rc   rd   <module>   sz   $	   n !