o
    oiZ                     @   s0  d Z ddlZddlZddlZddlmZ ddlmZ ddlm	Z	m
Z
mZmZmZmZmZmZmZ ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lm Z  ddl!m"Z"m#Z# ddl$m%Z% ddl&m'Z' e	rzddl(m)Z) ddl*m+Z+ e,e-Z.edZ/dZ0dedefddZ1G dd de"Z2dS )z
Neptune Logger
--------------
    N)	Namespace)wraps)	TYPE_CHECKINGAnyCallableDict	GeneratorListOptionalSetUnion)RequirementCache)Tensor)override)_add_prefix_convert_params_sanitize_callable_params)
Checkpoint)Loggerrank_zero_experiment)ModelSummary)rank_zero_onlyRunHandlerzneptune>=1.0z*source_code/integrations/pytorch-lightningfuncreturnc                    s&   t  dtdtdtf fdd}|S )Nargskwargsr   c                     sH   ddl m} t|  | i |W  d    S 1 sw   Y  d S )Nr   )InactiveRunException)neptune.exceptionsr    
contextlibsuppress)r   r   r    r    U/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/loggers/neptune.pywrapper7   s   $z _catch_inactive.<locals>.wrapper)r   r   )r   r'   r%   r$   r&   _catch_inactive6   s   r(   c                       s  e Zd ZdZdZdZdZdddddddd	ee d
ee dee dee	d  dee
 dedef fddZdKddZedefddZdedefddZed	ee d
