o
    }oi                     @   s  d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZmZ d dlZd dlmZ d dlmZ d dlmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlm Z  d dl!m"Z" d dl#m$Z$ d dl%m&Z&m'Z'm(Z( d dl)m*Z*m+Z+ G dd deZ,dS )    N)deepcopy)Path)AnyDictIterableListOptionalUnion)proxy)get_filesystem)ModelCheckpoint_is_local_file_protocol)call)rank_zero_info)EMA)logging)AppState)AsyncFinalizableCheckpointIO)is_global_rank_zero)ckpt_to_dirinject_model_parallel_rankuninject_model_parallel_rank)import_multistorageclientis_multistorageclient_urlc                       s  e Zd ZdZdZ								dXded	ed
edededededef fddZdd Z	dd Z
deeef ddf fddZdeddf fddZ fddZ fdd Zdee fd!d"ZdYd#ee defd$d%Zd&eddfd'd(Zd)d*dee fd+d,Zd&eeef d-ee ddfd.d/Zd&eeef dee fd0d1ZdZd2d3Zed4eeef defd5d6Zed4eeef defd7d8Zed[d4eeef ddfd9d:Z ed[d4eeef ddfd;d<Z!	d\d&ed)d*d=edefd>d?Z"d)d*d&eddf fd@dAZ#d)d*d&edBefdCdDZ$	d[d)d*d&eddf fdEdFZ%d&edefdGdHZ&dIe'e defdJdKZ(d&eeef defdLdMZ)e*de'e fdNdOZ+edPeeef ddfdQdRZ,d)dSdTedUedefdVdWZ-  Z.S )]NeMoModelCheckpointa?  Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end.
    Extends Lightning's on_save_checkpoint func to save the .nemo file. Saves the .nemo file based
    on the best checkpoint saved (according to the monitor value).
    Also contains func to save the EMA copy of the model.
    z-unfinishedFT.nemoNalways_save_nemosave_nemo_on_train_endsave_best_modelpostfixn_resumemodel_parallel_size
async_savesave_last_n_optim_statesc	           
         s   || _ || _|| _|| _| jr| jstd || _d| _|| _|| _	d | _
g | _d|	v r4|	d| _nd| _t jdi |	 | jdkrR|rTtd |   d S d S d S )NzFound save_best_model is True and save_nemo_on_train_end is False. Set save_nemo_on_train_end to True to automatically save the best model. prefixr   zChecking previous runs )r   r   r   r$   r   warningr    previous_best_pathr"   r#   async_finalize_cbdeferred_ckpts_to_removepopr&   super__init__
save_top_kdebugnemo_topk_check_previous_run)
selfr   r   r   r    r!   r"   r#   r$   kwargs	__class__r'   ^/home/ubuntu/.local/lib/python3.10/site-packages/nemo/utils/callbacks/nemo_model_checkpoint.pyr.   0   s,   
zNeMoModelCheckpoint.__init__c                    sT  z j   j  j  j W n ty   tdw i  _ d _d _d _t fdd jD }|D ]Z}dt|v sBdt|v rFt|}t|}|dd d	ksZ|d
d dkr[q4|	 j
t j
 d }|t j
krtd||d }|r||||  d  }t| j |< q4t j dk rdS  jdkrdnd}t j  j j|d} jdur|d  rt| j }nt| j j  }nt| j }td|}td|    j}	t|D ]/}
|d} j |  | |	r j |r  | td|  q|d  _|d  _ j  j  _dS )z3
        Check if there are previous runs.
        zQLightning's ModelCheckpoint was updated. NeMoModelCheckpoint will need an update.r%   Nc                 3   s    | ]
}  |s|V  qd S N_is_ema_filepath).0pathr2   r'   r6   	<genexpr>p   s    zCNeMoModelCheckpoint.nemo_topk_check_previous_run.<locals>.<genexpr>mp_ranktp_ranki
-last.ckpt-last   z[A-z]minFTkeyreverser   zNumber of models to delete: r   Removed checkpoint: ) best_k_modelskth_best_model_pathbest_model_scorebest_model_pathAttributeErrorlist_saved_checkpoint_pathsstrr   findmonitorlenresearchstartfloatmodesortedgetr"   is_dirr/   maxr   r0   _has_ema_ckptsranger,   _del_model_without_trainer_fsexists_ema_format_filepath)r2   checkpoints
checkpointindexmatchvalue_reverserI   models_to_deleteema_enabled_modelr'   r<   r6   r1   `   s`   
 





z0NeMoModelCheckpoint.nemo_topk_check_previous_runc                    s   dt dtffdd  fddj D _tjdkrHjdk}tjjj|d	}|d
 _jj _	|d _
