o
    zii                     @   s0  d Z ddlZddlZddlZddlZddlZddlZddlmZ ddl	m
Z
 ddlmZ ddlmZmZmZmZmZmZ ddlmZ ddlZddlZddlmZ dd	lmZ ddlZdd
lmZmZm Z  ddl!m"Z" ddl#m$Z$ ddl%m&Z& ddl'm(Z(m)Z)m*Z* ddl+m,Z, e-e.Z/e( Z0G dd de$Z1dS )z`
Model Checkpointing
===================

Automatically save model checkpoints during training.
    N)deepcopy)	timedelta)Path)AnyDictLiteralOptionalSetUnion)proxy)Tensor)override)_is_dir_is_local_file_protocolget_filesystem)_PATH)
Checkpoint)MisconfigurationException)WarningCacherank_zero_inforank_zero_warn)STEP_OUTPUTc                       s2  e Zd ZdZdZdZdZdZdZ										
					
dtde	e
 de	e de	e dede	eeed f  dedededede	e de	e de	e de	e def fddZeedefddZedd d!d"d#eddfd$d%Zedud&d'Zedd d!d"d(ed)ed*eddfd+d,Zedud-d.Zedud/d0Zedeeef fd1d2Zed3eeef ddfd4d5Zdd d6eeef ddfd7d8Z dd d9eddfd:d;Z!e"dd d9ed<eddfd=d>Z#dd defd?d@Z$dd defdAdBZ%dvdCdDZ&de	e
 de	e ddfdEdFZ'deddfdGdHZ(de	e de	e de	e ddfdIdJZ)ede	e fdKdLZ*dwdd dMe	e defdNdOZ+	P	