ee dee dee	d  deddfddZdeeef fddZdeeef ddfd d!ZeedLd#d$ZeedLd%d&Zeeed'e	eeef ef ddfd(d)ZeeedMd*eee	ee f f d+ee! ddfd,d-Z"eeed.eddf fd/d0Z#eedee fd1d2Z$eedNd4d5d6e!ddfd7d8Z%eeed9e&ddfd:d;Z'ed<ed9e&defd=d>Z(e)d?eeef d@ede*e fdAdBZ+e)dMdCeeef dDee de,fdEdFZ-eedee fdGdHZ.eedee fdIdJZ/  Z0S )ONeptuneLoggera  Log using `Neptune <https://docs.neptune.ai/integrations/lightning/>`_.

    Install it with pip:

    .. code-block:: bash

        pip install neptune

    or conda:

    .. code-block:: bash

        conda install -c conda-forge neptune-client

    **Quickstart**

    Pass a NeptuneLogger instance to the Trainer to log metadata with Neptune:

    .. code-block:: python


        from lightning.pytorch import Trainer
        from lightning.pytorch.loggers import NeptuneLogger
        import neptune

        neptune_logger = NeptuneLogger(
            api_key=neptune.ANONYMOUS_API_TOKEN,  # replace with your own
            project="common/pytorch-lightning-integration",  # format "workspace-name/project-name"
            tags=["training", "resnet"],  # optional
        )
        trainer = Trainer(max_epochs=10, logger=neptune_logger)

    **How to use NeptuneLogger?**

    Use the logger anywhere in your :class:`~lightning.pytorch.core.LightningModule` as follows:

    .. code-block:: python

        from neptune.types import File
        from lightning.pytorch import LightningModule


        class LitModel(LightningModule):
            def training_step(self, batch, batch_idx):
                # log metrics
                acc = ...
                self.append("train/loss", loss)

            def any_lightning_module_function_or_hook(self):
                # log images
                img = ...
                self.logger.experiment["train/misclassified_images"].append(File.as_image(img))

                # generic recipe
                metadata = ...
                self.logger.experiment["your/metadata/structure"] = metadata

    Note that the syntax ``self.logger.experiment["your/metadata/structure"].append(metadata)`` is specific to
    Neptune and extends the logger capabilities. It lets you log various types of metadata, such as
    scores, files, images, interactive visuals, and CSVs.
    Refer to the `Neptune docs <https://docs.neptune.ai/logging/methods>`_
    for details.
    You can also use the regular logger methods ``log_metrics()``, and ``log_hyperparams()`` with NeptuneLogger.

    **Log after fitting or testing is finished**

    You can log objects after the fitting or testing methods are finished:

    .. code-block:: python

        neptune_logger = NeptuneLogger(project="common/pytorch-lightning-integration")

        trainer = pl.Trainer(logger=neptune_logger)
        model = ...
        datamodule = ...
        trainer.fit(model, datamodule=datamodule)
        trainer.test(model, datamodule=datamodule)

        # Log objects after `fit` or `test` methods
        # model summary
        neptune_logger.log_model_summary(model=model, max_depth=-1)

        # generic recipe
        metadata = ...
        neptune_logger.experiment["your/metadata/structure"] = metadata

    **Log model checkpoints**

    If you have :class:`~lightning.pytorch.callbacks.ModelCheckpoint` configured,
    the Neptune logger automatically logs model checkpoints.
    Model weights will be uploaded to the "model/checkpoints" namespace in the Neptune run.
    You can disable this option with:

    .. code-block:: python

        neptune_logger = NeptuneLogger(log_model_checkpoints=False)

    **Pass additional parameters to the Neptune run**

    You can also pass ``neptune_run_kwargs`` to add details to the run, like ``tags`` or ``description``:

    .. testcode::
        :skipif: not _NEPTUNE_AVAILABLE

        from lightning.pytorch import Trainer
        from lightning.pytorch.loggers import NeptuneLogger

        neptune_logger = NeptuneLogger(
            project="common/pytorch-lightning-integration",
            name="lightning-run",
            description="mlp quick run with pytorch-lightning",
            tags=["mlp", "quick-run"],
        )
        trainer = Trainer(max_epochs=3, logger=neptune_logger)

    Check `run documentation <https://docs.neptune.ai/api/neptune/#init_run>`_
    for more info about additional run parameters.

    **Details about Neptune run structure**

    Runs can be viewed as nested dictionary-like structures that you can define in your code.
    Thanks to this you can easily organize your metadata in a way that is most convenient for you.

    The hierarchical structure that you apply to your metadata is reflected in the Neptune web app.

    See also:
        - Read about
          `what objects you can log to Neptune <https://docs.neptune.ai/logging/what_you_can_log/>`_.
        - Check out an `example run <https://app.neptune.ai/o/common/org/pytorch-lightning-integration/e/PTL-1/all>`_
          with multiple types of metadata logged.
        - For more detailed examples, see the
          `user guide <https://docs.neptune.ai/integrations/lightning/>`_.

    Args:
        api_key: Optional.
            Neptune API token, found on https://www.neptune.ai upon registration.
            You should save your token to the `NEPTUNE_API_TOKEN`
            environment variable and leave the api_key argument out of your code.
            Instructions: `Setting your API token <https://docs.neptune.ai/setup/setting_api_token/>`_.
        project: Optional.
            Name of a project in the form "workspace-name/project-name", for example "tom/mask-rcnn".
            If ``None``, the value of `NEPTUNE_PROJECT` environment variable is used.
            You need to create the project on https://www.neptune.ai first.
        name: Optional. Editable name of the run.
            The run name is displayed in the Neptune web app.
        run: Optional. Default is ``None``. A Neptune ``Run`` object.
            If specified, this existing run will be used for logging, instead of a new run being created.
            You can also pass a namespace handler object; for example, ``run["test"]``, in which case all
            metadata is logged under the "test" namespace inside the run.
        log_model_checkpoints: Optional. Default is ``True``. Log model checkpoint to Neptune.
            Works only if ``ModelCheckpoint`` is passed to the ``Trainer``.
        prefix: Optional. Default is ``"training"``. Root namespace for all metadata logging.
        \**neptune_run_kwargs: Additional arguments like ``tags``, ``description``, ``capture_stdout``, etc.
            used when a run is created.

    Raises:
        ModuleNotFoundError:
            If the required Neptune package is not installed.
        ValueError:
            If an argument passed to the logger's constructor is incorrect.

    /hyperparams	artifactsNTtraining)api_keyprojectnamerunlog_model_checkpointsprefixr.   r/   r0   r1   )r   r   r2   r3   neptune_run_kwargsc          
         s   t sttt | ||||| t   || _|| _|| _|| _	|| _
|| _|| _d | _| jd urP|   ddlm} | j}	t|	|rI|	 }	tj|	t< d S d S )Nr   r   )_NEPTUNE_AVAILABLEModuleNotFoundErrorstr_verify_input_argumentssuper__init___log_model_checkpoints_prefix	_run_name_project_name_api_key_run_instance_neptune_run_kwargs_run_short_id_retrieve_run_dataneptune.handlerr   
isinstanceget_root_objectpl__version___INTEGRATION_VERSION_KEY)
selfr.   r/   r0   r1   r2   r3   r4   r   root_obj	__class__r%   r&   r:      s(   


zNeptuneLogger.__init__r   c                 C   st   ddl m} | jd usJ | j}t||r| }|  |dr2|d  | _|d  | _	d S d| _d| _	d S )Nr   r   zsys/idzsys/nameOFFLINEzoffline-name)
rD   r   r@   rE   rF   waitexistsfetchrB   r=   )rJ   r   rK   r%   r%   r&   rC     s   


