o
    TiS                     @   s   d Z ddlZddlZddlmZ ddlmZ ddlT ddl	m
Z
mZmZmZmZ ddl	mZmZ ddlmZ G d	d
 d
eZG dd deZG dd deZdZdZG dd deZdS )zC
Functionality of swapping tensors to/from (NVMe) storage devices.
    N)comm)logger)*)swap_in_tensorsswap_out_tensorsMIN_AIO_BYTESAIO_ALIGNED_BYTESget_sized_buffers)SwapBufferManagerSwapBufferPool)get_acceleratorc                   @   s   e Zd Zdd ZdS )FlattenedTensorSwapInfoc                 C   s   || _ || _|| _d S N)pathoffsetlength)selfr   r   r    r   a/home/ubuntu/.local/lib/python3.10/site-packages/deepspeed/runtime/swap_tensor/optimizer_utils.py__init__   s   
z FlattenedTensorSwapInfo.__init__N)__name__
__module____qualname__r   r   r   r   r   r      s    r   c                   @   s$   e Zd Zdd Zdd Zdd ZdS )SwapTensorContextc                 C   s0   || _ t | _tj|t| d| _	d S )N.tensor.swp)
compute_tensortorchTensorswap_tensorosr   joinOptimizerSwapperparameter_id	swap_path)r   tensorswap_folderr   r   r   r      s   
 zSwapTensorContext.__init__c                 C   s   t  | j_t  | j_d S r   )r   r   r   datar   r   r   r   r   release_memory$   s   z SwapTensorContext.release_memoryc                 C   s   |j | j_ |j | j_ d S r   )r&   r   r   )r   compute_bufferswap_bufferr   r   r   set_buffers(   s   
zSwapTensorContext.set_buffersN)r   r   r   r   r(   r+   r   r   r   r   r      s    r   c                   @   s   e Zd Zdd Zdd Zdd Zdd Zd	d
 Zdd Zdd Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zdd Zdd  Zd!d" Zd#d$ Zd%d& Zd'd( Zd)S )*OptimizerStateSwapInfoc                 C   sV   g | _ t|| _|| _i | _i | _|| _|j| _	|j
| _d| _g | _| |g d S )NF)tensorsr!   r"   param_idr%   swapped_gradientsunswapped_gradientstensor_numeldtypetensor_dtypedevicetensor_devicehas_state_tensorsswap_buffers_add_tensors)r   	parameternumelbase_folderr   r   r   r   /   s   zOptimizerStateSwapInfo.__init__c                 C      | j S r   )r1   r'   r   r   r   r:   <      zOptimizerStateSwapInfo.numelc                 C   s   t | jp	t | jS r   )boolr/   r0   r'   r   r   r   has_gradients?      z$OptimizerStateSwapInfo.has_gradientsc                 C   s"   |D ]}| j t|| j qd S r   )r-   appendr   r%   )r   tensor_listtr   r   r   r8   B   s   z#OptimizerStateSwapInfo._add_tensorsc                 C   s   d| _ | | d S )NT)r6   r8   )r   rB   r   r   r   add_state_tensorsF   s   z(OptimizerStateSwapInfo.add_state_tensorsc                 C   s
   t | jS r   )lenr-   r'   r   r   r   num_tensorsJ      
z"OptimizerStateSwapInfo.num_tensorsc                 C   r<   r   )r5   r'   r   r   r   r4   M   r=   zOptimizerStateSwapInfo.devicec                 C   r<   r   )r3   r'   r   r   r   r2   P   r=   zOptimizerStateSwapInfo.dtypec                 C   s   | j D ]}|  qd S r   )r-   r(   )r   rC   r   r   r   r(   S   s   