dxde	e dQeeef dRededef
dSdTZ,	dydQeeef de	e dUe	e defdVdWZ-dd de
fdXdYZ.dd de/e fdZd[Z0de
ddfd\d]Z1	dwd6eeef dd d^e	e defd_d`Z2dd deeef fdadbZ3dd d6eeef ddfdcddZ4dd d6eeef ddfdedfZ5dd d6eeef ddfdgdhZ6dMedd d6eeef ddfdidjZ7dwd9e	e
 ddfdkdlZ8d9e
dd defdmdnZ9dd doedMedefdpdqZ:dd d9eddfdrdsZ;  Z<S )zModelCheckpointaR$  Save the model periodically by monitoring a quantity. Every metric logged with
    :meth:`~pytorch_lightning.core.LightningModule.log` or :meth:`~pytorch_lightning.core.LightningModule.log_dict` is
    a candidate for the monitor key. For more information, see :ref:`checkpointing`.

    After training finishes, use :attr:`best_model_path` to retrieve the path to the
    best checkpoint file and :attr:`best_model_score` to retrieve its score.

    Args:
        dirpath: directory to save the model file.

            Example::

                # custom path
                # saves a file like: my/path/epoch=0-step=10.ckpt
                >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')

            By default, dirpath is ``None`` and will be set at runtime to the location
            specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s
            :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` argument,
            and if the Trainer uses a logger, the path will also contain logger name and version.

        filename: checkpoint filename. Can contain named formatting options to be auto-filled.

            Example::

                # save any arbitrary metrics like `val_loss`, etc. in name
                # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
                >>> checkpoint_callback = ModelCheckpoint(
                ...     dirpath='my/path',
                ...     filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
                ... )

            By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``, where "epoch" and "step" match
            the number of finished epoch and optimizer steps respectively.
        monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
        verbose: verbosity mode. Default: ``False``.
        save_last: When ``True``, saves a `last.ckpt` copy whenever a checkpoint file gets saved. Can be set to
            ``'link'`` on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint
            in a deterministic manner. Default: ``None``.
        save_top_k: if ``save_top_k == k``,
            the best k models according to the quantity monitored will be saved.
            If ``save_top_k == 0``, no models are saved.
            If ``save_top_k == -1``, all models are saved.
            Please note that the monitors are checked every ``every_n_epochs`` epochs.
            If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
            unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
            collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k
            ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid
            collisions.
        mode: one of {min, max}.
            If ``save_top_k != 0``, the decision to overwrite the current save file is made
            based on either the maximization or the minimization of the monitored quantity.
            For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
        auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
            For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
            to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
            as this will result in extra folders.
            For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False``
        save_weights_only: if ``True``, then only the model's weights will be
            saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
        every_n_train_steps: Number of training steps between checkpoints.
            If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
            To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
            This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
        train_time_interval: Checkpoints are monitored at the specified time interval.
            For all practical purposes, this cannot be smaller than the amount
            of time it takes to process a single training batch. This is not
            guaranteed to execute at the exact time specified, but should be close.
            This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
        every_n_epochs: Number of epochs between checkpoints.
            This value must be ``None`` or non-negative.
            To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
            This argument does not impact the saving of ``save_last=True`` checkpoints.
            If all of ``every_n_epochs``, ``every_n_train_steps`` and
            ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
            (equivalent to ``every_n_epochs = 1``).
            If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
            saving at the end of each epoch is disabled
            (equivalent to ``every_n_epochs = 0``).
            This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
            Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
            ``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
            will only save checkpoints at epochs 0 < E <= N
            where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
        save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
            If this is ``False``, then the check runs at the end of the validation.
        enable_version_counter: Whether to append a version to the existing file name.
            If this is ``False``, then the checkpoint files will be overwritten.

    Note:
        For extra customization, ModelCheckpoint includes the following attributes:

        - ``CHECKPOINT_JOIN_CHAR = "-"``
        - ``CHECKPOINT_EQUALS_CHAR = "="``
        - ``CHECKPOINT_NAME_LAST = "last"``
        - ``FILE_EXTENSION = ".ckpt"``
        - ``STARTING_VERSION = 1``

        For example, you can change the default last checkpoint name by doing
        ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``

        If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
        then you should create multiple ``ModelCheckpoint`` callbacks.

        If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
        only ``best_model_path`` will be reloaded and a warning will be issued.

    Raises:
        MisconfigurationException:
            If ``save_top_k`` is smaller than ``-1``,
            if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
            if ``mode`` is none of ``"min"`` or ``"max"``.
        ValueError:
            If ``trainer.save_checkpoint`` is ``None``.

    Example::

        >>> from pytorch_lightning import Trainer
        >>> from pytorch_lightning.callbacks import ModelCheckpoint

        # saves checkpoints to 'my/path/' at every epoch
        >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
        >>> trainer = Trainer(callbacks=[checkpoint_callback])

        # save epoch and val_loss in name
        # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
        >>> checkpoint_callback = ModelCheckpoint(
        ...     monitor='val_loss',
        ...     dirpath='my/path/',
        ...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
        ... )

        # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
        # or Neptune, due to the presence of characters like '=' or '/')
        # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
        >>> checkpoint_callback = ModelCheckpoint(
        ...     monitor='val/loss',
        ...     dirpath='my/path/',
        ...     filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
        ...     auto_insert_metric_name=False
        ... )

        # retrieve the best checkpoint after training
        checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
        trainer = Trainer(callbacks=[checkpoint_callback])
        model = ...
        trainer.fit(model)
        checkpoint_callback.best_model_path

    .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
        following arguments:

        *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval*

        Read more: :ref:`Persisting Callback State <extensions/callbacks_state:save callback state>`

    -=lastz.ckpt   NFminTdirpathfilenamemonitorverbose	save_lastlink
save_top_ksave_weights_onlymodeauto_insert_metric_nameevery_n_train_stepstrain_time_intervalevery_n_epochssave_on_train_epoch_endenable_version_counterc                    s   t    || _|| _|| _|| _|| _|	| _|| _|| _	d| _
d | _d | _i | _d| _d | _d| _d| _d| _|  |  | | | || | |
|| |   d S )Nr    )super__init__r    r!   r"   r$   r%   r'   _save_on_train_epoch_end_enable_version_counter_last_global_step_saved_last_time_checkedcurrent_scorebest_k_modelskth_best_model_pathbest_model_scorebest_model_pathlast_model_path_last_checkpoint_saved#_ModelCheckpoint__init_monitor_mode_ModelCheckpoint__init_ckpt_dir_ModelCheckpoint__init_triggers-_ModelCheckpoint__validate_init_configuration)selfr   r   r    r!   r"   r$   r%   r&   r'   r(   r)   r*   r+   r,   	__class__ `/home/ubuntu/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.pyr/      s0   