jj
 _d S d_d _	d_
d _d S )N	ckpt_pathreturnc                    s>   t j| pt jt| pt j| d}|o |  S )N.ckpt)osr;   isfiler   isdirremovesuffixis_checkpoint_unfinished)rm   ra   r<   r'   r6   __is_ckpt_ok   s   zKNeMoModelCheckpoint._remove_invalid_entries_from_topk.<locals>.__is_ckpt_okc                    s   i | ]\}} |r||qS r'   r'   )r:   kv) _NeMoModelCheckpoint__is_ckpt_okr'   r6   
<dictcomp>   s    zINeMoModelCheckpoint._remove_invalid_entries_from_topk.<locals>.<dictcomp>r   rD   rE   r   r%   )rP   boolrI   itemsrS   rX   rY   rZ   rJ   	kth_valuerL   rK   )r2   reverse_arrbest_k_models_arrr'   )rx   r2   r6   !_remove_invalid_entries_from_topk   s   



z5NeMoModelCheckpoint._remove_invalid_entries_from_topk
state_dictrn   c                    s   t  | |   dS )z&
        Load the state dict.
        N)r-   load_state_dictr   )r2   r   r4   r'   r6   r      s   z#NeMoModelCheckpoint.load_state_dictstagec                    sh   t  rtd t| j tj rtj	  t
 ||| |j|j}||_|j| j| _dS )z'
        Setup the checkpoint.
        z)Removing unfinished checkpoints if any...N)r   r   r0   r   _remove_unfinished_checkpointsdirpathtorchdistributedis_initializedbarrierr-   setupstrategy	broadcastrm   last_model_path)r2   trainer	pl_moduler   r;   r4   r'   r6   r      s   


zNeMoModelCheckpoint.setupc           	         s  t  |||}| js|S t }|jdur|jdkrtd |  |_|jdur4|jdkr4t	| j
}n| j
}| jrtj|sBdS | j
| jkrOtd |S | j
| _t| }tj|ddd}d|v ri|d }|j|d	d
 tj rztj  | |}|j|jd td|j  |j|d	d
 ntj rtj  | |}|j|jd td|j  |durt rtd|  t|| |S )z&
        Save the checkpoint.
        NrC   z@always_save_nemo will slow down training for model_parallel > 1.z*Best model has not changed, skipping save.cpuF)map_locationweights_onlyr   Tstrict	save_pathzNew best .nemo model saved to: zNew .nemo model saved to: Removing old .nemo backup )r-   on_save_checkpointr   r   r"   r   r(   _format_nemo_checkpoint_namemodel_restore_pathr   rL   r   rp   r;   ra   r)   r0   r   r   r   loadr   r   r   r   _backup_existing_nemo_ckptsave_toinfor   r   rm)	r2   r   r   rd   output	app_statemaybe_injected_best_model_pathold_state_dictbackup_pathr4   r'   r6   r      sJ   








z&NeMoModelCheckpoint.on_save_checkpointc                    s  |j rdS | jrR|jdkrRd}t|jtr|j|j dkrd}t|jtr/|j|j dkr/d}|rR| |}| j| 	|| j
krKtd| j d nt || t || | jr|jd | jdkrpt|  d	 n$tj| jd
d r| jd
d | _|j| j| _|j| j | jr| |}|j|  d |durt  rt!d|  t"|#| dS dS dS dS )z3
        Save the checkpoint on train end.
        Nr   FTzLast checkpoint z already savedz&SaveBestCheckpointConnector.resume_endr%   z was told to save the best checkpoint at the end of training, but no saved checkpoints were found. Saving latest model instead.ro   r   r   )$fast_dev_run	save_lastval_check_interval
isinstancerW   global_stepint_monitor_candidatesr   format_checkpoint_nameCHECKPOINT_NAME_LASTr   r0   r-   _save_last_checkpointon_train_endr   r   r   rL   r(   rp   r;   rr   splitr   _checkpoint_connectorrestorer   r   r   r   r   r   r   r   )r2   r   r   should_save_last_checkpointmonitor_candidatesr   r4   r'   r6   r     s@   