z%OptimizerStateSwapInfo.release_memoryc                 C      dd | j D S )Nc                 S      g | ]}|j qS r   )r   .0rC   r   r   r   
<listcomp>X       z>OptimizerStateSwapInfo.get_compute_tensors.<locals>.<listcomp>r-   r'   r   r   r   get_compute_tensorsW      z*OptimizerStateSwapInfo.get_compute_tensorsc                 C   rH   )Nc                 S   rI   r   )r#   rJ   r   r   r   rL   [   rM   z9OptimizerStateSwapInfo.get_swap_paths.<locals>.<listcomp>rN   r'   r   r   r   get_swap_pathsZ   rP   z%OptimizerStateSwapInfo.get_swap_pathsc                    sP   g }g } fdd| j D }|D ]}| r|jn|j ||j q||fS )Nc                    s"   g | ]}t  |j kr|qS r   r   	is_pinnedr   rJ   pinnedr   r   rL   `   s   " zEOptimizerStateSwapInfo.get_swap_buffers_and_paths.<locals>.<listcomp>)r-   rA   r   r   r#   )r   rU   r7   
swap_pathsselect_tensorsrC   r   rT   r   get_swap_buffers_and_paths]   s   z1OptimizerStateSwapInfo.get_swap_buffers_and_pathsc              
   C   sr   g }t ||D ]/\}}|| j vr-tj| j| j d| d| d}t|||| j|< |	| j| j q|S )N
_gradient__r   )
zipr/   keysr   r   r    r%   r.   r   rA   )r   offsetslengthsgradient_pathsr   r   r   r   r   r   get_or_create_gradient_pathsf   s   $z3OptimizerStateSwapInfo.get_or_create_gradient_pathsc           
      C   sd   t | j}|  g| }t||}|g| }t||}t| jD ]\}}	|	j|| || d q d S )N)r)   r*   )rE   r-   r:   r	   	enumerater+   )
r   buffersaligned_numelrF   compute_lengthscompute_buffersswap_lengthsr7   irC   r   r   r   set_swap_buffersq   s   



z'OptimizerStateSwapInfo.set_swap_buffersc                    s,   |      ks
J  fdd| j D S )Nc                    s   g | ]}  d |j|jqS r   )narrowr   r   rK   gradr*   r   r   rL   }   s    zDOptimizerStateSwapInfo.get_swap_gradient_buffers.<locals>.<listcomp>)r:   r/   values)r   r*   r   rm   r   get_swap_gradient_buffers{   s   z0OptimizerStateSwapInfo.get_swap_gradient_buffersc                 C   s   dd | j  D S )Nc                 S   rI   r   )r   rk   r   r   r   rL      rM   zBOptimizerStateSwapInfo.get_swap_gradient_paths.<locals>.<listcomp>)r/   rn   r'   r   r   r   get_swap_gradient_paths   r@   z.OptimizerStateSwapInfo.get_swap_gradient_pathsc                 C   rH   )Nc                 S   s    g | ]}t  |js|jqS r   rR   rJ   r   r   r   rL      s     zEOptimizerStateSwapInfo.get_unpinned_state_tensors.<locals>.<listcomp>rN   r'   r   r   r   get_unpinned_state_tensors   rP   z1OptimizerStateSwapInfo.get_unpinned_state_tensorsc                 C   sH   d}| j  D ]\}}|d|| }|j|j || 7 }q|S Nr   r0   itemsrj   r:   r&   copy_)r   dest_buffernum_elem_countr   grad_partition
dst_tensorr   r   r   read_unswapped_gradients      z/OptimizerStateSwapInfo.read_unswapped_gradientsc                 C   sH   d}| j  D ]\}}|d|| }|j|j || 7 }q|S rr   rs   )r   
src_bufferrw   r   rx   
src_tensorr   r   r   write_unswapped_gradients   r{   z0OptimizerStateSwapInfo.write_unswapped_gradientsc                 C   s
   i | _ d S r   )r0   r'   r   r   r   release_unswapped_gradients   rG   z2OptimizerStateSwapInfo.release_unswapped_gradientsN)r   r   r   r   r:   r?   r8   rD   rF   r4   r2   r(   rO   rQ   rX   r`   rh   ro   rp   rq   rz   r~   r   r   r   r   r   r,   -   s*    	
		r,   Fswap_out_gradientc                   @   s   e Zd Zedd Zdd Zdd Zd7dd	Zd