zModelCheckpoint.__init__returnc                 C   s   | j | j| j| j| j| jdS )N)r    r&   r(   r*   r)   )_generate_state_keyr    r&   _every_n_train_steps_every_n_epochs_train_time_intervalr?   rB   rB   rC   	state_key  s   zModelCheckpoint.state_keytrainer
pl.Trainer	pl_modulepl.LightningModulestagec                 C   sr   |  |}|j|}|| _t| jpd| _|jr#|dkr#| | j | jdkr5t	| js7t
d| dd S d S )Nr-   fitr#   zY`ModelCheckpoint(save_last='link')` is only supported for local file paths, got `dirpath=z`.)"_ModelCheckpoint__resolve_ckpt_dirstrategy	broadcastr   r   _fsis_global_zero'_ModelCheckpoint__warn_if_dir_not_emptyr"   r   
ValueError)r?   rK   rM   rO   r   rB   rB   rC   setup  s   

zModelCheckpoint.setupc                 C   s   t  | _d S N)time	monotonicr3   )r?   rK   rM   rB   rB   rC   on_train_start  s   zModelCheckpoint.on_train_startoutputsbatch	batch_idxc                 C   s   |  |rdS | jdk p|j| j dk}| j}d}t }	|r4| j}
|
du p-|	|
 | k }|j	|}|r:|r:dS |s?|	| _| 
|}| || | || dS )zTSave checkpoint on train batch end if we meet the criteria for `every_n_train_steps`Nr   r   T)_should_skip_saving_checkpointrF   global_steprH   rZ   r[   r3   total_secondsrR   rS   _monitor_candidates_save_topk_checkpoint_save_last_checkpoint)r?   rK   rM   r]   r^   r_   
skip_batchr)   	skip_timenowprev_time_checkmonitor_candidatesrB   rB   rC   on_train_batch_end  s"   


z"ModelCheckpoint.on_train_batch_endc                 C   s`   |  |s,| |r.| |}| jdkr$|jd | j dkr$| || | || dS dS dS )z3Save a checkpoint at the end of the training epoch.r   r   Nr`   _should_save_on_train_epoch_endrc   rG   current_epochrd   re   r?   rK   rM   rj   rB   rB   rC   on_train_epoch_end?     
z"ModelCheckpoint.on_train_epoch_endc                 C   s`   |  |s,| |s.| |}| jdkr$|jd | j dkr$| || | || dS dS dS )z5Save a checkpoint at the end of the validation stage.r   r   Nrl   ro   rB   rB   rC   on_validation_endH  rq   z!ModelCheckpoint.on_validation_endc              
   C   s*   | j | j| j| j| j| j| j| j| jd	S )N	r    r7   r8   r4   r   r5   r6   	kth_valuer9   rs   rI   rB   rB   rC   
state_dictQ  s   zModelCheckpoint.state_dictru   c                 C   s   | d| j}| j|kr2|d | _| d| j| _| d| j| _| d| j| _| d| j| _ntd|d| jd	 |d
 | _	d S )Nr   r7   r6   rt   r5   r9   zThe dirpath has changed from z to z, therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.r8   )
getr   r7   r6   rt   r5   r9   warningswarnr8   )r?   ru   dirpath_from_ckptrB   rB   rC   load_state_dict_  s   

zModelCheckpoint.load_state_dictrj   c                 C   s   | j dkrd S | jd ur8| j|vr0d| jdt| d| jd}|jjjjr+t|t	| | 
|| d S | || d S )Nr   z`ModelCheckpoint(monitor=z=)` could not find the monitored key in the returned metrics: z. HINT: Did you call `log(z#, value)` in the `LightningModule`?)r$   r    listfit_loop
epoch_loopval_loop_has_runr   warning_cacherx   _save_monitor_checkpoint_save_none_monitor_checkpoint)r?   rK   rj   mrB   rB   rC   rd   r  s   




