o
    Ti                     @   s  d dl Z d dlmZmZmZ d dl mZ d dlmZ d dlm	Z
 d dlmZ zd dlZd dlmZmZ eejZW n eyG   dZdZY nw d dlmZ dd	lmZmZ d
d Zdd Zde jde jde jde jdee je jf f
ddZ	d)dee j dee j de jde jdee je jf f
ddZG dd de jjZ G dd de j!j"Z#G dd dZ$G dd de j!j"Z%G dd  d e jjZ&e j'j(d!d" Z)e j'j(d#d$ Z*G d%d& d&e j!j"Z+G d'd( d(e j!j"Z,dS )*    N)OptionalAnyTuple)Tensor)version)get_accelerator)_flash_attn_forward_flash_attn_backward)	rearrange   )single_all_to_allapply_rotary_pos_embc                 C   s2   t | ddd} | jdd\}}tj|| fddS )Nz... (j d) -> ... j d   )jdim)r
   unbindtorchcat)xx1x2 r   Q/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/sequence/fpdt_layer.py_rotate_half_backward   s   r   c                 C   sl   |j d }| dd |f | d|d f }}|| t||  }|j d dkr+|}|S tj||fdd}|S )Nr   .r   r   )shaper   r   r   )grad_output	freqs_cos	freqs_sinrot_dimgrad	grad_passgrad_tr   r   r   apply_rotary_pos_emb_backward    s   
"r%   outlse	block_out	block_lsereturnc                 C   sh   | tj}|ddjdd}|tt||  }t|| |  t|| |  } |}| |fS )Nr   r   r   )tor   float32	transpose	unsqueezelog1pexp)r&   r'   r(   r)   new_lser   r   r   _update_out_and_lse(   s   $r2   c                 C   s   | d u r%|d urt d|tj} |ddd jdd }| |fS |d urH| | || }}t||||\}}||| |< ||< | |fS t| |||\} }| |fS )Nz4first update_out_and_lse should not pass slice_ argsr   r   r   r   r   )RuntimeErrorr+   r   r,   permute
contiguousr.   r2   )r&   r'   r(   r)   slice_	slice_out	slice_lser   r   r   update_out_and_lse:   s   r9   c                       s&   e Zd Zd fddZdd Z  ZS )FPDT_InputConstructr*   Nc	                    s   t t|   || _|| _|| _|| _|| _|jd }	|jd }
|	| dks(J |	|j	 dks1J |	|j	 }|	| }|| dksBJ || _
|| | _|| _|| _|	| _|| _|
| _|j| _d S )Nr   r   )superr:   __init__tokenslabels	loss_maskattention_maskposition_idsr   $ds_sequence_parallel_fpdt_chunk_sizenum_chunk_per_gpu
chunk_sizesp_sizesp_rankglobal_seq_lenlocal_seq_len
batch_sizedevice)selfr=   r>   r?   r@   rA   argsrE   rF   rG   rI   rC   rH   	__class__r   r   r<   Q   s*   



zFPDT_InputConstruct.__init__c                    s  | j }| j| j }tj| j|tjd| j }tj||tjd}|| jd 	 }|
 d	 }||k}|jdd}|d d df  |d d df t fddt|jd D }| jd urm| jd d |f n| j}|d| j| j| j | j| jd   
 	 }| jd d |f }	| jd ur| jd d |f n| j}
| jd ur| jn| j}| jd ur| jd d |f n| j}|	|
|||fS )	NrJ   dtyper   r   F)as_tupler   c                    s   g | ]} |k qS r   r   .0igather_indicestoken_chunk_indicesr   r   
<listcomp>w   s    z0FPDT_InputConstruct.generate.<locals>.<listcomp>)rJ   rG   rD   r   arangeintreshaperC   tr5   flattenr.   nonzeror   ranger   r?   rF   r=   r>   r@   rA   )rK   rJ   totalChunkstoken_chunk_idxchunk_to_gpugather_chunkmaskindicesload_balanced_loss_maskload_balanced_tokensload_balanced_labelsload_balanced_attention_maskload_balanced_position_idsr   rU   r   generatej   s6   $""