d Zdd Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zd d! Zd"d# Zd$d% Zd&d' Zd(d) Zd*d+ Zd,d- Zd.d/ Zd0d1 Zd8d3d4Zd5d6 ZdS )9r!   c                 C   r<   r   )ds_id)paramr   r   r   r"      s   zOptimizerSwapper.parameter_idc	           	      C   s   || _ || _i | _tjg |d | _tj	|ddt
  | _tj| jdd || _tt|t | _t|t  | _| j| j | _| || _|| _t| j|j|d| _|| _t | _g d| _ d S )N)r2   	optimizerrankT)exist_ok)	num_elemscountr2   )r   swap_buffer_managerswap_params_infotimerstimer_names)!swap_config
aio_configr   r   r$   element_sizeswap_element_sizer   r   r    distget_rankr%   makedirsr   maxr   AIO_BLOCK_SIZEmin_aio_bytesr   AIO_INTRA_OP_PARALLELISMaligned_bytesnumel_alignment_io_aligned_numellargest_numelr2   r
   buffer_countr   r   setr   print_exclude_list)	r   r   r   r;   r   r   r4   r2   r   r   r   r   r      s&   zOptimizerSwapper.__init__c                 C   s(   | j  D ]}|jd g|_d|_qd S )Nr   F)r   rn   r-   r6   )r   	swap_infor   r   r   purge_state   s   zOptimizerSwapper.purge_stateNc                 C   sD   |d us|d usJ d|d ur| j | | j kS | j || j kS )Nz'Either tensor or numel must be provided)r   r:   r   )r   r$   r:   r   r   r   is_swappable_tensor   s   z$OptimizerSwapper.is_swappable_tensorc                 C   s   t  | _d S r   )r   r   r'   r   r   r   init_timers      zOptimizerSwapper.init_timersc                 C   s"   | j r| jt| j dd d S d S )NT)force)r   _log_timerslistr'   r   r   r   
log_timers   s   zOptimizerSwapper.log_timersc                 C   s   |    d S r   )r   r'   r   r   r   pre_backward   r   zOptimizerSwapper.pre_backwardc                 C   s   d S r   r   r'   r   r   r   post_backward   s   zOptimizerSwapper.post_backwardc                 C   sT   |  r(| t | }| j| | t | jt | j	|
  d S d S r   )has_buffers_start_timerSWAP_OUT_GRADIENT_TIMERrelease_buffersr   free_stop_timerr   addupdateget_timer_names)r   gradient_swapperpinned_buffersr   r   r   _flush_gradient_swapper   s   

z(OptimizerSwapper._flush_gradient_swapperc                 C   s  t || j vrd S | jt | }g }g }g }| j||d\}	}
| t t|	|
D ]!\}}| j|ds=||j	|< q-|
| |
| |
|  q-t|dkru| sh| jj| j| jd}|| |||}|j||d | t | jt d S )N)r-   r]   r$   r   r   r2   )rB   	path_list)r!   r"   r   r\   _adjust_for_misaligned_lengthsr   r   r[   r   r0   rA   r:   rE   r   r   allocate_allr   r2   add_buffersr`   r   r   r   r   )r   r9   gradient_offsetsgradient_tensorsr   r   swappable_tensorsswappable_offsetsswappable_lengthsaligned_gradientsaligned_offsetsr$   r   r   swappable_pathsr   r   r   _swap_out_gradients   s2   






z$OptimizerSwapper._swap_out_gradientsc                    s  t |t |ks
J t |t |ksJ tdd |D sJ  j||d} jj j jd}dd |D }t fdd|D sMJ d| d j t|}	t|}
d	}|t |k rψ j|||d  ||d  |
d
}t	
 d	krtrt|D ]$\}}|| }tdt||  d| d||  d||    qz j|||d  |	|d}|t |ksJ | dt | |
  |	  ||7 }|t |k s] j| d S )Nc                 S   s   g | ]}t  |qS r   )r   rS   )rK   bufferr   r   r   rL         zIOptimizerSwapper._initialize_from_swapped_fp16_params.<locals>.<listcomp>