z%ModelCheckpoint._save_topk_checkpointfilepathc                 C   sD   | || j |j| _|| _|jr|jD ]}|t|  qd S d S rY   )	save_checkpointr%   ra   r2   r:   rU   loggersafter_save_checkpointr   )r?   rK   r   loggerrB   rB   rC   _save_checkpoint  s   
z ModelCheckpoint._save_checkpointlinkpathc                 C   s   | j rAtj|stj|rt| ntj|r t| zt	tj
|tj|| W n ty@   t|| Y nw | j  d S rY   )rU   ospathislinkisfileremoveisdirshutilrmtreesymlinkrelpathdirnameOSErrorcopyrR   barrier)rK   r   r   rB   rB   rC   _link_checkpoint  s   
"z ModelCheckpoint._link_checkpointc                 C   s6   ddl m} t|jp|jj|jkp|jp| j|j	kS )Nr   )	TrainerFn)
 pytorch_lightning.trainer.statesr   boolfast_dev_runstatefnFITTINGsanity_checkingr2   ra   )r?   rK   r   rB   rB   rC   r`     s   

z.ModelCheckpoint._should_skip_saving_checkpointc                 C   sP   | j d ur| j S |jdkrdS t|jtrt|jn|j}|dkr#dS |jdkS )Nr   Fr   Tg      ?)r0   check_val_every_n_epoch
isinstancenum_val_batchesr{   sumval_check_interval)r?   rK   r   rB   rB   rC   rm     s   


z/ModelCheckpoint._should_save_on_train_epoch_endc                 C   s   | j dk rtd| j  d| jdk rtd| j d| jdk r*td| j d| jdk}| jdk}| jd u}|| | dkrRtd	| j d
| j d| j d| jd u re| j dvrgtd| j  dd S d S )NzInvalid value for save_top_k=z. Must be >= -1r   z&Invalid value for every_n_train_steps=z. Must be >= 0z!Invalid value for every_n_epochs=r   z.Combination of parameters every_n_train_steps=z, every_n_epochs=z and train_time_interval=z should be mutually exclusive.)r   r   r   zModelCheckpoint(save_top_k=zM, monitor=None) is not a valid configuration. No quantity for top_k to track.)r$   r   rF   rG   rH   r    )r?   every_n_train_steps_triggeredevery_n_epochs_triggeredtrain_time_interval_triggeredrB   rB   rC   __validate_init_configuration  s0   






z-ModelCheckpoint.__validate_init_configurationc                 C   sJ   t |r|nd| _|rt|r|ndrtjtj|}|| _|| _d S )Nr-   )	r   rT   r   r   r   realpath
expanduserr   r   )r?   r   r   rB   rB   rC   __init_ckpt_dir  s
   
zModelCheckpoint.__init_ckpt_dirc                 C   sZ   t t j}|df| dfd}||vr#tdd|  d| || \| _| _d S )Nr   maxr   r   z`mode` can be z, z	 but got )torchtensorinfr   joinkeysrt   r&   )r?   r&   	torch_inf	mode_dictrB   rB   rC   __init_monitor_mode  s
   z#ModelCheckpoint.__init_monitor_modec                 C   sR   |d u r|d u r|d u rd}d}t d n|pd}|pd}|| _|| _|| _d S )Nr   r   zQBoth every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1)logdebugrH   rG   rF   )r?   r(   r*   r)   rB   rB   rC   __init_triggers  s   
zModelCheckpoint.__init_triggersc                 C   s   | j S rY   )rG   rI   rB   rB   rC   r*     s   zModelCheckpoint.every_n_epochscurrentc                 C   sl   |d u rdS | j dkrdS t| j| j k }|rdS tjtjd| j }||| j| j }|j	t
|}|S )NFr   Tr   )r$   lenr5   r   ltgtr&   r6   rR   reduce_boolean_decisionr   )r?   rK   r   less_than_k_models
monitor_opshould_update_best_and_saverB   rB   rC   check_monitor_top_k  s   
z#ModelCheckpoint.check_monitor_top_kr-   metricsprefixc                 C   s   |s	d| j  d }td|}t|dd dd}|D ],}|dd  }|r1|||| j d	 | }||d
| d}||vrFtd||< q||}|rV| j 	||g}|S )Nz{epoch}z{step}z(\{.*?)[:\}]c                 S   s   t | S rY   )r   )xrB   rB   rC   <lambda>%  s    z9ModelCheckpoint._format_checkpoint_name.<locals>.<lambda>T)keyreverser   {z{0[]r   )
CHECKPOINT_JOIN_CHARrefindallsortedreplaceCHECKPOINT_EQUALS_CHARr   r   formatr   )r?   r   r   r   r'   groupsgroupnamerB   rB   rC   _format_checkpoint_name  s    
z'ModelCheckpoint._format_checkpoint_nameverc                 C   sb   |p| j }| j||| jd}|dur| j|d| f}| | j }| jr/tj| j|S |S )a  Generate a filename according to the defined template.

        Example::

            >>> tmpdir = os.path.dirname(__file__)
            >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
            >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0)))
            'epoch=0.ckpt'
            >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
            >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5)))
            'epoch=005.ckpt'
            >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
            >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
            'epoch=2-val_loss=0.12.ckpt'
            >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}'))
            'epoch=2.ckpt'
            >>> ckpt = ModelCheckpoint(dirpath=tmpdir,
            ... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
            ... auto_insert_metric_name=False)
            >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
            'epoch=2-validation_loss=0.12.ckpt'
            >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
            >>> os.path.basename(ckpt.format_checkpoint_name({}))
            'missing=0.ckpt'
            >>> ckpt = ModelCheckpoint(filename='{step}')
            >>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0)))
            'step=0.ckpt'

        )r'   Nv)	r   r   r'   r   r   FILE_EXTENSIONr   r   r   )r?   r   r   r   	ckpt_namerB   rB   rC   format_checkpoint_name9  s   
 z&ModelCheckpoint.format_checkpoint_namec                 C   s   | j dur| j S t|jdkrF|jd jdur|jd j}n|j}|jd j}|jd j}t|tr4|nd| }t	j
|t||d}|S t	j
|jd}|S )a  Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to
        determine where to save checkpoints. The path for saving weights is set in this priority:

        1.  The ``ModelCheckpoint``'s ``dirpath`` if passed in
        2.  The ``Logger``'s ``log_dir`` if the trainer has loggers
        3.  The ``Trainer``'s ``default_root_dir`` if the trainer has no loggers

        The path gets extended with subdirectory "checkpoints".

        Nr   version_checkpoints)r   r   r   save_dirdefault_root_dirr   versionr   strr   r   r   )r?   rK   r   r   r   	ckpt_pathrB   rB   rC   __resolve_ckpt_dirb  s   
z"ModelCheckpoint.__resolve_ckpt_dirc                    s`    |}dj ddtdtffdd j|r- fddjj|d	d
D S t S )N^z	(-(\d+))?r   rD   c                    s   | j jkott | jS rY   )suffixr   r   r   matchstem)r   )last_patternr?   rB   rC   _is_last  s   z8ModelCheckpoint._find_last_checkpoints.<locals>._is_lastc                    s$   h | ]} t |rtj|qS rB   )r   r   r   normpath).0p)r   rB   rC   	<setcomp>  s   $ z9ModelCheckpoint._find_last_checkpoints.<locals>.<setcomp>F)detail)rQ   CHECKPOINT_NAME_LASTr   r   rT   existslsset)r?   rK   r   rB   )r   r   r?   rC   _find_last_checkpoints  s   
z&ModelCheckpoint._find_last_checkpointsc                 C   sN   | j dkr!t| j|ddr#t| j|dkr%td| d d S d S d S d S )Nr   T)strictzCheckpoint directory z exists and is not empty.)r$   r   rT   r   r   r   )r?   r   rB   rB   rC   __warn_if_dir_not_empty  s   .z'ModelCheckpoint.__warn_if_dir_not_emptydel_filepathc                 C   sX   |  |}| jr*| j}| ||r*||kr*| j ||d}|d7 }| ||r*||ks|S )Nr   r   )r   r1   STARTING_VERSIONfile_exists)r?   rj   rK   r   r   version_cntrB   rB   rC   &_get_metric_interpolated_filepath_name  s   
z6ModelCheckpoint._get_metric_interpolated_filepath_namec                 C   sf   t |j}|d}t|tr| nt|j|d< |d}t|tr)| nt|j	|d< |S )Nepochstep)
r   callback_metricsrv   r   r   intr   r   rn   ra   )r?   rK   rj   r   r   rB   rB   rC   rc     s   

