o
    zif                     @   s&  d dl Z d dlZd dlmZ d dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZmZmZmZmZmZ d dlZd dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZmZ d dlm Z  d dl!m"Z" d dl#m$Z$m%Z%m&Z&m'Z'm(Z( d dl)m*Z* d dl+m,Z, d dl-m.Z.m/Z/m0Z0m1Z1 d dl2m3Z3m4Z4m5Z5m6Z6m7Z7 d dl2m8Z9 d dl:m;Z;m<Z< d dl=m>Z> d dl?m@Z@mAZAmBZB d dlCmZ d dlDmEZE d dlFmGZGmHZH e	rd dlImJZJ ededZKG dd de,ZLG dd de0ZMG dd  d eZN	dDd!ed"eeOeeee
f f d#ePd$eQd%eeeOeeOe
gePf f  d&dfd'd(ZR	)	*dEd!ed"eeOeeee
f f d+ePd,ePd&eeOe
f f
d-d.ZSd/eQd0eQd1eQd2ejTd&d3f
d4d5ZUd6eVd&ee fd7d8ZWdFd!ed6ed1eQd+ePd&df
d9d:ZX	)dGd<eeOe
f d6ed1eQd+ePd&df
d=d>ZYd6ed&efd?d@ZZdAeeOe
f d6ed&eeOe
f fdBdCZ[dS )H    N)	ExitStack)	timedelta)Path)
TYPE_CHECKINGAnyCallableContextManagerDict	GeneratorLiteralOptionalTypeVarUnion)rank_zero_only)Tensor)Module)	Optimizer)	TypeGuardoverride)CheckpointIO)default_pg_timeout)_distributed_checkpoint_load_distributed_checkpoint_save_get_full_state_dict_context_is_full_checkpoint_is_sharded_checkpoint)_SubprocessScriptLauncher)ParallelStrategy)
TBroadcast_apply_filter_BackwardSyncControl!_validate_keys_for_strict_loading)ReduceOp_distributed_is_initialized-_get_default_process_group_backend_for_device_init_dist_connection_sync_ddp_if_availablegroup)_TORCH_GREATER_EQUAL_2_3_TORCH_GREATER_EQUAL_2_4)_materialize_distributed_module)_METADATA_FILENAME
_lazy_load_move_state_into)
reset_seed)_PATH	_Stateful)
DeviceMeshTModel)boundc                       s  e Zd ZdZddddefdeedgef deed e	f deed e	f d	e
d
ee dee ddf fddZedKddZeedefddZejededdfddZeedejfddZede	fddZejde	ddfddZede	fddZeedeeef fddZedee fdd ZedLd!d"ZedL fd#d$Zed%e de fd&d'Z!ed%e ddfd(d)Z"edMd*ee
 de#fd+d,Z$e	-dNd.e%d/ee d0eee&ef  de%fd1d2Z'ed3ed4eddfd5d6Z(edOd8e)d9e	de)fd:d;Z*e		dPd<e+d=eeee e,ef f d>ee d?eeeeeege
f f  ddf
d@dAZ-e		dQd<e+d=eee e,eeee e,ef f f  dBe
deeef fdCdDZ.dLdEdFZ/defdGdHZ0dLdIdJZ1  Z2S )RModelParallelStrategya  Enables user-defined parallelism applied to a model.

    .. warning::  This is an :ref:`experimental <versioning:Experimental API>` feature.

    Currently supports up to 2D parallelism. Specifically, it supports the combination of
    Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
    experimental in PyTorch. Requires PyTorch 2.4 or newer.

    Arguments:
        parallelize_fn: A function that applies parallelisms to a module. The strategy will provide the
            model and device mesh as input.
        data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which
            sets this size to the number of nodes in the cluster.
        tensor_parallel_size: The number of devices within a tensor-parallel group. Defaults to ``"auto"``, which
            sets this size to the number of GPUs in a single node.
        save_distributed_checkpoint: If ``True``, each rank saves its shard of weights and optimizer states to a file.
            The checkpoint is a folder with as many files as the world size.
            If ``False``, the full weights and optimizer states get assembled on rank 0 and saved to a single file.

    autoTNparallelize_fnr2   data_parallel_sizetensor_parallel_sizesave_distributed_checkpointprocess_group_backendtimeoutreturnc                    s^   t    tstt| j d|| _|| _|| _d| _	|| _