zFPDT_InputConstruct.generate)r*   N)__name__
__module____qualname__r<   rk   __classcell__r   r   rM   r   r:   O   s    r:   c                   @   4   e Zd ZdZe		d
defddZedd Zd	S )_FPDTGPUAttentionImpl_F   Tctxc           ,      C   s  |j }|d ur#|d dddd|d dddd}}|| _|| _nd | _d | _t U |jd }|| }|| |ks@J |d u sFJ || _|| _|| _	|| _
|| _t  }|| _|j| _|	| _|| _g }g }g }|
d | _|| _d| _d | _|jd }dd t|D }d	d t|D }t|D ]p}|| }|| }t||| | | } | d d d d d |	f  | jd | jd d
|
dddd }!t|!||d|}!|!jd }"|d urt|!|d d |"| |"|d  f |d d |"| |"|d  f }!||! | d d d d |	|	| f  | jd | jd d
|
dddd }#t|#||d|}#|d ur_t|#|d d |"| |"|d  f |d d |"| |"|d  f }#||# | d d d d |	| d f  | jd | jd d
|
dddd }$t|$||d|}$||$ tt|D ]`}%||%k}&tt !dkrt"|| ||% ||% | j| j|&| jd| jdd
\}'}(}(}(}(})}(}(nt"|| ||% ||% | j| j|&| j| jdd	\}'}(}(}(}(})}(}(t#|| || |'|)\||< ||< q|| $|!j||< qdd t|D }*t|D ].}|| d d d d d d df ddd ||< t|| $| j ||d||*|< qtj%|*dd}*|*jd
 }+|r{| &| || _'|| _(|| _)|| _*|| _+|+| _,|| _-|| _.|| _/W d    |*S W d    |*S 1 sw   Y  |*S )Nr   r   r            ࿩r   r   c                 S      g | ]}d qS Nr   rS   _r   r   r   rX          z2_FPDTGPUAttentionImpl_.forward.<locals>.<listcomp>c                 S   rw   rx   r   ry   r   r   r   rX      r{   r   2.6.0        Fcausalwindow_sizesoftcapalibi_slopesreturn_softmaxr   r   r   r   c                 S   rw   rx   r   rR   r   r   r   rX     r{   r   )0requires_gradr4   pos_emb_cospos_emb_sinr   no_gradr   
num_chunkscpu_offloadingspgscatter_idx
gather_idxr   current_device_namerJ   rP   projection_sizekv_projection_sizesoftmax_scale	dropout_pr   r   r_   matmulr\   r5   r[   r   r   appendlenflash_attn_versionr   parser   r9   r+   r   save_for_backwardglobal_qglobal_kglobal_vattn_outputattn_lsehead_dimrI   qkv_linear_weightqkv_linear_bias),rs   layernorm_outputr@   inference_paramsrotary_pos_embr   r   r   hidden_sizer   hidden_size_per_attention_headr   r   r   dropoutr   r   do_saver   r   per_gpu_seq_lenrD   rJ   r   r   r   rI   global_o
global_lserT   sted	qkv_chunkq_chunkglobal_q_chunk_lenk_chunkv_chunkk_icausal_chunkr(   rz   r)   outputr   r   r   r   forward   s   *






"

"
&
2(


t
ttz_FPDTGPUAttentionImpl_.forwardc           0         s  | j }| j | j| j}| j}| j}| j}| j}| j}| j	}	| j
}
| j}| jd | j| j}| j}| j}| j}| j}| j}jd |  fddt|D }g }|jd | }t|D ]}|| }|| }|t|d d ||f  ||d| q`~ fddt|D } fddt|D } fddt|D }tj|j|jtjd}tj|j|jtjd}t|D ]}|| }|| }t|D ]}||k } | rq||k}!| }"|| }#|| }$|| }%tjd j d	}&tj|d j d	}'tj|d j d	}(ttd
kr*t|%|"|||#|$|&|'|(|||!|d|	dd d nt|%|"|||#|$|&|'|(|||!||	dd d ||  |&!tj ||  |'!tj ||  |(!tj q|| jd })| j"d urt#|| !| j"d d |)| |)|d  f | j$d d |)| |)|d  f ||< n	|| !||< || !||< t||  ||d|||< t||  ||d|||< | }*|* }+|*|+ %djd },|| &d'ddd||< || &d'ddd||< || jd || jd }-}.||
|
|   t(|| %|-|. d) |, ||
| d   t(|| %|-|. d) |, ||
|
|   || *d*d ||
| d   || *d*d ||  t(|| ||
|
|   ||  t(|| ||
| d   d ||< d ||< qt|D ]}|| jd }/| j"d urt#|| !| j"d d |/| |/|d  f | j$d d |/| |/|d  f ||< n	|| !||< t|| ! ||d|||< d  %djd },d  || &d'ddd||< || jd || jd }-}.|d |
  t(|| %|-|. d) |, |d |
  || *d*d ||  t(|| |d |
  d ||< qtj+|dd!d d d d d d d d d d |!|!d d d fS )Nr   c                    s.   g | ]}t jjd  jd f dqS )r   r   rO   )r   zerosr   ry   )rJ   rP   input_chunk_sizer   r   r   rX   9  s    z3_FPDTGPUAttentionImpl_.backward.<locals>.<listcomp>r   c                    $   g | ]}t jd  jt j dqS r   rP   rJ   r   r   r   floatry   rJ   r   r   r   rX   J     $ c                    r   r   r   ry   r   r   r   rX   K  r   c                    r   r   r   ry   r   r   r   rX   L  r   rO   r   r|   r}   Fr   r   deterministic	rng_stater   r   r   r   r   r   ),r   rJ   rP   r   r   r   r   r   r   r   r   r   saved_tensorsr   r   r   r   r   r   r   r   r_   r   r   r5   r   r   r   r   r   r   r	   add_r+   r   r%   r   r[   r]   r4   r   r\   sumr   )0rs   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r'   r   r   grad_layernorm_outputgrad_global_attn_outputrD   rT   r   r   dqdkdvgrad_qkv_linear_weightgrad_qkv_linear_biasr   r   q_ino_computationr   r   attn_output_chunk	lse_chunkd_outdq_thisdk_thisdv_this