z NeMoModelCheckpoint.on_train_endc                 C   s   |   }|}| jr%| j}| j||ddr%|  |}|d7 }| j||dds||kr+dS |jrCt| d|  t|r=nt	|| |j
  |S )a  Search for an available name with version infix and rename existing checkpoint.

        NOTE: this behavior is slightly different from regular checkpoints.
        PTL creates new regular checkpoint with the first available name.
        Here, for backward compatibility, we create .nemo checkpoint as before
        and create a backup under the first available name.

        Args:
            trainer (Trainer): trainer instance.

        Returns:
            Path to the backup checkpoint or None, if no backup was created
        F)check_dist_ckptrC   Nz/ already exists, moving existing checkpoint to )r   _enable_version_counterSTARTING_VERSIONfile_existsis_global_zeror   r   r   shutilmover   r   )r2   r   	base_pathavailable_pathversion_cntr'   r'   r6   r   3  s"   

z.NeMoModelCheckpoint._backup_existing_nemo_ckptverc              	   C   sn   |d u rdn| j  d| }t| jr!| j d| j| | j  S tjtjtj	| j| j| | j S )Nr%   rw   /)
CHECKPOINT_JOIN_CHARr   r   r&   r    rp   r;   abspath
expanduserjoin)r2   r   version_infixr'   r'   r6   r   V  s   
"z0NeMoModelCheckpoint._format_nemo_checkpoint_namefilepathc                 C   s   t |}t| r5t r3zt|}tj|dd td|  W d S    td| d Y d S d S t }|j	d urF|j	dkrFt
|}t sS|j	d urt|jdkrvz| j| td|  W d S    td	| d Y d S d S d S )
NT)ignore_errorsz Removed distributed checkpoint: z(Tried to remove distributed checkpoint: z but failed.rC   r   rH   zTried to remove checkpoint: )r   r   r[   r   r   rmtreer   r   r   r"   r   data_parallel_rankr`   r   )r2   r   	dist_ckptr   r'   r'   r6   r_   ^  s*   	z.NeMoModelCheckpoint._del_model_without_trainerr   zlightning.pytorch.Trainerc                 C   s"   d }|j D ]	}t|tr|}q|S r7   )	callbacksr   r   )r2   r   ema_callbackcallbackr'   r'   r6   _ema_callback~  s   

z!NeMoModelCheckpoint._ema_callbackstorage_optionsc           	      C   s  |  |}d}t|| j d }t|| jkr|| }td| d |jj|dd}| || |d u r<tdd}nd|d< |j	| | d	| j
|d
 t ra|j| t| | | tj rktj  |jjt|dd}| || td| d d S d S )Nz	-no-optimrC   z	Loading 'z(' checkpoint to drop optimizer states...F)checkpoint_pathload_optimizer_states)include_optimizerr   ro   r   z+Successfully dropped optimizer states for 'z' checkpoint.)_get_checkpoints_listrS   r$   r   r   r   load_checkpoint_load_current_state_dictdictsave_checkpointsave_weights_onlyr   remove_checkpointr   r   r   r   r   r   r   )	r2   r   r   r   rc   suffixcheckpoint_indexr   rd   r'   r'   r6   _drop_optimizer_states  s2   


z*NeMoModelCheckpoint._drop_optimizer_statesc                    sJ   t j|  fddt  D }t|dd d} fdd|D }|S )Nc                    s.   g | ]}t jt j |rd |vr|qS )rB   )rp   r;   rr   r   r:   dcheckpoints_dirr'   r6   
<listcomp>  s    z=NeMoModelCheckpoint._get_checkpoints_list.<locals>.<listcomp>c                 S   s   t | dd dd S )Nz-step=rC   -r   )r   r   )xr'   r'   r6   <lambda>      z;NeMoModelCheckpoint._get_checkpoints_list.<locals>.<lambda>)rF   c                    s   g | ]	}t j |qS r'   )rp   r;   r   )r:   rd   r   r'   r6   r     s    )rp   r;   dirnamelistdirrY   )r2   r   rc   r'   r   r6   r     s   
z)NeMoModelCheckpoint._get_checkpoints_listc                 C   s&   t |d| |jj||jjd d S )Non_load_checkpointr   )r   _call_lightning_module_hookr   load_model_state_dictlightning_modulestrict_loading)r2   r   rd   r'   r'   r6   r     s
   