"
"z#ModelCheckpoint._monitor_candidatesc                 C   s   | j sd S | || j}| jr5| j}| ||r5|| jkr5| j|| j|d}|d7 }| ||r5|| jks| j|}| _| j dkrR| jrR| jdkrR| 	|| j| n| 
|| |ri| |||rk| || d S d S d S )Nr   r   r#   r   )r"   r   r   r1   r   r   r9   r:   r$   r   r   _should_remove_checkpoint_remove_checkpoint)r?   rK   rj   r   r   previousrB   rB   rC   re     s    z%ModelCheckpoint._save_last_checkpointc              	   C   s   | j sJ || j }| ||r |d usJ | ||| d S | jrA|d }|d }td|dd|dd| j d| j  d S d S )Nr   r   Epoch d, global step : z was not in top )r    rv   r   _update_best_and_saver!   r   r$   )r?   rK   rj   r   r   r   rB   rB   rC   r     s   
,z(ModelCheckpoint._save_monitor_checkpointc                 C   sb   |  ||| j}| j|}| _| || | jdkr+|r-| |||r/| || d S d S d S d S )Nr   )r   r8   r   r$   r  r  )r?   rK   rj   r   r  rB   rB   rC   r     s   z-ModelCheckpoint._save_none_monitor_checkpointc           
      C   s  | j dkrt| jd n| j }d }t| j|kr%|dkr%| j}| j| t|tr@t|r@tj	t