dk_seq_leninput_stinput_edinput_chunklb
dq_seq_lenr   )rJ   rP   r   r   r   r   backward  s2  
"  
&&
  
".""z_FPDTGPUAttentionImpl_.backwardNrr   Trl   rm   rn   generate_vmap_rulestaticmethodr   r   r   r   r   r   r   rq      s     rq   c                   @   sF   e Zd ZddejfddZdd Zdd	 Zd
d Zdd Z	dd Z
dS )SequenceChunkNFchunkc                 C   sx   |j | _|j| _|d u r|jn|| _tj|j |jddd}t |r+|j	|dd n|}|| _
|r7|| _d S d | _d S )NcpuTrP   rJ   
pin_memorynon_blocking)r   chunk_shaperP   chunk_dtyperJ   r   emptyr   on_acceleratorcopy_	cpu_chunk	gpu_chunk)rK   r   rJ   	is_in_user   r   r   r   r<     s   zSequenceChunk.__init__c                 C   sL   | j d u sJ | j d urd S tj| j| j| jd}|j| jdd || _ d S )NrO   Tr   )r   r   r   r   rJ   r   r   r   )rK   r   r   r   r   load_to_gpu  s   

zSequenceChunk.load_to_gpuc                 C   s"   | j d ur| j j| jksJ | j S rx   r   rJ   rK   r   r   r   get_gpu_chunk  s   zSequenceChunk.get_gpu_chunkc                 C   sB   | j d ur| j j| jksJ d| j d u d| j d| j j dS )Nz
gpu_chunk z shound be on z, but it is now on Tr   r   r   r   r   check_gpu_chunk  s
   
 zSequenceChunk.check_gpu_chunkc                 C   s*   | j d ur| j j| jksJ | ` d | _ d S rx   r   r   r   r   r   offload  s   
zSequenceChunk.offloadc                 C   s2   | j d ur| j j| jksJ | jj| j dd d S )NTr   )r   rJ   r   r   r   r   r   r   overwrite_to_cpu  s   zSequenceChunk.overwrite_to_cpu)NF)rl   rm   rn   r   r   r<   r   r   r   r   r   r   r   r   r   r     s    	r   c                   @   rp   ) _FPDTGPUOffloadingAttentionImpl_Frr   Trs   c           7      C   sd  |j }|d ur#|d dddd|d dddd}}|| _|| _nd | _d | _t X |jd }|| }|| |ks@J |d u sFJ || _|| _|| _	|| _
|| _|| _t  }|| _|j| _|	| _|| _g }g }g }|
d | _|| _d| _d | _|jd }g }g }g }g }t  }t  } t  }!d}"d}#t|D ]}$t|d | | | }%t |  |t|d |  W d    n1 sw   Y  ||d  }|%d d d d d |	f   |%jd |%jd d|
dddd }&t!|&||d|}&|&jd }'|%d d d d |	|	| f   |%jd |%jd d|
dddd }(t!|(||d|}(|%d d d d |	| d f   |%jd |%jd d|
dddd })t!|)||d|})t"#  | jd ur|d d |'|$ |'|$d  f }*|d d |'|$ |'|$d  f }+t$|&|*|+}&t$|(|*|+}(|!%| |!&  t |# |t|&dd	 |t|(dd	 |t|)dd	 W d    n	1 sw   Y  ~%d },d }-tt'|D ].}.|$|.k}/t |!e t(t)*d
krt+||" , ||# , ||# , | j| j|/| jd| jdd
\}0}1}1}1}1}2}1}1n%t+||" , ||# , ||# , | j| j|/| j| jdd	\}0}1}1}1}1}2}1}1t-|,|-|0|2\},}-W d    n	1 sVw   Y  d}3|.t'|d ksm|$|d kr|.t'|d kr{|.d }4nd}4|4|#krd}3n,|4t'|d krt | ||4 .  ||4 .  W d    n	1 sw   Y  |$|d kr|.|d krt |& |d .  |d .  |d .  |d .  |d .  W d    n	1 sw   Y  |!%| |!&  |3r||# /  ||# /  |4}#q||" /  |"d7 }"t!|,0| j ||d|}5||5 t | - |t|,0| j |t|-d d d d d d df ddd  W d    n	1 sdw   Y  q|!%|  |!&  tj1|dd}|jd }6W d    n	1 sw   Y  |r|| _2|| _3|| _4|| _5|| _6|| _7|6| _8|| _9|| _:|| _;|S )Nr   r   r   rt   ru   rv   r   Tr   r|   r}   Fr~   r   r   )<r   r4   r   r   r   r   r   r   r   r   r   r   rD   r   r   rJ   rP   r   r   r   r   r   r   Streamdefault_streamr_   r   r\   streamr   r   r5   r[   r   distbarrierr   wait_streamsynchronizer   r   r   r   r   r   r9   r   r   r+   r   r   r   r   r   r   r   r   rI   r   r   )7rs   r   r@   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rD   rJ   r   r   r   rI   r   r   layernorm_output_cpufinal_outputoffload_streamgeneral_offload_streamcompute_streamq_compute_chunk_idxkv_compute_chunk_idxrT   r   r   r   r   r   pos_emb_cos_chunkpos_emb_sin_chunkcur_attn_outputcur_attn_lser   r   r(   rz   r)   can_offload_kvnext_kv_compute_chunk_idxall2all_outputr   r   r   r   r     s^  *







""  








 



6
  &z(_FPDTGPUOffloadingAttentionImpl_.forwardc           0         s
  | j }|j | j| j}| j}| j}| j}| j}| j}| j	}	| j
}
| j}| j| j| j}| j}| j}| j}| j}| j}t  }t  }t  }|jd | }|d jjd ks^J  fddt|D }dd t|D }d}d}d}t |$ d   tj|j|jtjd}tj|j|jtjd}W d    n1 sw   Y  t|d d d |f   ||d|}t !  |d d |d f }t |G t"|dd|d< t"tjd j#tj d	ddg fd
dt|d D  }tj|d j#tj d	}tj|d j#tj d	} W d    n	1 sw   Y  t|D ]}!t|D ]}"|"|!k }#|#r:q.|"|!k}$tjd j# d	}%tj|d j# d	}&tj|d j# d	}'t |j t$t%&dkrt'|| ( | ( || ( || ( || ( || ( |%|&|'|||$|d|	dd d n,t'|| ( | ( || ( || ( || ( || ( |%|&|'|||$||	dd d W d    n	1 sw   Y  |!t)|d kr|"t)d kr|"d }(n|!d }(d})|(|krd})nt |t |!dks|"dkr|)r||!kr|| *  ||(   |(   ||(   ||(   ||( d urA||(   ||( d u rpt|d d d |f   ||d|}t+,  |d d |d f }t"|dd||(< W d    n	1 s{w   Y  |-| |!  t |  || .  || j/0|% |0|& | 0|' W d    n	1 sw   Y  |-| t | || 1  W d    n	1 sw   Y  |)r| *  || *  || *  || *  |}|(}q.|-| |!  |jd }*| j2d urbt3|| ( 4| j2d d |*|! |*|!d  f | j5d d |*|! |*|!d  f }+t3|4| j2d d |*|! |*|!d  f | j5d d |*|! |*|!d  f }n|| ( 4}+|4}| 4} t|+  ||d|}+t|  ||d|}t|   ||d|} |!  |-| t+,  t |
 |! ( 6d|! j#d },|+7d8ddd}+|7d8ddd}| 7d8ddd} |jd |jd }-}.|d |
 0t9|+6|-|. d: |, ||
|
|  0t9|6|-|. d: |, ||
| d  0t9| 6|-|. d: |, |d |
 0|+;d;d ||
|
|  0|;d;d ||
| d  0| ;d;d ||! 0t9|+|d |
  ||! 0t9|||
|
|   ||! 0t9| ||
| d   ~+~~ tj||! j#tj d	}tj||! j#tj d	} || *  d ||< W d    n	1 sw   Y  |!t)|d kr|d }/t | ||/   ||/   W d    n	1 sw   Y  t | |/   W d    n	1 sw   Y  |-| |!  | *  || *  || *  |/}q'tj<|dd4d d d d d d d d d d |4|4d d d fS )Nr   r   c                    s"   g | ]}t jd  j dqS )r   rO   )r   r   r   ry   )rJ   rP   r   r   r   rX     s    z=_FPDTGPUOffloadingAttentionImpl_.backward.<locals>.<listcomp>c                 S   rw   rx   r   ry   r   r   r   rX     r{   rO   Tr   r   c              	      s,   g | ]}t tjd  jtjddd qS )r   r   Tr   )r   r   r   r   r   ry   r   r   r   rX     s    r|   r}   Fr   r   r   r   r   )=r   rJ   rP   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r_   r   r   r   r   r   r   r5   r   r   r   r   r   r   r	   r   r   r   r   r   r   r   r   r   r   r   r%   r+   r   r[   r]   r4   r   r\   r   r   )0rs   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r'   r   r   r  r  r  rD   r   r   r  r  last_q_accum_idxr   r   grad_global_attn_output_chunkr   dk_accumdv_accumrT   r   r   r   r   r   r   next_q_compute_chunk_idxcan_offload_qr   dq_accumr   r   r   r  r   )rJ   rP   r   r   r   r     s  


	











%







  
  


""
 
z)_FPDTGPUOffloadingAttentionImpl_.backwardNr   r   r   r   r   r   r     s     Pr   c                       sH   e Zd Z					ddededdf fd	d
Z	ddefddZ  ZS )FPDT_Attentionr   r   T   r   r   r*   Nc                    s   t t|   td u std u rtd|| _|| _|| _|| _	|j
|j | _| j|j | _|j
|j | _|j| _|| _|| _|| _|| _|	| _|j| _|
| _|| _d S )NzoDeepSpeed FPDT requires flash-attn 2.6.3. Please install it with `pip install flash-attn --no-build-isolation`.)r;   r  r<   r   r	   ImportErrorr   r   r   configkv_channelsnum_attention_headsr   r   num_key_value_headsr   r   r   r   qkv_dense_weightqkv_dense_biasreture_biasattention_dropoutr   rD   double_buffer)rK   r  first_weight
first_biassecond_weightsecond_biassequence_process_groupr   r   return_biasrD   enable_offloadingrM   r   r   r<     s*   
zFPDT_Attention.__init__c                 C   s   |j d t| j | j | _|r| jdkr6t||||| j| j| j	| j
| j| j| j| j| j| j| j|}nt||||| j| j| j	| j
| j| j| j| j| j| j| j|}|dddd }t|| j }| jsr|| j7 }|| jrz| jfS d fS )Nr   r   r   )r   r   get_world_sizer   rD   num_chunks_attnrq   applyr   r   r   r   r   r   r   r   r   r   r]   r4   r5   r   r   r  r\   r  r  )rK   r   r@   r   r   r   r   r   r   r   r     s(   
zFPDT_Attention.forward)r   r   Tr  T)T)rl   rm   rn   rZ   r<   r   r   ro   r   r   rM   r   r    s"    	-r  c                 C   s*   | d dt d|  dd|  |      S )N      ?g      ? e3E?r   Hm?r   tanh)r   r   r   r   	bias_gelu  s   *r0  c                 C   sV   t d| dd| |   }d| d||  dd| |    dd|   }||  S )Nr,  r   r-  r+  6vf?r.  )gr   tanh_outffr   r   r   bias_gelu_back  s   0r5  c                   @   .   e Zd ZdZedefddZedd ZdS )FPDT_FFNFrs   c                 C   s<  |j }|| _|j}	t | |jd | }
|
| _tj|j|	|jd}||
 |jd ks-J t	|
D ]:}|| }|| }t
||| | | }t|}|r^t
|| | |||< nt
|| |||< ~q1|r|	| _|j| _| ||||| |j| _W d    n1 sw   Y  ||j|s|fS d fS )Nr   rO   )r   add_biasrJ   r   r   r   	num_chunkr   rP   r_   r   r\   r0  r   grad_x_shaper+   )rs   r   w1b1w2b2r8  rD   r   rJ   r9  resultrT   r   r   x_r   r   r   r   #  s2   
zFPDT_FFN.forwardc                 C   s@  | j \}}}}}| j}| j}	| j}
| j}|jd | }|| |jd ks&J tj|j|tjd}tj|j|tjd}tj|j|tjd}tj|j|tjd}t	|D ]}|| }|| }||| }t
|| | }|d }td| dd|   }d| d||  dd|    dd|   }|t
||| d	|jd  |d d|  d	|jd  ~~~t
||| || }~|t
|d	|jd  |d	|jd  ||dd ||| t
|| ~|
r
|||| dd qR|||	||	||	||	d d fS )
Nr   rO   r   r,  r   r-  r+  r1  r   )r   rJ   rP   r8  r9  r   r   r   r   r_   r   r\   r/  r   r[   r   r   r+   )rs   r   	grad_biasr   r;  r<  r=  r>  rJ   rP   r8  r9  rD   grad_w2grad_b2grad_w1grad_b1rT   r   r   x_chunk
before_actbefore_act_2r3  r4  
grad_interr   r   r   r   A  sR   

 $*zFPDT_FFN.backwardNr   r   r   r   r   r7     s    r7  c                   @   r6  )FPDT_LogitsLossFrs   c                 C   s  |  }|jd | }|| |jd ksJ |jd |jd }	}
tj|	|
ftj|jd}|| _|| _|j| _|j| _|| _	|
| _
t o t|D ]V}|| }|| }t||| |   }|d}tjjj|dd}tjjj| d| |||d d f d dd}|||	  |d d ||f< ~qJ| |d	| || _W d    n1 sw   Y  |
| }|d}	|   }tj||	|j|jd }tj|||d
 |S )Nr   r   r   r   r   r   none)	reductionr   )group)r\   r   r   r   r   rJ   r9  rD   rP   rankrH   r   r_   r   sizenn
functionalsoftmaxnll_losslogr[   r5   r   r+   logit_weightsr   allgather_fn)rs   	lm_outputr>   rU  rN  spg_sizer   r9  rD   rI   rH   lossrT   r   r   logits_chunk
vocab_sizerR  
loss_chunkseqlenloss_allr   r   r   r   t  sB   

 
zFPDT_LogitsLoss.forwardc                 C   s  | j \}}| j}| j}| j}| j}| j}| j}	| j}
||	|
 |	d |
  }dd t|D }t	j
|j|jt	jd}t|D ]}|| }|| }||| |}t	||  }t	jjj|dd}|d}|}|d| }t	jd| d |d	}|||||d d f d f  d8  < ||d |d d f jdd ||}||d   }t	||}|||< |t	|d|jd  |d|jd  q>t	j|dd|d ||d d d d fS )
Nr   c                 S   rw   rx   r   ry   r   r   r   rX     r{   z,FPDT_LogitsLoss.backward.<locals>.<listcomp>rO   r   r   r   r   )startendrJ   )r   rU  rJ   rP   r9  rD   rN  rH   r_   r   r   r   r   r+   r   r\   rP  rQ  rR  rO  r[   r5   rY   mul_r.   r   r   )rs   r   rW  r>   rU  rJ   rP   r9  rD   rN  rH   grad_lm_outputgrad_logit_weightsrT   r   r   lm_output_chunkrZ  rR  r[  
grad_inputgrad_2d	arange_1dgrad_lm_output_chunkr   r   r   r     sD   

."
(zFPDT_LogitsLoss.backwardNr   r   r   r   r   rJ  q  s    )rJ  rx   )-r   typingr   r   r   r   	packagingr   deepspeed.commcommr   deepspeed.acceleratorr   
flash_attnflash_attn.flash_attn_interfacer   r	   r   __version__r   r  einopsr
   layerr   r   r   r%   r2   r9   rP  Moduler:   autogradFunctionrq   r   r   r  jitscriptr0  r5  r7  rJ  r   r   r   r   <module>   sr   

7  J0   PH

Q