|| _|| _t | _d | _d S )Nz  requires PyTorch 2.4 or higher.   )super__init__r*   ImportErrortype__name___parallelize_fn_data_parallel_size_tensor_parallel_size
_num_nodes_save_distributed_checkpoint_process_group_backend_timeout_ParallelBackwardSyncControl_backward_sync_control_device_mesh)selfr7   r8   r9   r:   r;   r<   	__class__ ^/home/ubuntu/.local/lib/python3.10/site-packages/lightning_fabric/strategies/model_parallel.pyr@   X   s   
	
zModelParallelStrategy.__init__c                 C   s   | j d u r	td| j S )NzKAccessing the device mesh before processes have initialized is not allowed.)rM   RuntimeErrorrN   rQ   rQ   rR   device_mesho   s   
z!ModelParallelStrategy.device_meshc                 C      t dt| j d)NThe `z3` does not use the `CheckpointIO` plugin interface.NotImplementedErrorrB   rC   rT   rQ   rQ   rR   checkpoint_iou      z#ModelParallelStrategy.checkpoint_ioioc                 C   rV   )NrW   z3` does not support setting a `CheckpointIO` plugin.rX   )rN   r\   rQ   rQ   rR   rZ   z   r[   c                 C   s   | j d usJ | j | j S N)parallel_devices
local_rankrT   rQ   rQ   rR   root_device   s   z!ModelParallelStrategy.root_devicec                 C      | j S r]   rG   rT   rQ   rQ   rR   	num_nodes      zModelParallelStrategy.num_nodesrc   c                 C   s
   || _ d S r]   rb   )rN   rc   rQ   rQ   rR   rc      s   
c                 C   s   | j d ur
t| j S dS )Nr   )r^   lenrT   rQ   rQ   rR   num_processes   s   z#ModelParallelStrategy.num_processesc                 C   s*   | j d usJ | j d }| | dS )Ndata_parallel)num_replicasrank)rU   sizeget_local_rank)rN   data_parallel_meshrQ   rQ   rR   distributed_sampler_kwargs   s   
z0ModelParallelStrategy.distributed_sampler_kwargsc                 C   ra   r]   )rI   rT   rQ   rQ   rR   r;      rd   z+ModelParallelStrategy.process_group_backendc                 C   s2   | j d usJ | j jst| j | j| j| _d S d S r]   )cluster_environmentcreates_processes_externallyr   rf   rc   	_launcherrT   rQ   rQ   rR   _configure_launcher   s   z)ModelParallelStrategy._configure_launcherc                    sR   t    |   | jdkr| j| _| jdkr| j| _t| j| j| j| j	| _
d S )Nr6   )r?   setup_environment_setup_distributedrE   rc   rF   rf   _setup_device_mesh
world_sizer`   rM   rT   rO   rQ   rR   rr      s   



z'ModelParallelStrategy.setup_environmentmodulec                    sv   ddl m  t fdd| D rtd| jj d| || j}t	|t
s3tdt|j t|| j |S )Nr   FullyShardedDataParallelc                 3       | ]}t | V  qd S r]   
isinstance).0modrw   rQ   rR   	<genexpr>       z5ModelParallelStrategy.setup_module.<locals>.<genexpr>z\Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`. The `z5` only supports the new FSDP2 APIs in PyTorch >= 2.4.zBThe `parallelize_fn` must return a `nn.Module` instance, but got: )torch.distributed.fsdprx   anymodules	TypeErrorrP   rC   rD   rU   r{   r   rB   r+   r`   rN   rv   rQ   rw   rR   setup_module   s   
z"ModelParallelStrategy.setup_modulec                 C   s   d S r]   rQ   r   rQ   rQ   rR   module_to_device   s   z&ModelParallelStrategy.module_to_device
empty_initc                 C   s2   | j  }t }|r|td || |S )Nmeta)	precisionmodule_init_contextr   enter_contexttorchdevice)rN   r   precision_init_ctxstackrQ   rQ   rR   r      s   

z)ModelParallelStrategy.module_init_contextmeantensorr(   	reduce_opc                 C   s   t |trt|||dS |S )N)r   )r{   r   r&   )rN   r   r(   r   rQ   rQ   rR   
all_reduce   s   
z ModelParallelStrategy.all_reduceargskwargsc                 O   s>   t  sd S tj dkrtjj| jjgd d S tj  d S )Nnccl)
device_ids)r#   r   distributedget_backendbarrierr`   index)rN   r   r   rQ   rQ   rR   r      s
   zModelParallelStrategy.barrierr   objsrcc                 C   s,   t  s|S |g}tjj||tjd |d S )Nr'   r   )r#   r   r   broadcast_object_list_groupWORLD)rN   r   r   rQ   rQ   rR   	broadcast   s
   zModelParallelStrategy.broadcastpathstatestorage_optionsfilterc                 C   sv   |durt dt| j dt| j d|dur&| jr&tt| j dt| |}t||| j | j|d dS )aN  Save model, optimizer, and other state to a checkpoint on disk.

        If distributed checkpointing is enabled (default), the checkpoint gets saved as a directory containing one file
        per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file
        `meta.pt` with the rest of the user's state (only saved from rank 0).
        If distributed checkpointing is disabled (``save_distributed_checkpoint=False``), the checkpoint will be
        written to a single file containing the weights, optimizer state and other metadata.

        N`zF.save_checkpoint(..., storage_options=...)` is not supported because `z"` does not use the `CheckpointIO`.zV doesn't support loading distributed filtered checkpoints, so saving them is disabled.)r   r   full_state_dictri   r   )	r   rB   rC   rH   rY   r   r   _save_checkpointglobal_rank)rN   r   r   r   r   rQ   rQ   rR   save_checkpoint   s$   