z NeptuneLogger._retrieve_run_datac                 C   s   i }t t | j}W d    n1 sw   Y  | jd ur$| j|d< | jd ur.| j|d< | jd ur8| j|d< t t | jd urP| j|d< W d    |S W d    |S 1 s[w   Y  |S )Nr/   	api_tokenr1   r0   )r"   r#   AttributeErrorrA   r>   r?   rB   r=   )rJ   r   r%   r%   r&   _neptune_init_args   s(   








z NeptuneLogger._neptune_init_argskeysc                 G   s&   | j r| j| j g|S | j|S )zXReturn sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined.)r<   LOGGER_JOIN_CHARjoin)rJ   rU   r%   r%   r&   _construct_path_with_prefix7  s   z)NeptuneLogger._construct_path_with_prefixc                 C   sn   ddl m} ddlm} |d urt|||fstdtdd | ||fD p(|}|d ur3|r5tdd S d S )Nr   r   r   zQRun parameter expected to be of type `neptune.Run`, or `neptune.handler.Handler`.c                 s   s    | ]}|d uV  qd S Nr%   ).0argr%   r%   r&   	<genexpr>M  s    z8NeptuneLogger._verify_input_arguments.<locals>.<genexpr>zlWhen an already initialized run object is provided, you can't provide other `neptune.init_run()` parameters.)neptuner   rD   r   rE   
ValueErrorany)r.   r/   r0   r1   r4   r   r   any_neptune_init_arg_passedr%   r%   r&   r8   =  s   z%NeptuneLogger._verify_input_argumentsc                 C   s   | j  }d |d< |S )Nr@   )__dict__copy)rJ   stater%   r%   r&   __getstate__T  s   
zNeptuneLogger.__getstate__rc   c                 C   s&   dd l }|| _|jdi | j| _d S Nr   r%   )r]   ra   init_runrT   r@   )rJ   rc   r]   r%   r%   r&   __setstate__Z  s   zNeptuneLogger.__setstate__r   c                 C      | j S )aK  Actual Neptune run object. Allows you to use neptune logging features in your
        :class:`~lightning.pytorch.core.LightningModule`.

        Example::

            class LitModel(LightningModule):
                def training_step(self, batch, batch_idx):
                    # log metrics
                    acc = ...
                    self.logger.experiment["train/acc"].append(acc)

                    # log images
                    img = ...
                    self.logger.experiment["train/misclassified_images"].append(File.as_image(img))

        Note that the syntax ``self.logger.experiment["your/metadata/structure"].append(metadata)``
        is specific to Neptune and extends the logger capabilities.
        It lets you log various types of metadata, such as scores, files,
        images, interactive visuals, and CSVs. Refer to the
        `Neptune docs <https://docs.neptune.ai/logging/methods>`_
        for more detailed explanations.
        You can also use the regular logger methods ``log_metrics()``, and ``log_hyperparams()``
        with NeptuneLogger.

        )r1   rJ   r%   r%   r&   