parametersr   r   c                 S      g | ]}|  qS r   r:   )rK   bufr   r   r   rL         c                    s   g | ]}| j kqS r   )r   )rK   r:   r'   r   r   rL     s    znumel of fp16 buffers z+ is too small for initializing fp32 params r   )
aio_handlefp16_num_elemsfp16_partitions_infofp16_swap_bufferszswap_in_fp16_param: fp32_id = 	 index = z orig_num_elem = , swap_num_elem = )r   fp32_swap_pathsfp32_swap_buffersfp16_pinned_tensorsz does not match )rE   all_get_swap_pathsr   r   r   r2   r   _swap_in_fp16_paramsr   r   SWAPPER_DEBUG_MODEra   r   infor!   r"   r:   _swap_out_fp16_paramsresetr   )r   r   r   r   fp16_pinned_buffersfp32_parametersr   fp32_pinned_buffersfp16_buffer_numelr   r   
curr_indexr   rg   r$   
true_indexswap_out_countr   r'   r   $_initialize_from_swapped_fp16_params  sJ   

0
z5OptimizerSwapper._initialize_from_swapped_fp16_paramsc                 C   s  t |dksJ g }g }g }g }g }	t|D ]H\}
}||d |\}}|d u r) n6|| d}||
 D ])\}}}|d||}|d u rO|| |	| n
|| || ||7 }q4qt |t | dkskJ t|||}t||	D ]\}}|j|j qvt ||	 ksJ |S rr   )
rE   ra   allocate_tensorrA   rj   r   r[   r&   ru   wait)r   r   r   r   r   swapped_fp16_tensorsswap_tensorsrV   unswapped_srcsunswapped_dstsrg   r:   pinned_tensorrZ   r   r$   partition_numelpartition_pathry   retsrcdstr   r   r   r   ;  s6   





z%OptimizerSwapper._swap_in_fp16_paramsc           
   	   C   s   t |t |ks
J d}t|D ].\}}|| s$|| |  |||| | | \}}	|d us:J |d7 }qt | dkrL|| |S )Nr      )	rE   ra   	has_spacer:   swap_outr   insert_tensorr   get_swap_tensors)
r   r   r   r   r   r   rg   fp16_tensorr   rZ   r   r   r   r   ^  s   


z&OptimizerSwapper._swap_out_fp16_paramsc           	   
   C   s   t |t |ks