z%ModelParallelStrategy.save_checkpointstrictc                 C   s   |st dt| j d|dt| j dt| |}t|tr-t||| j|d i S t|t	r=t
dt| j dt|||dS )	zOLoad the contents from a checkpoint and restore the state of the given objects.zGot z.load_checkpoint(..., state=zY) but a state with at least  a model instance to reload is required. Pass it in like so: z2.load_checkpoint(..., state={'model': model, ...}))rv   ru   r   zNLoading a single optimizer object from a checkpoint is not supported yet with .)r   r   r   )
ValueErrorrB   rC   r   r   r{   r    _load_raw_module_state_from_pathru   r   rY   _load_checkpoint)rN   r   r   r   rQ   rQ   rR   load_checkpoint  s"   

z%ModelParallelStrategy.load_checkpointc                 C   s>   t   |   |  | _| jd usJ t| j| j| jd d S )N)r<   )r/   _set_world_ranks_get_process_group_backendrI   rn   r%   rJ   rT   rQ   rQ   rR   rs   ,  s
   
z(ModelParallelStrategy._setup_distributedc                 C   s   | j pt| jS r]   )rI   r$   r`   rT   rQ   rQ   rR   r   3     z0ModelParallelStrategy._get_process_group_backendc                 C   sJ   | j d ur| j | j| j | j  | j | j| j  | j t_	t
_	d S r]   )rn   set_global_rank	node_rankrf   r_   set_world_sizerc   r   r   ri   utils_rank_zero_onlyrT   rQ   rQ   rR   r   6  s   
z&ModelParallelStrategy._set_world_ranks)r=   r2   r=   Nr]   )Nr   )r   )NN)NT)3rC   
__module____qualname____doc__r   r   r3   r   r   intboolr   strr   r@   propertyrU   r   r   rZ   setterr   r   r`   rc   rf   r	   r   rm   r;   rq   rr   r   r   r   r   r   r   r"   r   r   r   r   r0   r   r   r   rs   r   r   __classcell__rQ   rQ   rO   rR   r5   B   s    
%"