experiment`  s   zNeptuneLogger.experimentc                 C   s<   dd l }| js|jdi | j| _|   tj| jt< | jS re   )r]   r@   rf   rT   rC   rG   rH   rI   )rJ   r]   r%   r%   r&   r1   ~  s   zNeptuneLogger.runparamsc                 C   s>   ddl m} t|}t|}| j}| |}||| j|< dS )a  Log hyperparameters to the run.

        Hyperparameters will be logged under the "<prefix>/hyperparams" namespace.

        Note:

            You can also log parameters by directly using the logger instance:
            ``neptune_logger.experiment["model/hyper-parameters"] = params_dict``.

            In this way you can keep hierarchical structure of the parameters.

        Args:
            params: `dict`.
                Python dictionary structure with parameters.

        Example::

            from lightning.pytorch.loggers import NeptuneLogger
            import neptune

            PARAMS = {
                "batch_size": 64,
                "lr": 0.07,
                "decay_factor": 0.97,
            }

            neptune_logger = NeptuneLogger(
                api_key=neptune.ANONYMOUS_API_TOKEN,
                project="common/pytorch-lightning-integration"
            )

            neptune_logger.log_hyperparams(PARAMS)

        r   )stringify_unsupportedN)neptune.utilsrl   r   r   PARAMETERS_KEYrX   r1   )rJ   rk   rl   parameters_keyr%   r%   r&   log_hyperparams  s   &
zNeptuneLogger.log_hyperparamsmetricsstepc                 C   sL   t jdkr	tdt|| j| j}| D ]\}}| j| j||d qdS )zLog metrics (numeric values) in Neptune runs.

        Args:
            metrics: Dictionary with metric names as keys and measured quantities as values.
            step: Step number at which the metrics should be recorded

        r   z&run tried to log from global_rank != 0)rr   N)	r   rankr^   r   r<   rV   itemsr1   append)rJ   rq   rr   keyvalr%   r%   r&   log_metrics  s   
zNeptuneLogger.log_metricsstatusc                    s.   | j sd S |r|| j| d< t | d S )Nry   )r@   r1   rX   r9   finalize)rJ   ry   rL   r%   r&   rz     s
   zNeptuneLogger.finalizec                 C   s   t jt  dS )zGets the save directory of the experiment which in this case is ``None`` because Neptune does not save
        locally.

        Returns:
            the root directory where experiment logs get saved

        z.neptune)ospathrW   getcwdri   r%   r%   r&   save_dir  s   
zNeptuneLogger.save_dirmodelzpl.LightningModule	max_depthc                 C   s:   ddl m} tt||d}|j|dd| j| d< d S )Nr   File)r   r   txt)content	extensionzmodel/summary)neptune.typesr   r7   r   from_contentr1   rX   )rJ   r   r   r   	model_strr%   r%   r&   log_model_summary  s
   zNeptuneLogger.log_model_summarycheckpoint_callbackc                 C   s  | j sdS ddlm} t }| d}t|drJ|jrJ| |j|}|| t	|jd}|
|| j| d| < W d   n1 sEw   Y  t|drm|jD ]}| ||}|| | j| d|  | qRt|d	r|jr|j| j| d
< | |j|}|| t	|jd}|
|| j| d| < W d   n1 sw   Y  | j|r| j }	| |	|}
t|
| D ]}| j| d| = qt|dr|jr|j   | j| d< dS dS dS )zAutomatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.

        Args:
            checkpoint_callback: the model checkpoint callback instance

        Nr   r   zmodel/checkpointslast_model_pathrbr*   best_k_modelsbest_model_pathzmodel/best_model_pathbest_model_scorezmodel/best_model_score)r;   r   r   setrX   hasattrr   _get_full_model_nameaddopenfrom_streamr1   r   uploadr   rP   get_structure(_get_full_model_names_from_exp_structurelistr   cpudetachnumpy)rJ   r   r   
file_namescheckpoints_namespacemodel_last_namefprv   
model_nameexp_structureuploaded_model_namesfile_to_dropr%   r%   r&   after_save_checkpoint  s@   







z#NeptuneLogger.after_save_checkpoint
model_pathc                 C   s~   t |dr8tj| } tj|j}| |s!t|  d| dtj| t|d d \}}|	tj
dS | 	tj
dS )zZReturns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.dirpathz was expected to start with .   Nr*   )r   r{   r|   normpathr   
startswithr^   splitextlenreplacesep)r   r   expected_model_pathfilepath_r%   r%   r&   r   (  s   

 z"NeptuneLogger._get_full_model_namer   	namespacec                 C   s0   | | j}|D ]}|| }q|}t| |S )zHReturns all paths to properties which were already logged in `namespace`)splitrV   r   _dict_paths)clsr   r   structure_keysrv   uploaded_models_dictr%   r%   r&   r   5  s
   
z6NeptuneLogger._get_full_model_names_from_exp_structuredpath_in_buildc                 c   sV    |  D ]#\}}|d ur| d| n|}t|ts|V  q| ||E d H  qd S )Nr*   )rt   rE   dictr   )r   r   r   kvr|   r%   r%   r&   r   >  s   
zNeptuneLogger._dict_pathsc                 C   rh   )zMReturn the experiment name or 'offline-name' when exp is run in offline mode.)r=   ri   r%   r%   r&   r0   G  s   zNeptuneLogger.namec                 C   rh   )zMReturn the experiment version.

        It's Neptune Run's short_id

        )rB   ri   r%   r%   r&   versionM  s   zNeptuneLogger.version)r   N)r   r   rY   )r   )1__name__
__module____qualname____doc__rV   rn   ARTIFACTS_KEYr
   r7   r   boolr   r:   rC   propertyr   rT   rX   staticmethodr   r8   rd   rg   r   rj   r1   r   r   r(   r   rp   r   floatintrx   rz   r~   r   r   r   r   classmethodr   r   r   r   r0   r   __classcell__r%   r%   rL   r&   r)   A   s     $
	
&
(-2

4$&r)   )3r   r"   loggingr{   argparser   	functoolsr   typingr   r   r   r   r   r	   r
   r   r    lightning_utilities.core.importsr   torchr   typing_extensionsr   lightning.pytorchpytorchrG   !lightning.fabric.utilities.loggerr   r   r   lightning.pytorch.callbacksr    lightning.pytorch.loggers.loggerr   r   )lightning.pytorch.utilities.model_summaryr   %lightning.pytorch.utilities.rank_zeror   r]   r   rD   r   	getLoggerr   logr5   rI   r(   r)   r%   r%   r%   r&   <module>   s0   ,