| jdkr9dnd|jd}| |||}|| _|| j|< t| j|krp| jdkr]tnt}|| j| jjd| _| j| j | _| jdkrwtnt}|| j| jjd| _| j| j | _| jr|d	 }|d
 }	td|dd|	dd| jd|dd| jdd|d|  | || |r| |||r| || d S d S d S )Nr   r   r   r   r   z-inf)device)r   r   r   r  r  r  r  z	 reached z0.5fz (best z), saving model to z as top )r$   r   r5   r6   popr   r   r   isnanr   floatr&   r	  r   r4   r   r   rv   rt   r8   r7   r!   r   r    r   r  r  )
r?   r   rK   rj   kr   r   _opr   r   rB   rB   rC   r    sB   "
"z%ModelCheckpoint._update_best_and_savec                 C   sv   dd | j  D }|du r| jsJ tj| jd}| j|d}t	|| W d   dS 1 s4w   Y  dS )ztSaves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
        file.c                 S   s   i | ]	\}}||  qS rB   )item)r   r  r   rB   rB   rC   
<dictcomp>  s    z+ModelCheckpoint.to_yaml.<locals>.<dictcomp>Nzbest_k_models.yamlw)
r5   itemsr   r   r   r   rT   openyamldump)r?   r   best_kfprB   rB   rC   to_yaml  s   
"zModelCheckpoint.to_yamlc                 C   s   | j |}|j|S )zChecks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
        state to diverge between ranks.)rT   r   rR   rS   )r?   r   rK   r   rB   rB   rC   r     s   zModelCheckpoint.file_existsr  c                 C   sz   ||krdS t |sdS t| }|jdurt|j nd}|dur*||kr*dS | jdus1J t| j }||jv S )a  Checks if the previous checkpoint should be deleted.

        A checkpoint won't be deleted if any of the cases apply:
        - The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
        - The previous checkpoint is not in the current checkpoint directory and the filesystem is local
        - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local

        FTN)r   r   absoluter   r   parents)r?   rK   r  r   resume_pathr   rB   rB   rC   r  
  s   	
z)ModelCheckpoint._should_remove_checkpointc                 C   s   |j | dS )z1Calls the strategy to remove the checkpoint file.N)rR   remove_checkpoint)r?   rK   r   rB   rB   rC   r    s   z"ModelCheckpoint._remove_checkpoint)NNNFNr   Fr   TNNNNT)rK   rL   rM   rN   rD   N)rD   NrY   )r-   T)NN)=__name__
__module____qualname____doc__r   r   r   r   r   r   r   r   r   r
   r   r   r   r/   propertyr   rJ   rX   r\   r   r   rk   rp   rr   r   ru   rz   r   rd   r   staticmethodr   r`   rm   r>   r<   r;   r=   r*   r   r   r   rQ   r	   r   rV   r   rc   re   r   r   r  r  r   r  r  __classcell__rB   rB   r@   rC   r   2   s4    	
+	 

		


%

)


	

*
r   )2r   loggingr   r   r   rZ   rw   r   r   datetimer   pathlibr   typingr   r   r   r   r	   r
   weakrefr   r   r  r   typing_extensionsr   pytorch_lightningpl#lightning_fabric.utilities.cloud_ior   r   r    lightning_fabric.utilities.typesr   pytorch_lightning.callbacksr   &pytorch_lightning.utilities.exceptionsr   %pytorch_lightning.utilities.rank_zeror   r   r   !pytorch_lightning.utilities.typesr   	getLoggerr  r   r   r   rB   rB   rB   rC   <module>   s4    