z,NeMoModelCheckpoint._load_current_state_dictr   c                 C   s8   t t| }|d}|d}|d}t|tj S )a  Format the path to the unfinished checkpoint marker file.

        If the marker file exists, corresponding checkpoint is considered unfinished/incomplete.
        NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint.

        Args:
            checkpoint_path: Path to the checkpoint file or dir.
              Does not need to exist.

        Returns:
            Path to the unfinished checkpoint marker file.
        r   ro   -EMA)rP   r   rs   r   r   UNFINISHED_CHECKPOINT_SUFFIX)r   marker_filepathr'   r'   r6   (format_checkpoint_unfinished_marker_path  s
   


z<NeMoModelCheckpoint.format_checkpoint_unfinished_marker_pathc                 C   s   t |  S )zCheck if the checkpoint is unfinished.

        Args:
            checkpoint_path: Path to the checkpoint file or dir.
              Does not need to exist.

        Returns:
            True if the checkpoint is unfinished, False otherwise.
        )r   r   ra   )r   r'   r'   r6   rt     s   z,NeMoModelCheckpoint.is_checkpoint_unfinishedc                 C   sL   t  rt| }|jjddd |  |r"tj r$tj	  dS dS dS )a  Marks given checkpoint as unfinished.

        Args:
            checkpoint_filepath: Path to the checkpoint file or dir.
              Does not need to exist.
            barrier_after: Synchronize ranks after writing the marker file.
              Defaults to False.
        T)parentsexist_okN)
r   r   r   parentmkdirtouchr   r   r   r   )r   barrier_aftermarker_pathr'   r'   r6    set_checkpoint_unfinished_marker  s   

z4NeMoModelCheckpoint.set_checkpoint_unfinished_markerc                 C   sX   z%|rt j rt j  t r t| }| r#|  W dS W dS W dS    Y dS )a  Clear unfinished marker for given checkpoint.

        Args:
            checkpoint_path: Path to the checkpoint file or dir.
              Does not need to exist.
            barrier_before: Synchronize ranks before removing the marker file.
              Defaults to False.
        N)	r   r   r   r   r   r   r   ra   unlink)r   barrier_beforer   r'   r'   r6   #remove_checkpoint_unfinished_marker  s   