J | j|dd |D d}d}| | | jj| j| jd}|d us-J | j||||d t	 dkr^t
r^t|D ]\}}tdt||  d	| d
||    qB| j| | | | |g d S )Nc                 S   r   r   r   )rK   r   r   r   r   rL   t  r   z;OptimizerSwapper._initialize_parameters.<locals>.<listcomp>r   swap_init_writer   )r   unpinned_tensors
dest_pathsr   r   zcopy_in_fp16_param: fp32_id = r   r   )rE   r   r   r   r   r   r2   _swap_out_unpinned_tensorsr   r   r   ra   r   r   r!   r"   r:   r   r   r   )	r   r   src_tensorsr   rV   SWAP_INIT_TIMERr   rg   r$   r   r   r   _initialize_parametersq  s&   
&
z'OptimizerSwapper._initialize_parametersc                    s>    fddt ||D }t|t|ksJ dd |D }|S )Nc                    s   g | ]\}} j ||d qS ))r9   r:   )_create_param_swap_info)rK   pr:   r'   r   r   rL     s    z4OptimizerSwapper._get_swap_paths.<locals>.<listcomp>c                 S   s   g | ]}|j d  jqS ri   )r-   r#   )rK   r   r   r   r   rL     r   )r[   rE   )r   r   r   swap_info_listrV   r   r'   r   r     s   
z OptimizerSwapper._get_swap_pathsc                    s   t |}t |}td||D ]R}t|| |}||||  }	dd |	D }
t||
}t||	D ]\}}|j|j q0 fdd|	D }t||}||||  }t||| | |ks`J qd S )Nr   c                 S   r   r   r   rJ   r   r   r   rL     r   z?OptimizerSwapper._swap_out_unpinned_tensors.<locals>.<listcomp>c                    s   g | ]	}  | qS r   )r   r:   rJ   r'   r   r   rL     s    )	rE   rangeminr	   r[   r&   ru   r   r   )r   r   r   r   r   swap_buffer_countunpinned_tensor_countrg   swap_tensor_countr   rd   re   r   r   rf   r7   rV   r   r'   r   r     s   

z+OptimizerSwapper._swap_out_unpinned_tensorsc           	      C   s   g }g }t ||D ]U\}}| j|ds|| || q	| | j }|dkr4|| || q	| | j | j }||dd| || ||d|| |||  q	||fS )Nr   r   )r[   r   rA   r:   r   rj   )	r   r-   r]   new_tensorsnew_offsetsorig_tensororig_offset	remainderaligned_lengthr   r   r   r     s$   




z/OptimizerSwapper._adjust_for_misaligned_lengthsc                 C   sj   d}|  | t|j}||}| | | |g |  tr3t	d|j
 d| d|  d S d S )Nunswapped_read_gradientsz.optimizer_retrieve_unswapped_gradients: param=z tensor_count=z elem_count=)r   rE   r0   rz   r   r   r   r   r   r   r.   )r   r   rv   UNSWAPPED_READ_GRADIENTStensor_countrw   r   r   r   #_retrieve_unswapped_grad_partitions  s   



z4OptimizerSwapper._retrieve_unswapped_grad_partitionsc                 C   sb   || j jvrg S g }| j j|  D ]\}}t|r.| j|dr.|d |j |_|| q|S )Nr   -)r   statert   r   	is_tensorr   r   rA   )r   r9   rB   
state_namevaluer   r   r   _get_state_tensors  s   
z#OptimizerSwapper._get_state_tensorsc                 C   s*   |j s| |}|r|| d S d S d S r   )r6   r  rD   )r   r   r9   state_tensorsr   r   r   _update_param_state_info  s   
z)OptimizerSwapper._update_param_state_infoc                 C   sH   t |}|| jvsJ t||| jd| j|< | j| }| || |S )N)r9   r:   r;   )r!   r"   r   r,   r%   r  )r   r9   r:   r.   r   r   r   r   r     s   

z(OptimizerSwapper._create_param_swap_infoc                 C   s0   t |}| j|d }|d ur| || |S r   )r!   r"   r   getr  )r   r9   r.   r   r   r   r   _get_param_swap_info  s
   
z%OptimizerSwapper._get_param_swap_infoc                 C      | j r|  |  d S d S r   )r   startr   namer   r   r   r        zOptimizerSwapper._start_timerc                 C   r  r   )r   stopr  r   r   r   r     r  zOptimizerSwapper._stop_timerFc                 C   s&   | j rts|r| j | d S d S d S r   )r   r   log)r   	name_listr   r   r   r   r   	  s   zOptimizerSwapper._log_timersc                 C   s$   || j  }|dkr|S || j  | S rr   )r   )r   r:   r	  r   r   r   r     s   
z"OptimizerSwapper._io_aligned_numel)NN)F)r   r   r   staticmethodr"   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r   r  r   r   r   r   r   r   r   r   r!      s8    
%
	$,#	
r!   )__doc__r   r   	deepspeedr   r   deepspeed.utils.loggingr   'deepspeed.runtime.swap_tensor.constants#deepspeed.runtime.swap_tensor.utilsr   r   r   r   r	   r
   r   deepspeed.acceleratorr   objectr   r   r,   r   r   r!   r   r   r   r   <module>   s   n