r5   c                   @   s&   e Zd ZedededefddZdS )rK   rv   enabledr=   c                 C   s   t ||dS )z9Blocks gradient synchronization inside the FSDP2 modules.)rv   r   )_FSDPNoSyncrN   rv   r   rQ   rQ   rR   no_backward_sync@  s   z-_ParallelBackwardSyncControl.no_backward_syncN)rC   r   r   r   r   r   r   r   rQ   rQ   rQ   rR   rK   ?  s    rK   c                   @   sX   e Zd ZdededdfddZdeddfdd	Zdd
dZdedededdfddZ	dS )r   rv   r   r=   Nc                 C   s   || _ || _d S r]   )_module_enabledr   rQ   rQ   rR   r@   G  s   
z_FSDPNoSync.__init__requires_grad_syncc                 C   s8   ddl m} | j D ]}t||r|j|dd qd S )Nr   )
FSDPModuleFrecurse)"torch.distributed._composable.fsdpr   r   r   r{   set_requires_gradient_sync)rN   r   r   r}   rQ   rQ   rR   _set_requires_grad_syncK  s   
z#_FSDPNoSync._set_requires_grad_syncc                 C   s   |  | j  d S r]   r   r   rT   rQ   rQ   rR   	__enter__R  s   z_FSDPNoSync.__enter__exc_type	exc_value	tracebackc                 C   s   |  | j d S r]   r   )rN   r   r   r   rQ   rQ   rR   __exit__U  r   z_FSDPNoSync.__exit__r   )
rC   r   r   r   r   r@   r   r   r   r   rQ   rQ   rQ   rR   r   F  s
    
r   r   r   r   ri   r   r=   c                 C   s  |   r|rt| std|  dd | D }t|dkr$tdt|dkr.td|d }ddlm}m}m	}	 ||d	d
}
i }i }|
 D ]7\}}t|tr\|||
d}|}nt|trk|	|||
d}|}nt|trt| n|}|}t||p}i || qJ|rt| rt|  || |dkrt||  d S d S |  r|   | jd	d	d t||  |dkrt|| t  d S d S )Nz/The checkpoint path exists and is a directory: c                 S   s   g | ]}t |r|qS rQ   _has_dtensor_modules)r|   rv   rQ   rQ   rR   
<listcomp>c  s    z$_save_checkpoint.<locals>.<listcomp>r   a  Could not find a distributed model in the provided checkpoint state. Please provide the model as part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure you set up the model (and optimizers if any) through the strategy before saving the checkpoint.r>   zFound multiple distributed models in the given state. Saving distributed checkpoints is currently limited to a single model per checkpoint. To save multiple models, call the save method for each model separately with a different path.)StateDictOptionsget_model_state_dictget_optimizer_state_dictT)r   cpu_offloadoptions)parentsexist_ok)is_dirr   IsADirectoryErrorvaluesre   r   'torch.distributed.checkpoint.state_dictr   r   r   itemsr{   r   r   r1   
state_dictr   shutilrmtreeupdater   saveis_fileunlinkmkdirr   r,   )r   r   r   ri   r   r   rv   r   r   r   state_dict_optionsconverted_statemetadatakeyr   	convertedtarget_dictrQ   rQ   rR   r   Y  sP   




r   TFr   optimizer_states_from_listc                 C   s,  ddl m}m}m}m} dd | D }t|dkrtddd | D }	t|dkr2tdt| d \}
}t	| r|d	d
}|
||i}t
||  |j||
 |d |	 D ]\}}||||i}t
||  ||||| |d q]t| t }| |  |	  }t|| |d |D ]}||vrq||||< q|S t| rtj| d	ddd}t||
||d |d	d	|d}t|	 D ]!\}\}}|r|d | }n||}t||}|||||d q| |  |	  }t|| |d t|||d |S tdt| d)Nr   )r   r   r   set_optimizer_state_dictc                 S   s   i | ]\}}t |r||qS rQ   r   )r|   r   rv   rQ   rQ   rR   
<dictcomp>  s    z$_load_checkpoint.<locals>.<dictcomp>a  Could not find a distributed model in the provided checkpoint state. Please provide the model as part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure you set up the model (and optimizers if any) through the strategy before loading the checkpoint.c                 S   s    i | ]\}}t |tr||qS rQ   )r{   r   )r|   r   optimrQ   rQ   rR   r     s     r>   zFound multiple distributed models in the given state. Loading distributed checkpoints is currently limited to a single model per checkpoint. To load multiple models, call the load method for each model separately with a different path.T)r   r   )optim_state_dictr   cpuF)mmapmap_locationweights_onlybroadcast_from_rank0r   r   optimizer_states)sourcedestinationkeysz	The path z does not point to a valid checkpoint. Make sure the path points to either a directory with distributed checkpoint shards, or a single file with a full checkpoint.)r   r   r   r   r   r   re   r   listr   r   load_state_dictr   loadr,   r  r!   popr   _load_raw_module_state	enumerate _rekey_optimizer_state_if_neededr.   r   )r   r   r   r   r   r   r   r   r   
optimizers
module_keyrv   r   module_state	optim_keyr   optim_stater   requested_metadata_keysr   
checkpointoptimizer_idxoptimizer_name	optimizeroptimizer_staterQ   rQ   rR   r     sn   





