o
    }oi-7                     @   sJ  d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZ d dlmZ d dlZd dlmZ d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dl m!Z! d dl"m#Z#m$Z$m%Z% d dl&m'Z' d dl(m)Z) d dl*m+Z+ d dl,m-Z-m.Z.m/Z/m0Z0m1Z1m2Z2m3Z3m4Z4m5Z5 e6e7Z8G dd dee+j9ZdS )    N)OrderedDict)Path)AnyDictOptionalUnion)CheckpointIO)_get_sharded_state_dict_context)rank_zero_info)
reset_seed)FSDPStrategy)	TrainerFn)STEP_OUTPUT)TransformerLayer)StateDictOptionsget_optimizer_state_dictset_state_dict)
DataLoader)override)io)	_destroy_dist_connectionckpt_to_dircreate_checkpoint_iofix_progress_barinit_model_parallelmcore_to_pyt_sharded_state_dictpyt_to_mcore_state_dictsetup_data_samplersetup_parallel_ranksc                
       s  e Zd ZdZehddddfdedef fddZed/d
dZede	j
d	df fddZdefddZd0ddZed0d	efddZed0d	efddZed0d	efddZed0d	efddZeded	efddZeed	efdd Zejd!ed	dfd"d Zed	efd#d$Zed%eeef d	dfd&d'Ze	d0d(e eef d%eeef d)e!e d	dfd*d+Z"ed,eeB d	e eef fd-d.Z#  Z$S )1r   a1  Megatron plugin for Pytorch Lightning.

    This strategy implements Fully-Sharded-Data-Parallel using PyTorch's native FSDP methods.
    Comparing with MegatronStrategy, FSDPStrategy is designed to be more lightweight, with
    minimal modifications over Lightning's FSDPStrategy but preserves necessary features to be
    compatible with nemo and mcore.
    By default, this strategy wraps FSDP per TransformerLayer.

    Note:
        This strategy is designed to work with NVIDIA's Megatron-LM framework and requires
        specific model implementations that are compatible with Megatron's parallelism techniques.
    Note:
        Due to the different optimizer structure (FSDP only uses torch native optimizers),
        MegatronStrategy cannot resume training from checkpoints saved by FSDPStrategy, and vice
        versa. However, the model weights structure is made compatible, so switching strategy is
        possible if users only need the weights not the optimizer states. (E.g. run pretrain with
        megatron 4D parallelism and run SFT with FSDP.)
    shardedTNckpt_load_optimizerckpt_save_optimizerc                    s4   t  jd||d| || _|| _|| _d | _d S )N)auto_wrap_policystate_dict_type )super__init__data_samplerr    r!   store)selfr"   r#   r    r!   r'   kwargs	__class__r$   c/home/ubuntu/.local/lib/python3.10/site-packages/nemo/lightning/pytorch/strategies/fsdp_strategy.pyr&   M   s
   	
zFSDPStrategy.__init__returnc              	   C   sJ  t |  | j| j t  |   |  | _| jdusJ t	j
 s'tdt	j
 r3td dS | j }| j }| jjtjd< t| jjtjd< td| d|d  d	|  t	j
j| j||| jd
 | jdkrstt td d| j d| dd d t| j dt!rddl"m#} |d| jd | jd< t$| j% dS )z6Initializes rank and process group for communications.NzOtorch.distributed is not available. Cannot initialize distributed process groupz7torch.distributed is already initialized. Exiting earlyMASTER_ADDRMASTER_PORTz'Initializing distributed: GLOBAL_RANK: z
, MEMBER:    /)rank
world_sizer(   ncclzd----------------------------------------------------------------------------------------------------z
distributed_backend=z5
All distributed processes registered. Starting with z processes

device_meshr   )init_device_meshcuda)&r   acceleratorsetup_deviceroot_devicer   set_world_ranks_get_process_group_backend_process_group_backendcluster_environmenttorchdistributedis_availableRuntimeErroris_initialized_loggerdebugglobal_rankr4   main_addressosenvironstr	main_portinfoinit_process_groupr(   atexitregisterr   r
   
isinstancer*   gettupletorch.distributed.device_meshr8   r   model)r)   rH   r4   r8   r$   r$   r-   setup_environment]   sD   





 