z7NeMoModelCheckpoint.remove_checkpoint_unfinished_markerr   c                 C   sB   t |r| j|}n| j|p|o| jt|}|j|S )zLChecks if a file or a file without a suffix (distributed checkpoint) exists.)r   r`   ra   r   r   r   )r2   r   r   r   ra   r'   r'   r6   r     s    zNeMoModelCheckpoint.file_existsc                    s  | j |dd | |}|d urf| jrtd|| t || W d    n1 s.w   Y  || | |}| j	rHt
d|  t || W d    n1 sYw   Y  | j|dd nM| |||j}| jr|jj}t|ts~tdt|d}| jg  nd }td|j d	t  d
 |j|| j|d | jrtd|  n|  | jdkrd|v r| ||| d S d S d S )NTr   z!async_save with EMA not supportedz*Saving EMA weights to separate checkpoint r   z1Async save requires async compatible CheckpointIO)finalize_fnzCheckpoint save for step z started at .r   z$Scheduled async checkpoint save for r   rB   )r   r   r#   
ValueErrorsave_original_optimizer_stater-   _save_checkpointsave_ema_modelrb   verboser   r   &_get_finalize_save_checkpoint_callbackr   r   checkpoint_ior   r   r   r+   appendr   r   timer   r   r$   r   )r2   r   r   r   r   r  r   r4   r'   r6   r    s>   



z$NeMoModelCheckpoint._save_checkpointr   c                    s    fdd}|S )zLCreates a callback that can be used to finalize async (and sync) ckpt saves.c               	      s   t d d   _ _jr!jD ]	} | t qj dd j	s-d S t 
d d  dt  d jsCJ jd	}t d
|  |D ]
}j|dd qSd S )Nz"Finalize callback called for step z, filepath Tr   zAsync checkpoint save for step z (z) finalized successfully at r   r   zCheckpoints to remove: )override_async)r   r0   _last_global_step_saved_last_checkpoint_savedr   loggersafter_save_checkpointr
   r   r#   r   r  r+   r,   _remove_checkpoint)loggerckpts_to_removeckpt_to_remover   r   r2   r   r'   r6   _cbL  s$   

zGNeMoModelCheckpoint._get_finalize_save_checkpoint_callback.<locals>._cbr'   )r2   r   r   r   r  r'   r  r6   r  G  s   z:NeMoModelCheckpoint._get_finalize_save_checkpoint_callbackc                    sv   | j r|s| jd | dS | j|dd t || | |}|dur2| |}t || | j|dd dS )a  Performs checkpoint removal or deferred removal.

        With async save, `self._remove_checkpoint` is called before the checkpoint
        is actually finished so we can't remove it. Instead we add it to
        `self.deferred_ckpts_to_remove` for future removal.
        r   NTr   r   )	r#   r+   r  r   r-   r  r   rb   r   )r2   r   r   r	  r   r4   r'   r6   r  k  s   
	

z&NeMoModelCheckpoint._remove_checkpointc                 C   s   | | jd| j S Nr   )replaceFILE_EXTENSIONr2   r   r'   r'   r6   rb        z(NeMoModelCheckpoint._ema_format_filepathrc   c                    s   t  fdd|D S )Nc                 3   s    | ]}  |V  qd S r7   r8   )r:   r   r<   r'   r6   r=     s    z5NeMoModelCheckpoint._has_ema_ckpts.<locals>.<genexpr>)any)r2   rc   r'   r<   r6   r]     r  z"NeMoModelCheckpoint._has_ema_ckptsc                 C   s   t |d| j S r  )rP   endswithr  r  r'   r'   r6   r9     r  z$NeMoModelCheckpoint._is_ema_filepathc                    s~   t  jrt }| j dS dd t jdD }|r)t fdd|S dd t jdD }t fd	d|S )
Nz/*.ckptc                 S   s   g | ]}|  r|qS r'   )r[   r   r'   r'   r6   r     s    z?NeMoModelCheckpoint._saved_checkpoint_paths.<locals>.<listcomp>*c                         |  S r7   rt   pr<   r'   r6   r         z=NeMoModelCheckpoint._saved_checkpoint_paths.<locals>.<lambda>c                 S   s   g | ]}|qS r'   r'   r:   fr'   r'   r6   r     s    *.ckptc                    r  r7   r  r  r<   r'   r6   r     r   )r   r   r   globr   filterrglob)r2   mscdist_checkpointscheckpoint_filesr'   r<   r6   rO     s   
z+NeMoModelCheckpoint._saved_checkpoint_pathscheckpoint_dirc           
      C   s"  t  stdt| r)t }||  dtj }t| }|D ]}|| qd S t	| } dd | dtj D }dd | 
dD }|D ]}t|}||v r_td|  t| qGdd | dD }|D ]}t|}||v rtd	|  t| ql|D ]}	t|	 qd S )
Nz8_remove_unfinished_checkpoints should run only on rank 0r  c                 S      h | ]
}|  r| qS r'   )is_fileresolver!  r'   r'   r6   	<setcomp>  s    zENeMoModelCheckpoint._remove_unfinished_checkpoints.<locals>.<setcomp>c                 S   s   h | ]}|  qS r'   )r-  r!  r'   r'   r6   r.    s    r#  z Removing unfinished checkpoint: c                 S   r+  r'   )r[   r-  r   r'   r'   r6   r.    r   z%Removing unfinished dist checkpoint: )r   AssertionErrorr   r   r$  r   r   r   r   r   r&  r   r   r(   rp   remover   r   )
r*  r'  existing_marker_filepathsfsckpt_filepathcheckpoint_filepathspossible_marker_pathall_dirpathsckpt_dirpathr   r'   r'   r6   r     s@   



z2NeMoModelCheckpoint._remove_unfinished_checkpointsz
pl.Trainerpreviouscurrentc                 C   s   ||krdS t |sdS t| }|jdurt|j nd}|dur8||kr8t|dr6|jdr6ndS | jdu rEt| j	 dt| j }||j
v S )a  Checks if the previous checkpoint should be deleted.
        A checkpoint won't be deleted if any of the cases apply:
        - The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
        - The previous checkpoint is not in the current checkpoint directory and the filesystem is local
        - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
            and the resumed from checkpoint is not the last checkpoint
        FTNr@   z.dirpath is None.)r   r   absoluterm   rP   r  namer   r   r5   r   )r2   r   r8  r9  resume_pathr   r'   r'   r6   _should_remove_checkpoint  s   

z-NeMoModelCheckpoint._should_remove_checkpoint)FTFr   FNFr   r7   )rn   N)F)T)/__name__
__module____qualname____doc__r   rz   rP   r   r.   r1   r   r   r   r   r   r   r   r   r   r   r_   r   r   r	   r   r   r   r   r   staticmethodr   rt   r   r   r   r  r  r  rb   r   r]   r9   propertyrO   r   r=  __classcell__r'   r'   r4   r6   r   '   s    	0B2-# "+

)
%",r   )-rp   rT   r   r  copyr   pathlibr   typingr   r   r   r   r   r	   r   _weakrefr
   #lightning.fabric.utilities.cloud_ior   ,lightning.pytorch.callbacks.model_checkpointr   r   lightning.pytorch.trainerr   lightning.pytorch.utilitiesr   !nemo.collections.common.callbacksr   
nemo.utilsr   nemo.utils.app_stater   !nemo.utils.callbacks.dist_ckpt_ior   nemo.utils.get_rankr   nemo.utils.model_utilsr   r   r   nemo.utils.msc_utilsr   r   r   r'   r'   r'   r6   <module>   s*    