r   r8   r9   ru   r   r2   c                 C   sH   ddl m} | | |krtd|  d| d| d||j| |fddS )	Nr   )init_device_meshzThe sizes `data_parallel_size=z` and `tensor_parallel_size=z*` multiplied should equal the world size (z).)rg   tensor_parallel)device_type
mesh_shapemesh_dim_names)torch.distributed.device_meshr  rS   rB   )r8   r9   ru   r   r  rQ   rQ   rR   rt     s   rt   rv   c                    s0   ddl m  t| tot fdd|  D S )Nr   DTensorc                 3   ry   r]   rz   )r|   tr  rQ   rR   r~     r   z'_has_dtensor_modules.<locals>.<genexpr>)torch.distributed._tensorr  r{   r   r   
parameters)rv   rQ   r  rR   r   	  s   $r   c                 C   sF   t | std|  trtj| dddnt| }t||||d dS )z;Loads the state dict from a file path into the FSDP module.zxFailed to load checkpoint directly into the model. The given path must be a single file containing the full state dict: Tr   )r   r   )r   rv   ru   r   N)r   r   r)   r   r  r-   r  )r   rv   ru   r   r   rQ   rQ   rR   r     s   r   r>   r   c                 C   s  ddl m} t|rUddlm}m} |dddd}| D ]5\}}	t|	D ],\}
}| |r/dnd |
 }|| vrD|s<q%td	| d
|
| | i}||	||d q%qdS t	||r{t
||dd |j| |d W d   dS 1 stw   Y  dS |j| |d dS )zlLoads the state dict into the module by gathering all weights first and then and writing back to each shard.r   rw   )r   set_model_state_dictTFr   r    zThe model contains a key 'z^' that does not exist in the loaded checkpoint. To disable strict loading, set `strict=False`.r   )ru   
rank0_onlyr   N)r   rx   r   r   r   r   named_modules%_named_parameters_and_buffers_to_loadKeyErrorr{   r   r  )r   rv   ru   r   FSDPr   r   r   submodule_name	submodule
param_name_full_param_namelocal_state_dictrQ   rQ   rR   r    s4   

"r  c                 c   sB    t | jdd| jddD ]\}}|| jv rq||fV  qdS )zEReturns parameters and buffers, with non-persistent buffers excluded.Fr   N)	itertoolschainnamed_buffersnamed_parameters_non_persistent_buffers_set)rv   r)  paramrQ   rQ   rR   r$  ?  s   


r$  optimizer_state_dictc                 C   sF   ddl m} ddl m} tt| d  d tr!|| |j|} | S )zyHandles the case where the optimizer state is saved from a normal optimizer and converts the keys to parameter
    names.r   rw   )OptimStateKeyTyper   )	r   rx   r4  r{   r  r  r   rekey_optim_state_dict
PARAM_NAME)r3  rv   r&  r4  rQ   rQ   rR   r	  J  s
   r	  r]   )TF)T)r>   T)\r-  r   
contextlibr   datetimer   pathlibr   typingr   r   r   r   r	   r
   r   r   r   r   r   "lightning_utilities.core.rank_zeror   r   r   torch.nnr   torch.optimr   typing_extensionsr   r   lightning_fabric.pluginsr   5lightning_fabric.plugins.collectives.torch_collectiver    lightning_fabric.strategies.fsdpr   r   r   r   r   7lightning_fabric.strategies.launchers.subprocess_scriptr   $lightning_fabric.strategies.parallelr   $lightning_fabric.strategies.strategyr   r   r    r!   &lightning_fabric.utilities.distributedr"   r#   r$   r%   r&   r(   r   "lightning_fabric.utilities.importsr)   r*   lightning_fabric.utilities.initr+   lightning_fabric.utilities.loadr,   r-   r.   $lightning_fabric.utilities.rank_zerolightning_fabric.utilities.seedr/    lightning_fabric.utilities.typesr0   r1   r  r2   r3   r5   rK   r   r   r   r   r   r   r   rt   objectr   r   r  r$  r	  rQ   rQ   rQ   rR   <module>   s   0 ~
@

^
 

$*