zFSDPStrategy.setup_environmenttrainerc                    s(   || _ t| j  t| t | dS )zJConnect strategy to trainer and handle adjustments before the loop starts.N)rX   r   r   r%   setup)r)   rX   r+   r$   r-   rY      s   
zFSDPStrategy.setup	step_typec                 C   s4   | ddfD ]}t | j|rt| j|  S qd S )N_loss_reductionloss_reduction)hasattrlightning_modulegetattr)r)   rZ   fn_namer$   r$   r-   _get_loss_reduction   s
   z FSDPStrategy._get_loss_reductionc                 C   sf   | d}| j | jkr| | j | j|||}n	t| j|||}| |}|r-|||S |d|ifS )N_stepavg)rV   r^   _forward_redirectionr_   ra   forward)r)   rZ   batch	batch_idxmethod_namelossr[   r$   r$   r-   _step_proxy   s   

zFSDPStrategy._step_proxyc                 C   s   | j dusJ | jdusJ | j 6 | d||\}}| j jd| jjdddd | j d| jj | j jd|d	 dddd |W  d   S 1 sLw   Y  dS )
z#Run training step and logs results.Ntrainingglobal_stepTr1   )prog_barrank_zero_only
batch_sizestepreduced_train_lossrc   )r^   rV   precision_plugintrain_step_contextrj   logrX   rl   r)   rf   rg   ri   reducedr$   r$   r-   training_step   s(   $zFSDPStrategy.training_stepc                 C   x   | j dusJ | jdusJ | j  | d||\}}| j jd|d ddd |W  d   S 1 s5w   Y  dS )z%Run validation step and logs results.N
validationval_lossrc   Tr1   rn   ro   )r^   rV   rr   val_step_contextrj   rt   ru   r$   r$   r-   validation_step   s   $zFSDPStrategy.validation_stepc                 C   rx   )zRun test step and logs results.Ntest	test_lossrc   Tr1   r{   )r^   rV   rr   test_step_contextrj   rt   ru   r$   r$   r-   	test_step   s   $zFSDPStrategy.test_stepc                 C   s`   | j dusJ | jdusJ | j  | d||\}}|W  d   S 1 s)w   Y  dS )zRun prediction step.Npredict)r^   rV   rr   predict_step_contextrj   ru   r$   r$   r-   predict_step   s   $zFSDPStrategy.predict_step
dataloaderc                 C   s   | j r	| j |S |S )z"Transform dataloader with sampler.)r'   transform_dataloader)r)   r   r$   r$   r-   process_dataloader   s   zFSDPStrategy.process_dataloaderc                 C   s   | j st | _ | j S )zGet CheckpointIO.)_checkpoint_ior   r)   r$   r$   r-   checkpoint_io   s   zFSDPStrategy.checkpoint_ior   c                 C   s
   || _ dS )zSet CheckpointIO.N)r   )r)   r   r$   r$   r-   r      s   
c                 C   s*   t | jjjjjjjjj	| jjjj
jjj	S )z8
        Get the value of step within an epoch.
        )maxrX   fit_loop
epoch_loopautomatic_optimizationoptim_progress	optimizerrp   current	completedmanual_optimizationoptim_step_progressr   r$   r$   r-   current_epoch_step   s   zFSDPStrategy.current_epoch_stepfilepathc                 C   s:   t |}| jrtj|rt| dS t| dS dS )zDelete checkpoint at filepath.N)r   is_global_zerorJ   pathislinkunlinkshutilrmtree)r)   r   ckptr$   r$   r-   remove_checkpoint   s   zFSDPStrategy.remove_checkpoint
checkpointstorage_optionsc                 C   s   t |d|d< tg |d< d|v r5| jjjtjkr5i |d< | jr5t	| j
| j|d< t |d d dd | jj|||d d	S )
zPConverts PyT checkpoints to MCore format and save using MCore dist ckpt library.
state_dictsharded_state_dictoptimizer_statesr   stateoptimizer.state.prefix)r   N)r   popr   rX   r   fnr   FITTINGr!   r   rV   
optimizersr   save_checkpoint)r)   r   r   r   r$   r$   r-   r   
  s   zFSDPStrategy.save_checkpointcheckpoint_pathc                 C   s$  t | |}tj  i }t| j | j }t| ||d< W d   n1 s+w   Y  | j	rS| j
jjtjkrSt| j| jtddd}t|d dd ||d	< | jj||d
}t|d | | j	rx| j
jjtjkrxt|d	 d |d  t| j| j	r| jng |d | j	r|d	 ndd |S )a  PTL method which we override to integrate distributed checkpoints for FSDP models.
        Different from MegatronStrategy, both model and optimizer states are restore within
        this method.

        The logic here is slightly more complicated:
        1. Obtain PyT state dicts (sharded & unflattened) for model and optim -> torch::ShardedTensor
        2. Convert to MCore state dicts -> mcore::ShardedTensor
        3. Load from checkpoint using MCore dist ckpt API -> torch::Tensor
        4. Convert to PyT state dicts (sharded & unflattened) -> torch::ShardedTensor
        5. Load into model and optim using PyT dist ckpt API
        6. Return the loaded checkpoint for lightning to load other metadata
        r   NT)cpu_offload)optionsr   r   r   r   )r   )model_state_dictoptim_state_dict)r   	broadcastrA   r9   empty_cacher	   rV   r   r   r    rX   r   r   r   r   r   r   r   r   load_checkpointr   r   )r)   r   r   r   msdosdr   r$   r$   r-   r      s.   


zFSDPStrategy.load_checkpoint)r.   N)N)%__name__
__module____qualname____doc__r   boolr&   r   rW   plTrainerrY   rL   ra   rj   r   rw   r   r}   r   r   r   r   propertyr   r   setterintr   r   r   r   r   r   r   r   __classcell__r$   r$   r+   r-   r   9   sd    .
	
	


(r   ):rP   loggingrJ   r   collectionsr   pathlibr   typingr   r   r   r   lightning.pytorchpytorchr   rA   lightning.fabric.pluginsr    lightning.fabric.strategies.fsdpr	   $lightning.fabric.utilities.rank_zeror
   lightning.fabric.utilities.seedr   !lightning.pytorch.strategies.fsdpr   PLFSDPStrategy lightning.pytorch.trainer.statesr   !lightning.pytorch.utilities.typesr   +megatron.core.transformer.transformer_layerr   'torch.distributed.checkpoint.state_dictr   r   r   torch.utils.datar   typing_extensionsr   nemo.lightningr   'nemo.lightning.pytorch.strategies.utilsr   r   r   r   r   r   r   r   r   	getLoggerr   rF   IOMixinr$   r$   r$   r-   <module>   s0   ,
