# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
import inspect
import os
import pathlib
import uuid
from abc import abstractmethod
from os import path
from pathlib import Path
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import hydra
import torch

from nemo.core.classes.module import NeuralModule
from nemo.utils.msc_utils import import_multistorageclient, is_multistorageclient_url

try:
    from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
    from megatron.core.utils import get_model_config

    HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

    HAVE_MEGATRON_CORE = False

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.utilities import model_summary, rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict

from nemo import package_info
from nemo.core import optim
from nemo.core.classes.common import Model
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
from nemo.lightning.callback_group import CallbackGroup
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.debug_hook import register_debug_hooks
from nemo.utils.exceptions import NeMoBaseException
from nemo.utils.get_rank import get_rank, is_global_rank_zero

__all__ = ['ModelPT']


# multiple interpolated values in the config
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)


class ModelPT(LightningModule, Model):
    """
    Interface for Pytorch-lightning based NeMo models
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Base class from which all NeMo models should inherit

        Args:
            cfg (DictConfig):  configuration object.
                The cfg object should have (optionally) the following sub-configs:

                * train_ds - to instantiate training dataset
                * validation_ds - to instantiate validation dataset
                * test_ds - to instantiate testing dataset
                * optim - to instantiate optimizer with learning rate scheduler

            trainer (Optional): Pytorch Lightning Trainer instance
        """
        if trainer is not None and not isinstance(trainer, Trainer):
            raise ValueError(
                f"trainer constructor argument must be either None or lightning.pytorch.Trainer. "
                f"But got {type(trainer)} instead."
            )

        # Track model init start
        CallbackGroup.get_instance().on_model_init_start()

        super().__init__()

        """
        Internal global flags that determine core functionality of ModelPT.

        _MODEL_IS_RESTORED:
            This flag determines the context of the model - whether the model is currently being
            restored or not.
            -   When set, it can be assumed that the model's will disable all automatic methods -
                setup_training_data(), setup_validation/test_data() and their multi equivalents.
            -   If a model is being restored from a archive file (tarfile), it can be assumed that
                under this context, the cwd is *inside* the tarfile itself.

        _MODEL_RESTORE_PATH:
            A string path to a a file from which the model is being restored.
            This file can either be a PyTorch Lightning Checkpoint, or a archive (tarfile) that contains
            artifact objects.
            If it is an archive file, during restoration, the cwd will be temporarily moved to inside the
            archive itself.
        """
        # set global vars in AppState
        app_state = AppState()

        # Convert config to a DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)

        # Convert config to support Hydra 1.0+ instantiation
        cfg = model_utils.maybe_update_config_version(cfg)

        if 'model' in cfg:
            raise ValueError(
                "Creating model config node is forbidden due to collision problem when loading from checkpoint."
            )

        if 'target' not in cfg:
            # This is for Jarvis service.
            OmegaConf.set_struct(cfg, False)
            cfg.target = "{0}.{1}".format(self.__class__.__module__, self.__class__.__name__)
            OmegaConf.set_struct(cfg, True)

        if 'nemo_version' not in cfg:
            with open_dict(cfg):
                cfg.nemo_version = package_info.__version__

        self._cfg = cfg

        # init mapping submodule attribute -> config_field for nested NeMo models
        self._nemo_submodule_name_to_config_field = dict()

        self.save_hyperparameters("cfg")
        self._train_dl = None
        self._validation_dl = None
        self._test_dl = None
        self._optimizer_param_groups = None
        self._optimizer = None
        self._scheduler = None
        self.set_trainer(trainer)

        self._save_restore_connector = SaveRestoreConnector()

        self._set_model_guid()

        # Set device_id in AppState
        if torch.cuda.is_available() and torch.cuda.current_device() is not None:
            app_state.device_id = torch.cuda.current_device()

        CallbackGroup.get_instance().on_model_init_end()
        CallbackGroup.get_instance().on_dataloader_init_start()
        if self._cfg is not None and not self._is_model_being_restored():
            # Setup data loaders now (default) or defer setup to `self.setup()`
            # if `defer_setup` is set in the config of the corresponding dataloader.
            if (
                'train_ds' in self._cfg
                and self._cfg.train_ds is not None
                and not self._cfg.train_ds.get('defer_setup', False)
            ):
                self.setup_training_data(self._cfg.train_ds)

            if (
                'validation_ds' in self._cfg
                and self._cfg.validation_ds is not None
                and not self._cfg.validation_ds.get('defer_setup', False)
            ):
                self.setup_multiple_validation_data(val_data_config=cfg.validation_ds)

            if (
                'test_ds' in self._cfg
                and self._cfg.test_ds is not None
                and not self._cfg.test_ds.get('defer_setup', False)
            ):
                self.setup_multiple_test_data(test_data_config=cfg.test_ds)

        else:
            if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
                logging.warning(
                    f"If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() "
                    f"method and provide a valid configuration file to setup the train data loader.\n"
                    f"Train config : \n{OmegaConf.to_yaml(self._cfg.train_ds)}"
                )

            if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
                logging.warning(
                    f"If you intend to do validation, please call the ModelPT.setup_validation_data() or "
                    f"ModelPT.setup_multiple_validation_data() method "
                    f"and provide a valid configuration file to setup the validation data loader(s). \n"
                    f"Validation config : \n{OmegaConf.to_yaml(self._cfg.validation_ds)}"
                )
            if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
                logging.warning(
                    f"Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method "
                    f"and provide a valid configuration file to setup the test data loader(s).\n"
                    f"Test config : \n{OmegaConf.to_yaml(self._cfg.test_ds)}"
                )

        CallbackGroup.get_instance().on_dataloader_init_end()

        # Create list of lists for val and test outputs to support multiple dataloaders
        # Initialize an empty list as sometimes self._validation_dl can be None at this stage
        self._validation_step_outputs = None

        # Initialize an empty list as sometimes self._test_dl can be None at this stage
        self._test_step_outputs = None

        # ModelPT wrappers over subclass implementations
        self.training_step = model_utils.wrap_training_step(self.training_step)

        # Setup nsys profiling if it has been enabled in the model config
        self._setup_profiling()

        # A flag for the profile generation
        self._nsys_profile_started = False
        self._nsys_profile_complete = False
        self._memory_profile_started = False
        self._memory_profile_complete = False

        # Setup chakra profiling if it has been enabled in the model config
        self._setup_chakra_profiling()

        # A flag for the profile generation
        self._chakra_profile_in_progress = False

    def __init_subclass__(cls) -> None:
        cls._save_restore_connector = SaveRestoreConnector()

    def on_fit_start(self) -> None:
        """
        Register debug hooks.
        """
        if self.cfg.get("dump_debug_info", False):
            register_debug_hooks(self.model, self.trainer, self.log, self.cfg.get("dump_debug_info_to_file", False))
        return super().on_fit_start()

    def register_artifact(
        self,
        config_path: str,
        src: str,
        verify_src_exists: bool = True,
    ):
        """Register model artifacts with this function. These artifacts (files) will be included inside .nemo file
        when model.save_to("mymodel.nemo") is called.

        How it works:

        1. It always returns existing absolute path which can be used during Model constructor call
            EXCEPTION: src is None or "" in which case nothing will be done and src will be returned
        2. It will add (config_path, model_utils.ArtifactItem()) pair to self.artifacts

            .. code-block::

                If "src" is local existing path:
                    then it will be returned in absolute path form.
                elif "src" starts with "nemo_file:unique_artifact_name":
                    .nemo will be untarred to a temporary folder location and an actual existing path will be returned
                else:
                    an error will be raised.

        WARNING: use .register_artifact calls in your models' constructors.
        The returned path is not guaranteed to exist after you have exited your model's constructor.

        Args:
            config_path (str): Artifact key. Usually corresponds to the model config.
            src (str): Path to artifact.
            verify_src_exists (bool): If set to False, then the artifact is optional and register_artifact will return
                                      None even if src is not found. Defaults to True.

        Returns:
            str: If src is not None or empty it always returns absolute path which is guaranteed to exist during model
                 instance life
        """

        if src is None or src == "":
            return src

        if Path(src).suffix == ".nemo":
            raise NeMoBaseException(
                "Registering .nemo files as artifacts not supported. "
                "If you are trying to make a nested model, use `register_nemo_submodule`."
            )

        if not hasattr(self, 'artifacts'):
            self.artifacts = {}

        if self.artifacts is None:
            self.artifacts = {}

        if config_path in self.artifacts.keys():
            logging.warning(
                f"You tried to register an artifact under config key={config_path} but an artifact for "
                f"it has already been registered."
            )

        return self._save_restore_connector.register_artifact(self, config_path, src, verify_src_exists)

    def has_artifacts(self) -> bool:
        """Returns True if model has artifacts registered"""
        return hasattr(self, 'artifacts') and self.artifacts is not None and len(self.artifacts) > 0

    def has_native_or_submodules_artifacts(self) -> bool:
        """Returns True if it has artifacts or any of the submodules have artifacts"""
        for module in self.modules():
            if (
                isinstance(module, ModelPT)
                and hasattr(module, 'artifacts')
                and module.artifacts is not None
                and len(module.artifacts) > 0
            ):
                return True
        return False

    def has_nemo_submodules(self) -> bool:
        """Returns True if it has any registered NeMo submodules"""
        return len(self._nemo_submodule_name_to_config_field) > 0

    def register_nemo_submodule(self, name: str, config_field: str, model: "ModelPT") -> None:
        """
        Adds a NeMo model as a submodule. Submodule can be accessed via the `name` attribute on the parent NeMo model
        this submodule was registered on (`self`).
        In the saving process, the whole parent model (self) is held as a solid model with artifacts
        from the child submodule, the submodule config will be saved to the `config_field` of the parent model.
        This method is necessary to create a nested model, e.g.

        .. code-block:: python

            class ParentModel(ModelPT):
                def __init__(self, cfg, trainer=None):
                    super().__init__(cfg=cfg, trainer=trainer)

                    # annotate type for autocompletion and type checking (optional)
                    self.child_model: Optional[ChildModel] = None
                    if cfg.get("child_model") is not None:
                        self.register_nemo_submodule(
                            name="child_model",
                            config_field="child_model",
                            model=ChildModel(self.cfg.child_model, trainer=trainer),
                        )
                    # ... other code

        Args:
            name: name of the attribute for the submodule
            config_field: field in config, where submodule config should be saved
            model: NeMo model, instance of ModelPT
        """
        # check it is a real NeMo model
        if not isinstance(model, ModelPT):
            raise NeMoBaseException(
                f"Model is not and instance of ModelPT, so can't be registered. Got {type(model).__name__}"
            )
        # check if it is called after __init__
        if not hasattr(self, "_nemo_submodule_name_to_config_field"):
            raise NeMoBaseException(
                "You are trying to register a submodule before the model is initialized. This is not allowed. "
                "Did you forget to call `super().__init__`?"
            )
        # assign attribute to self
        setattr(self, name, model)
        # add to the submodules mapping
        self._nemo_submodule_name_to_config_field[name] = config_field

    def named_nemo_modules(
        self, prefix_name: str = "", prefix_config: str = ""
    ) -> Iterator[Tuple[str, str, "ModelPT"]]:
        """
        Returns an iterator over all NeMo submodules recursively, yielding
        tuples of (attribute path, path in config, submodule), starting from the core module

        Args:
            prefix_name: prefix for the name path
            prefix_config: prefix for the path in config

        Returns:
            Iterator over (attribute path, path in config, submodule), starting from (prefix, self)
        """
        if not hasattr(self, "_nemo_submodule_name_to_config_field"):
            raise NeMoBaseException(
                "Model is not fully initialized. Calling `named_nemo_modules` before __init__ not allowed. "
                "Did you forget to call `super().__init__`?"
            )

        yield prefix_name, prefix_config, self

        # recursive iteration over all NeMo submodules
        for name, config_field in self._nemo_submodule_name_to_config_field.items():
            attribute_path = f"{prefix_name}.{name}" if prefix_name else name
            config_path = f"{prefix_config}.{config_field}" if prefix_config else config_field
            module: ModelPT = getattr(self, name)
            for submodule_name, subconfig_path, submodule in module.named_nemo_modules(
                prefix_name=attribute_path, prefix_config=config_path
            ):
                yield submodule_name, subconfig_path, submodule

    def save_to(self, save_path: str):
        """
        Saves model instance (weights and configuration) into .nemo file
         You can use "restore_from" method to fully restore instance from .nemo file.

        .nemo file is an archive (tar.gz) with the following:
            model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for
                                model's constructor
            model_wights.ckpt - model checkpoint

        Args:
            save_path: Path to .nemo file where model instance should be saved
        """

        def maybe_make_save_dir(path: Union[str, 'pathlib.Path']):
            if not is_multistorageclient_url(path):
                if not path.parent.exists():
                    path.parent.mkdir(parents=True)

        if not is_multistorageclient_url(save_path):
            save_path = Path(save_path).expanduser().resolve()
        app_state = AppState()
        if app_state.model_parallel_size is not None:
            if app_state.model_parallel_size > 1:
                if type(self._save_restore_connector) == SaveRestoreConnector:
                    raise ValueError(
                        'Default NeMo SaveRestoreConnector will not work in model parallel mode. You should use a '
                        'connector which supports model parallel mode, such as NLPSaveRestoreConnector in NLP. You '
                        'can also use a custom one.'
                    )
            if is_global_rank_zero():
                maybe_make_save_dir(save_path)
            if torch.distributed.is_initialized():
                torch.distributed.barrier()
            # connector checks for ranks properly, no need to check here
            self._save_restore_connector.save_to(self, str(save_path))  # downstream tasks expect str, not Path
        elif is_global_rank_zero():
            maybe_make_save_dir(save_path)
            self._save_restore_connector.save_to(self, str(save_path))  # downstream tasks expect str, not Path

    @classmethod
    def restore_from(
        cls,
        restore_path: str,
        override_config_path: Optional[Union[OmegaConf, str]] = None,
        map_location: Optional[torch.device] = None,
        strict: bool = True,
        return_config: bool = False,
        save_restore_connector: SaveRestoreConnector = None,
        trainer: Optional[Trainer] = None,
        validate_access_integrity: bool = True,
    ):
        """
        Restores model instance (weights and configuration) from .nemo file.

        Args:
            restore_path: path to .nemo file from which model should be instantiated
            override_config_path: path to a yaml config that will override the internal
                config file or an OmegaConf / DictConfig object representing the model config.
            map_location: Optional torch.device() to map the instantiated model to a device.
                By default (None), it will select a GPU if available, falling back to CPU otherwise.
            strict: Passed to load_state_dict. By default True.
            return_config: If set to true, will return just the underlying config of the restored
                model as an OmegaConf DictConfig object without instantiating the model.
            trainer: Optional, a pytorch lightning Trainer object that will be forwarded to the
                instantiated model's constructor.
            save_restore_connector (SaveRestoreConnector): Can be overridden to add custom save and restore logic.

            Example:
                ```
                model = nemo.collections.asr.models.EncDecCTCModel.restore_from('asr.nemo')
                assert isinstance(model, nemo.collections.asr.models.EncDecCTCModel)
                ```

        Returns:
            An instance of type cls or its underlying config (if return_config is set).
        """
        if save_restore_connector is None:
            save_restore_connector = SaveRestoreConnector()

        if is_multistorageclient_url(restore_path):
            msc = import_multistorageclient()
            if not msc.os.path.exists(restore_path):
                raise FileNotFoundError(f"Can't find {restore_path}")
        else:
            if save_restore_connector.model_extracted_dir is None:
                restore_path = os.path.abspath(os.path.expanduser(restore_path))
            else:
                restore_path = os.path.abspath(os.path.expanduser(save_restore_connector.model_extracted_dir))

            if not path.exists(restore_path):
                raise FileNotFoundError(f"Can't find {restore_path}")

        app_state = AppState()
        app_state.model_restore_path = restore_path

        cls.update_save_restore_connector(save_restore_connector)
        instance = cls._save_restore_connector.restore_from(
            cls,
            restore_path,
            override_config_path,
            map_location,
            strict,
            return_config,
            trainer,
            validate_access_integrity,
        )
        if isinstance(instance, ModelPT):
            instance._save_restore_connector = save_restore_connector

        return instance

    @classmethod
    def load_from_checkpoint(
        cls,
        checkpoint_path: str,
        *args,
        map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
        hparams_file: Optional[str] = None,
        strict: bool = True,
        **kwargs,
    ):
        """
        Loads ModelPT from checkpoint, with some maintenance of restoration.
        For documentation, please refer to LightningModule.load_from_checkpoint() documentation.
        """
        # Notify OneLogger of checkpoint loading start for telemetry tracking
        CallbackGroup.get_instance().on_load_checkpoint_start()

        checkpoint = None
        try:
            cls._set_model_restore_state(is_being_restored=True)

            checkpoint = super().load_from_checkpoint(
                checkpoint_path=checkpoint_path,
                *args,
                map_location=map_location,
                hparams_file=hparams_file,
                strict=strict,
                **kwargs,
            )

        finally:
            cls._set_model_restore_state(is_being_restored=False)

        # Notify OneLogger of checkpoint loading completion for telemetry tracking
        CallbackGroup.get_instance().on_load_checkpoint_end()

        return checkpoint

    @abstractmethod
    def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
        """
        Setups data loader to be used in training

        Args:
            train_data_layer_config: training data layer parameters.
        Returns:

        """
        pass

    @abstractmethod
    def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]):
        """
        Setups data loader to be used in validation
        Args:

            val_data_layer_config: validation data layer parameters.
        Returns:

        """
        pass

    def setup_test_data(self, test_data_config: Union[DictConfig, Dict]):
        """
        (Optionally) Setups data loader to be used in test

        Args:
            test_data_layer_config: test data layer parameters.
        Returns:

        """
        raise NotImplementedError()

    def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]):
        """
        (Optionally) Setups data loader to be used in validation, with support for multiple data loaders.

        Args:
            val_data_layer_config: validation data layer parameters.
        """
        # Set some placeholder overriden by helper method
        self._val_dl_idx = 0
        self._validation_names = None
        self._validation_dl = None  # type: torch.utils.data.DataLoader

        # preserve config
        self._update_dataset_config(dataset_name='validation', config=val_data_config)

        try:
            self._multi_dataset_mode = True
            model_utils.resolve_validation_dataloaders(model=self)
        finally:
            self._multi_dataset_mode = False

        if self._validation_names is None:
            if self._validation_dl is not None and type(self._validation_dl) in [list, tuple]:
                self._validation_names = ['val_{}_'.format(idx) for idx in range(len(self._validation_dl))]

    def setup_multiple_test_data(self, test_data_config: Union[DictConfig, Dict]):
        """
        (Optionally) Setups data loader to be used in test, with support for multiple data loaders.

        Args:
            test_data_layer_config: test data layer parameters.
        """
        # Set some placeholder overriden by helper method
        self._test_dl_idx = 0
        self._test_names = None
        self._test_dl = None  # type: torch.utils.data.DataLoader

        # preserve config
        self._update_dataset_config(dataset_name='test', config=test_data_config)

        try:
            self._multi_dataset_mode = True
            model_utils.resolve_test_dataloaders(model=self)
        finally:
            self._multi_dataset_mode = False

        if self._test_names is None:
            if self._test_dl is not None and type(self._test_dl) in [list, tuple]:
                self._test_names = ['test_{}_'.format(idx) for idx in range(len(self._test_dl))]

    def setup_megatron_optimization(self, optim_config: Union[Dict[str, Any], DictConfig]):
        """
        Setup mcore optimizer config.

        Args:
            optim_config: Nemo optim args used to set up Mcore optimizer options.
        """

        config = get_model_config(self.model[0])

        megatron_optim_config = OptimizerConfig(
            fp16=config.fp16,
            bf16=config.bf16,
            params_dtype=config.params_dtype,
            lr=optim_config['lr'],
            weight_decay=optim_config['weight_decay'],
            adam_beta1=optim_config['betas'][0],
            adam_beta2=optim_config['betas'][1],
            adam_eps=optim_config.get('eps', OptimizerConfig.adam_eps),
            clip_grad=self.trainer.gradient_clip_val,
            use_distributed_optimizer=self.use_mcore_dist_optim,
            overlap_param_gather_with_optimizer_step=self.cfg.optim.get(
                'overlap_param_gather_with_optimizer_step', False
            ),
        )
        return megatron_optim_config

    def setup_optimization(
        self,
        optim_config: Optional[Union[DictConfig, Dict]] = None,
        optim_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """Prepares an optimizer from a string name and its optional config parameters.

        Args:
            optim_config: A dictionary containing the following keys:

                * "lr": mandatory key for learning rate. Will raise ValueError if not provided.
                * "optimizer": string name pointing to one of the available optimizers in the registry. \
                If not provided, defaults to "adam".
                * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \
                The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \
                will be built and supplied to instantiate the optimizer.

            optim_kwargs: A dictionary with additional kwargs for the
                optimizer. Used for non-primitive types that are not
                compatible with OmegaConf.

        """
        # Setup the optimizer parameter groups (by default use all parameters that are trainable)
        self.setup_optimizer_param_groups()

        # Make copy of the config so it can be modified later
        optim_config = self._optim_config_copy(optim_config)

        # If config is still None, return without instantiation
        if optim_config is None:
            logging.info('No optimizer config provided, therefore no optimizer was created')
            return

        # See if internal config has `optim` namespace before preservation
        if self._cfg is not None and hasattr(self._cfg, 'optim'):
            self._cfg.optim = optim_config

        # Setup optimizer and scheduler
        if optim_config is not None:
            optim_config = OmegaConf.to_container(optim_config, resolve=True)

        if self._trainer is None:
            logging.warning("Trainer wasn't specified in model constructor. Make sure that you really wanted it.")

        if 'sched' in optim_config and self._trainer is not None:

            if not isinstance(self._trainer.accumulate_grad_batches, int):
                raise ValueError("We do not currently support gradient acculumation that is not an integer.")
            if self.trainer.max_steps < 0:
                # Store information needed to calculate max_steps
                optim_config['sched']['t_max_epochs'] = self._trainer.max_epochs
                optim_config['sched']['t_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches
                optim_config['sched']['t_limit_train_batches'] = self._trainer.limit_train_batches

                app_state = AppState()
                if app_state.data_parallel_size is not None:
                    optim_config['sched']['t_num_workers'] = app_state.data_parallel_size
                elif app_state.model_parallel_size is None:
                    optim_config['sched']['t_num_workers'] = self._trainer.num_devices * self._trainer.num_nodes
                else:
                    optim_config['sched']['t_num_workers'] = (
                        self._trainer.num_devices * self._trainer.num_nodes
                    ) / app_state.model_parallel_size
            else:
                optim_config['sched']['max_steps'] = self._trainer.max_steps

        # Force into DictConfig from nested structure
        optim_config = OmegaConf.create(optim_config)
        # Get back nested dict so we its mutable
        optim_config = OmegaConf.to_container(optim_config, resolve=True)

        # Extract scheduler config if inside optimizer config
        if 'sched' in optim_config:
            scheduler_config = optim_config.pop('sched')
        else:
            scheduler_config = None

        # Check if caller provided optimizer name, default to Adam otherwise
        optimizer_cls = optim_config.get('_target_', None)

        if optimizer_cls is None:
            # Try to get optimizer name for dynamic resolution, defaulting to Adam
            # Use or instead of default as None will also results in default value not used.
            optimizer_name = optim_config.get('name') or 'adam'
        else:
            if inspect.isclass(optimizer_cls):
                optimizer_name = optimizer_cls.__name__.lower()
            else:
                # resolve the class name (lowercase) from the class path if not provided
                optimizer_name = optimizer_cls.split(".")[-1].lower()

        # We are guarenteed to have lr since it is required by the argparser
        # But maybe user forgot to pass it to this function
        lr = optim_config.get('lr', None)

        # Check if caller has optimizer kwargs, default to empty dictionary
        if 'args' in optim_config:
            optimizer_args = optim_config.pop('args')
            optimizer_args = optim.parse_optimizer_args(optimizer_name, optimizer_args)
        else:
            optimizer_args = copy.deepcopy(optim_config)

            # Remove extra parameters from optimizer_args nest
            # Assume all other parameters are to be passed into optimizer constructor
            optimizer_args.pop('name', None)
            optimizer_args.pop('cls', None)
            optimizer_args.pop('lr', None)

        # Include user-provided kwargs
        if optim_kwargs is not None:
            optimizer_args.update(optim_kwargs)

        # Adaptive schedulers don't need `lr`
        if lr is not None:
            optimizer_args['lr'] = lr

        # Actually instantiate the optimizer
        if optimizer_cls is not None:
            if inspect.isclass(optimizer_cls):
                optimizer = optimizer_cls(self._optimizer_param_groups, **optimizer_args)
                logging.info("Optimizer config = %s", str(optimizer))

                self._optimizer = optimizer

            else:
                # Attempt class path resolution
                try:
                    optimizer_cls = OmegaConf.create({'_target_': optimizer_cls})
                    if lr is not None:
                        optimizer_config = {'lr': lr}
                    else:
                        optimizer_config = {}
                    optimizer_config.update(optimizer_args)

                    optimizer_instance = hydra.utils.instantiate(
                        optimizer_cls, self._optimizer_param_groups, **optimizer_config
                    )  # type: DictConfig

                    logging.info("Optimizer config = %s", str(optimizer_instance))

                    self._optimizer = optimizer_instance

                except Exception as e:
                    logging.error(
                        "Could not instantiate class path - {} with kwargs {}".format(
                            optimizer_cls, str(optimizer_config)
                        )
                    )
                    raise e

        else:
            if optimizer_name == 'mcore_distributed_optim':
                # setup megatron_optim_config and get Mcore based optimizer with the wrapper
                megatron_optim_config = self.setup_megatron_optimization(optimizer_args)
                _megatron_optimizer = get_megatron_optimizer(
                    megatron_optim_config,
                    self.model,
                )
                optimizer = McoreDistributedOptimizer(_megatron_optimizer)

            else:
                optimizer = optim.get_optimizer(optimizer_name)
                optimizer = optimizer(self._optimizer_param_groups, **optimizer_args)

                logging.info("Optimizer config = %s", str(optimizer))

            self._optimizer = optimizer

        self._scheduler = prepare_lr_scheduler(
            optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl
        )

        # Return the optimizer with/without scheduler
        # This return allows multiple optimizers or schedulers to be created
        return self._optimizer, self._scheduler

    def setup_optimizer_param_groups(self):
        """
        Used to create param groups for the optimizer.
        As an example, this can be used to specify per-layer learning rates:

        optim.SGD([
                    {'params': model.base.parameters()},
                    {'params': model.classifier.parameters(), 'lr': 1e-3}
                    ], lr=1e-2, momentum=0.9)

        See https://pytorch.org/docs/stable/optim.html for more information.
        By default, ModelPT will use self.parameters().
        Override this method to add custom param groups.
        In the config file, add 'optim_param_groups' to support different LRs
        for different components (unspecified params will use the default LR):

        model:
            optim_param_groups:
                encoder:
                    lr: 1e-4
                    momentum: 0.8
                decoder:
                    lr: 1e-3
            optim:
                lr: 3e-3
                momentum: 0.9
        """
        if not hasattr(self, "parameters"):
            self._optimizer_param_groups = None
            return

        known_groups = []
        param_groups = []
        if "optim_param_groups" in self.cfg:
            param_groups_cfg = self.cfg.optim_param_groups
            for group, group_cfg in param_groups_cfg.items():
                module = getattr(self, group, None)
                if module is None:
                    raise ValueError(f"{group} not found in model.")
                elif hasattr(module, "parameters"):
                    known_groups.append(group)
                    new_group = {"params": list(module.parameters())}
                    for k, v in group_cfg.items():
                        new_group[k] = v
                    param_groups.append(new_group)
                else:
                    raise ValueError(f"{group} does not have parameters.")

            other_params = []
            for n, p in self.named_parameters():
                is_unknown = True
                for group in known_groups:
                    if n.startswith(group):
                        is_unknown = False
                if is_unknown:
                    other_params.append(p)

            if len(other_params):
                param_groups = [{"params": other_params}] + param_groups
        else:
            param_groups = [{"params": list(self.parameters())}]

        self._optimizer_param_groups = param_groups

    def configure_optimizers(self):
        """
        Configure the optimizer and scheduler.
        """
        # Track optimizer init start
        CallbackGroup.get_instance().on_optimizer_init_start()
        self.setup_optimization()

        CallbackGroup.get_instance().on_optimizer_init_end()

        if self._scheduler is None:
            return self._optimizer
        else:
            return [self._optimizer], [self._scheduler]

    def propagate_model_guid(self):
        """
        Propagates the model GUID to all submodules, recursively.
        """

        def recursively_propagate_guid(module: "NeuralModule"):
            module.model_guid = self.model_guid
            for child in module.children():
                recursively_propagate_guid(child)

        for _, _, module in self.named_nemo_modules():
            module.model_guid = self.model_guid
            recursively_propagate_guid(module)

    def setup(self, stage: Optional[str] = None):
        """Called at the beginning of fit, validate, test, or predict.
        This is called on every process when using DDP.

        Args:
            stage: fit, validate, test or predict
        """
        self.propagate_model_guid()
        if stage == 'fit':
            train_deferred_setup = (
                'train_ds' in self._cfg
                and self._cfg.train_ds is not None
                and self._cfg.train_ds.get('defer_setup', False)
            )
            no_train_dataloader = self.train_dataloader() is None or (
                isinstance(self.train_dataloader(), list) and len(self.train_dataloader()) == 0
            )
            if no_train_dataloader and train_deferred_setup:
                self.setup_training_data(self._cfg.train_ds)

        if stage in ('fit', 'validate'):
            val_deferred_setup = (
                'validation_ds' in self._cfg
                and self._cfg.validation_ds is not None
                and self._cfg.validation_ds.get('defer_setup', False)
            )
            no_val_dataloader = self.val_dataloader() is None or (
                isinstance(self.val_dataloader(), list) and len(self.val_dataloader()) == 0
            )
            if no_val_dataloader and val_deferred_setup:
                self.setup_multiple_validation_data(val_data_config=self._cfg.validation_ds)

        if stage == 'test':
            test_deferred_setup = (
                'test_ds' in self._cfg
                and self._cfg.test_ds is not None
                and self._cfg.test_ds.get('defer_setup', False)
            )
            no_test_dataloader = self.test_dataloader() is None or (
                isinstance(self.test_dataloader(), list) and len(self.test_dataloader()) == 0
            )
            if no_test_dataloader and test_deferred_setup:
                self.setup_multiple_test_data(test_data_config=self._cfg.test_ds)

        if stage == 'fit':
            CallbackGroup.get_instance().update_config(nemo_version='v1', trainer=self._trainer)

    def train_dataloader(self):
        """
        Get the training dataloader.
        """
        if self._train_dl is not None:
            return self._train_dl

    def val_dataloader(self):
        """
        Get the validation dataloader.
        """
        if self._validation_dl is None:
            # None dataloader no longer supported in PTL2.0
            self._validation_dl = []

        return self._validation_dl

    def test_dataloader(self):
        """
        Get the test dataloader.
        """
        if self._test_dl is None:
            # None dataloader no longer supported in PTL2.0
            self._test_dl = []

        return self._test_dl

    def on_validation_epoch_end(self, sync_metrics: bool = False) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
        """
        Default DataLoader for Validation set which automatically supports multiple data loaders
        via `multi_validation_epoch_end`.

        If multi dataset support is not required, override this method entirely in base class.
        In such a case, there is no need to implement `multi_validation_epoch_end` either.

        .. note::
            If more than one data loader exists, and they all provide `val_loss`,
            only the `val_loss` of the first data loader will be used by default.
            This default can be changed by passing the special key `val_dl_idx: int`
            inside the `validation_ds` config.

        Args:
            outputs: Single or nested list of tensor outputs from one or more data loaders.

        Returns:
            A dictionary containing the union of all items from individual data_loaders,
            along with merged logs from all data loaders.
        """
        # Case where we dont provide data loaders
        if self.validation_step_outputs is not None and len(self.validation_step_outputs) == 0:
            return {}

        # Case where we provide exactly 1 data loader
        if isinstance(self.validation_step_outputs[0], dict):
            output_dict = self.multi_validation_epoch_end(self.validation_step_outputs, dataloader_idx=0)

            if output_dict is not None and 'log' in output_dict:
                self.log_dict(output_dict.pop('log'), on_epoch=True, sync_dist=sync_metrics)

            self.validation_step_outputs.clear()  # free memory
            return output_dict

        else:  # Case where we provide more than 1 data loader
            output_dict = {'log': {}}

            # The output is a list of list of dicts, outer list corresponds to dataloader idx
            for dataloader_idx, val_outputs in enumerate(self.validation_step_outputs):
                # Get prefix and dispatch call to multi epoch end
                dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx)
                dataloader_logs = self.multi_validation_epoch_end(val_outputs, dataloader_idx=dataloader_idx)

                # If result was not provided, generate empty dict
                dataloader_logs = dataloader_logs or {}

                # Perform `val_loss` resolution first (if provided outside logs)
                if 'val_loss' in dataloader_logs:
                    if 'val_loss' not in output_dict and dataloader_idx == self._val_dl_idx:
                        output_dict['val_loss'] = dataloader_logs['val_loss']

                # For every item in the result dictionary
                for k, v in dataloader_logs.items():
                    # If the key is `log`
                    if k == 'log':
                        # Parse every element of the log, and attach the prefix name of the data loader
                        log_dict = {}

                        for k_log, v_log in v.items():
                            # If we are logging the metric, but dont provide it at result level,
                            # store it twice - once in log and once in result level.
                            # Also mark log with prefix name to avoid log level clash with other data loaders
                            if k_log not in output_dict['log'] and dataloader_idx == self._val_dl_idx:
                                new_k_log = k_log

                                # Also insert duplicate key with prefix for ease of comparison / avoid name clash
                                log_dict[dataloader_prefix + k_log] = v_log

                            else:
                                # Simply prepend prefix to key and save
                                new_k_log = dataloader_prefix + k_log

                            # Store log value
                            log_dict[new_k_log] = v_log

                        # Update log storage of individual data loader
                        output_logs = output_dict['log']
                        output_logs.update(log_dict)

                        # Update global log storage
                        output_dict['log'] = output_logs

                    else:
                        # If any values are stored outside 'log', simply prefix name and store
                        new_k = dataloader_prefix + k
                        output_dict[new_k] = v

                self.validation_step_outputs[dataloader_idx].clear()  # free memory

            if 'log' in output_dict:
                self.log_dict(output_dict.pop('log'), on_epoch=True, sync_dist=sync_metrics)

            # return everything else
            return output_dict

    def on_test_epoch_end(self) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
        """
        Default DataLoader for Test set which automatically supports multiple data loaders
        via `multi_test_epoch_end`.

        If multi dataset support is not required, override this method entirely in base class.
        In such a case, there is no need to implement `multi_test_epoch_end` either.

        .. note::
            If more than one data loader exists, and they all provide `test_loss`,
            only the `test_loss` of the first data loader will be used by default.
            This default can be changed by passing the special key `test_dl_idx: int`
            inside the `test_ds` config.

        Args:
            outputs: Single or nested list of tensor outputs from one or more data loaders.

        Returns:
            A dictionary containing the union of all items from individual data_loaders,
            along with merged logs from all data loaders.
        """
        # Case where we dont provide data loaders
        if self.test_step_outputs is not None and len(self.test_step_outputs) == 0:
            return {}

        # Case where we provide exactly 1 data loader
        if isinstance(self.test_step_outputs[0], dict):
            output_dict = self.multi_test_epoch_end(self.test_step_outputs, dataloader_idx=0)

            if output_dict is not None and 'log' in output_dict:
                self.log_dict(output_dict.pop('log'), on_epoch=True)

            self.test_step_outputs.clear()  # free memory
            return output_dict

        else:  # Case where we provide more than 1 data loader
            output_dict = {'log': {}}

            # The output is a list of list of dicts, outer list corresponds to dataloader idx
            for dataloader_idx, test_outputs in enumerate(self.test_step_outputs):
                # Get prefix and dispatch call to multi epoch end
                dataloader_prefix = self.get_test_dataloader_prefix(dataloader_idx)
                dataloader_logs = self.multi_test_epoch_end(test_outputs, dataloader_idx=dataloader_idx)

                # If result was not provided, generate empty dict
                dataloader_logs = dataloader_logs or {}

                # Perform `test_loss` resolution first (if provided outside logs)
                if 'test_loss' in dataloader_logs:
                    if 'test_loss' not in output_dict and dataloader_idx == self._test_dl_idx:
                        output_dict['test_loss'] = dataloader_logs['test_loss']

                # For every item in the result dictionary
                for k, v in dataloader_logs.items():
                    # If the key is `log`
                    if k == 'log':
                        # Parse every element of the log, and attach the prefix name of the data loader
                        log_dict = {}
                        for k_log, v_log in v.items():
                            # If we are logging the loss, but dont provide it at result level,
                            # store it twice - once in log and once in result level.
                            # Also mark log with prefix name to avoid log level clash with other data loaders
                            if k_log not in output_dict['log'] and dataloader_idx == self._test_dl_idx:
                                new_k_log = k_log

                                # Also insert duplicate key with prefix for ease of comparison / avoid name clash
                                log_dict[dataloader_prefix + k_log] = v_log

                            else:
                                # Simply prepend prefix to key and save
                                new_k_log = dataloader_prefix + k_log

                            log_dict[new_k_log] = v_log

                        # Update log storage of individual data loader
                        output_logs = output_dict.get('log', {})
                        output_logs.update(log_dict)

                        # Update global log storage
                        output_dict['log'] = output_logs

                    else:
                        # If any values are stored outside 'log', simply prefix name and store
                        new_k = dataloader_prefix + k
                        output_dict[new_k] = v
                self.test_step_outputs[dataloader_idx].clear()  # free memory

            if 'log' in output_dict:
                self.log_dict(output_dict.pop('log'), on_epoch=True)

            # return everything else
            return output_dict

    def multi_validation_epoch_end(
        self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0
    ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
        """
        Adds support for multiple validation datasets. Should be overriden by subclass,
        so as to obtain appropriate logs for each of the dataloaders.

        Args:
            outputs: Same as that provided by LightningModule.on_validation_epoch_end()
                for a single dataloader.
            dataloader_idx: int representing the index of the dataloader.

        Returns:
            A dictionary of values, optionally containing a sub-dict `log`,
            such that the values in the log will be pre-pended by the dataloader prefix.
        """
        logging.warning(
            "Multi data loader support has been enabled, but "
            "`multi_validation_epoch_end(outputs, dataloader_idx) has not been implemented.\n"
            "If you require multi data loader support for validation sets, please override this method.\n"
            "If you do not require multi data loader support, please instead override "
            "`on_validation_epoch_end(outputs)."
        )

    def multi_test_epoch_end(
        self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0
    ) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
        """
        Adds support for multiple test datasets. Should be overriden by subclass,
        so as to obtain appropriate logs for each of the dataloaders.

        Args:
            outputs: Same as that provided by LightningModule.on_validation_epoch_end()
                for a single dataloader.
            dataloader_idx: int representing the index of the dataloader.

        Returns:
            A dictionary of values, optionally containing a sub-dict `log`,
            such that the values in the log will be pre-pended by the dataloader prefix.
        """
        logging.warning(
            "Multi data loader support has been enabled, but "
            "`multi_test_epoch_end(outputs, dataloader_idx) has not been implemented.\n"
            "If you require multi data loader support for validation sets, please override this method.\n"
            "If you do not require multi data loader support, please instead override "
            "`on_test_epoch_end(outputs)."
        )

    def get_validation_dataloader_prefix(self, dataloader_idx: int = 0) -> str:
        """
        Get the name of one or more data loaders, which will be prepended to all logs.

        Args:
            dataloader_idx: Index of the data loader.

        Returns:
            str name of the data loader at index provided.
        """
        return self._validation_names[dataloader_idx]

    def get_test_dataloader_prefix(self, dataloader_idx: int = 0) -> str:
        """
        Get the name of one or more data loaders, which will be prepended to all logs.

        Args:
            dataloader_idx: Index of the data loader.

        Returns:
            str name of the data loader at index provided.
        """
        return self._test_names[dataloader_idx]

    def load_part_of_state_dict(self, state_dict, include, exclude, load_from_string=None):
        """
        Load a part of the state dict into the model.
        """
        excluded_param_names = []
        # create dict
        dict_to_load = {}
        for k, v in state_dict.items():
            should_add = False
            # if any string in include is present, should add
            for p in include:
                if p in k:
                    should_add = True
                    break
            # except for if any string from exclude is present
            for e in exclude:
                if e in k:
                    excluded_param_names.append(k)
                    should_add = False
                    break
            if should_add:
                dict_to_load[k] = v

        # Restore checkpoint part into current model
        self.load_state_dict(dict_to_load, strict=False)
        if load_from_string is not None:
            logging.info(f'Model checkpoint partially restored from {load_from_string}')
            if len(excluded_param_names) > 0:
                logging.info(
                    'The following parameters were excluded when loading from '
                    f'{load_from_string} : {excluded_param_names}'
                )
                logging.info('Make sure that this is what you wanted!')
        else:
            if len(excluded_param_names) > 0:
                logging.info(
                    f'The following parameters were excluded when loading checkpoint : {excluded_param_names}'
                )

    @rank_zero_only
    def maybe_init_from_pretrained_checkpoint(self, cfg: OmegaConf, map_location: str = 'cpu'):
        """
        Initializes a given model with the parameters obtained via specific config arguments.
        The state dict of the provided model will be updated with `strict=False` setting so as to prevent
        requirement of exact model parameters matching.

        Initializations:
            init_from_nemo_model: Str path to a .nemo model in order to load state_dict from single nemo file;
            if loading from multiple files, pass in a dict where the values have the following fields:

                path: Str path to .nemo model

                include: Optional list of strings, at least one of which needs to be contained in parameter name
                to be loaded from this .nemo file. Default: everything is included.

                exclude: Optional list of strings, which can be used to exclude any parameter containing one of
                these strings from being loaded from this .nemo file. Default: nothing is excluded.

                hydra usage example:

                init_from_nemo_model:
                    model0:
                        path:<path/to/model1>
                        include:["encoder"]
                    model1:
                        path:<path/to/model2>
                        include:["decoder"]
                        exclude:["embed"]

            init_from_pretrained_model: Str name of a pretrained model checkpoint (obtained via cloud).
                The model will be downloaded (or a cached copy will be used), instantiated and then
                its state dict will be extracted. If loading from multiple models, you can pass in a dict
                with the same format as for init_from_nemo_model, except with "name" instead of "path"

            init_from_ptl_ckpt: Str name of a Pytorch Lightning checkpoint file. It will be loaded and
                the state dict will extracted. If loading from multiple files, you can pass in a dict
                with the same format as for init_from_nemo_model.

        Args:
            cfg: The config used to instantiate the model. It need only contain one of the above keys.
            map_location: str or torch.device() which represents where the intermediate state dict
                (from the pretrained model or checkpoint) will be loaded.

        """
        args = [
            'init_from_nemo_model',
            'init_from_pretrained_model',
            'init_from_ptl_ckpt',
        ]
        arg_matches = [(1 if arg in cfg and arg is not None else 0) for arg in args]

        if sum(arg_matches) == 0:
            # model weights do not need to be restored
            return

        if sum(arg_matches) > 1:
            raise ValueError(
                f"Cannot pass more than one model initialization arguments to config!\n"
                f"Found : {[args[idx] for idx, arg_present in enumerate(arg_matches) if arg_present]}"
            )

        CallbackGroup.get_instance().on_load_checkpoint_start()

        if 'init_from_nemo_model' in cfg and cfg.init_from_nemo_model is not None:
            with open_dict(cfg):
                if isinstance(cfg.init_from_nemo_model, str):
                    model_path = cfg.init_from_nemo_model
                    # Restore model
                    restored_model = self.restore_from(
                        model_path, map_location=map_location, strict=cfg.get("init_strict", True)
                    )
                    # Restore checkpoint into current model
                    self.load_state_dict(restored_model.state_dict(), strict=False)
                    logging.info(f'Model checkpoint restored from nemo file with path : `{model_path}`')
                    del restored_model
                elif isinstance(cfg.init_from_nemo_model, (DictConfig, dict)):
                    model_load_dict = cfg.init_from_nemo_model
                    for model_load_cfg in model_load_dict.values():
                        model_path = model_load_cfg.path
                        # Restore model
                        restored_model = self.restore_from(
                            model_path, map_location=map_location, strict=cfg.get("init_strict", True)
                        )

                        include = model_load_cfg.pop('include', [""])
                        exclude = model_load_cfg.pop('exclude', [])

                        self.load_part_of_state_dict(
                            restored_model.state_dict(), include, exclude, f'nemo file with path `{model_path}`'
                        )

                        del restored_model
                else:
                    raise TypeError("Invalid type: init_from_nemo_model is not a string or a dict!")

        if 'init_from_pretrained_model' in cfg and cfg.init_from_pretrained_model is not None:
            with open_dict(cfg):
                # Restore model

                if isinstance(cfg.init_from_pretrained_model, str):
                    model_name = cfg.pop('init_from_pretrained_model')

                    # Check if model is being resumed or not - only works if `Trainer` is attached to model
                    if hasattr(self, 'trainer') and self.trainer is not None:
                        trainer = self.trainer
                        if (
                            hasattr(trainer, 'resume_from_checkpoint')
                            and trainer._checkpoint_connector.resume_checkpoint_path is not None
                        ):
                            logging.info(
                                "Model training is being resumed via Pytorch Lightning.\n"
                                "Initialization from pretrained model (via cloud) will be skipped."
                            )
                            return

                    restored_model = self.from_pretrained(
                        model_name, map_location=map_location, strict=cfg.get("init_strict", True)
                    )

                    # Restore checkpoint into current model
                    self.load_state_dict(restored_model.state_dict(), strict=False)
                    logging.info(f'Model checkpoint restored from pretrained checkpoint with name : `{model_name}`')

                    del restored_model
                elif isinstance(cfg.init_from_pretrained_model, (DictConfig, dict)):
                    model_load_dict = cfg.init_from_pretrained_model
                    for model_load_cfg in model_load_dict.values():
                        model_name = model_load_cfg.name
                        # Restore model
                        restored_model = self.from_pretrained(
                            model_name, map_location=map_location, strict=cfg.get("init_strict", True)
                        )

                        include = model_load_cfg.pop('include', [""])
                        exclude = model_load_cfg.pop('exclude', [])

                        self.load_part_of_state_dict(
                            restored_model.state_dict(),
                            include,
                            exclude,
                            f'pretrained checkpoint with name `{model_name}`',
                        )

                        del restored_model
                else:
                    raise TypeError("Invalid type: init_from_pretrained_model is not a string or a dict!")

        if 'init_from_ptl_ckpt' in cfg and cfg.init_from_ptl_ckpt is not None:
            with open_dict(cfg):
                if isinstance(cfg.init_from_ptl_ckpt, str):
                    # Restore checkpoint
                    ckpt_path = cfg.pop('init_from_ptl_ckpt')
                    ckpt = torch.load(ckpt_path, map_location=map_location)

                    # Restore checkpoint into current model
                    self.load_state_dict(ckpt['state_dict'], strict=False)
                    logging.info(
                        f'Model checkpoint restored from pytorch lightning checkpoint with path : `{ckpt_path}`'
                    )

                    del ckpt
                elif isinstance(cfg.init_from_ptl_ckpt, (DictConfig, dict)):
                    model_load_dict = cfg.init_from_ptl_ckpt
                    for model_load_cfg in model_load_dict.values():
                        ckpt_path = model_load_cfg.path
                        # Restore model
                        ckpt = torch.load(ckpt_path, map_location=map_location)

                        include = model_load_cfg.pop('include', [""])
                        exclude = model_load_cfg.pop('exclude', [])

                        self.load_part_of_state_dict(
                            ckpt['state_dict'], include, exclude, f'nemo file with path `{ckpt_path}`'
                        )

                        del ckpt
                else:
                    raise TypeError("Invalid type: init_from_ptl_ckpt is not a string or a dict!")

        # Track load checkpoint end
        CallbackGroup.get_instance().on_load_checkpoint_end()

    def teardown(self, stage: str):
        """
        Called at the end of fit and test.

        Args:
            stage: either 'fit' or 'test'
        """
        if stage == 'fit':
            # Update env variable to bypass multi gpu issue after training
            # This fix affects usage of trainer.test() after trainer.train()
            # If trainer.train() was done on multiple GPUs, then trainer.test()
            # will try to do ddp, even if its a new Trainer object with just 1 GPU.
            # Temporary patch to fix that
            if 'PL_TRAINER_GPUS' in os.environ:
                os.environ.pop('PL_TRAINER_GPUS')

        super().teardown(stage)

    @classmethod
    def extract_state_dict_from(
        cls,
        restore_path: str,
        save_dir: str,
        split_by_module: bool = False,
        save_restore_connector: SaveRestoreConnector = None,
    ):
        """
        Extract the state dict(s) from a provided .nemo tarfile and save it to a directory.

        Args:
            restore_path: path to .nemo file from which state dict(s) should be extracted
            save_dir: directory in which the saved state dict(s) should be stored
            split_by_module: bool flag, which determins whether the output checkpoint should
                be for the entire Model, or the individual module's that comprise the Model
            save_restore_connector (SaveRestoreConnector): Can be overrided to add custom save and restore logic.

        Example:
            To convert the .nemo tarfile into a single Model level PyTorch checkpoint
            ::
            state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from('asr.nemo', './asr_ckpts')


            To restore a model from a Model level checkpoint
            ::
            model = nemo.collections.asr.models.EncDecCTCModel(cfg)  # or any other method of restoration
            model.load_state_dict(torch.load("./asr_ckpts/model_weights.ckpt"))


            To convert the .nemo tarfile into multiple Module level PyTorch checkpoints
            ::
            state_dict = nemo.collections.asr.models.EncDecCTCModel.extract_state_dict_from(
                            'asr.nemo',
                            './asr_ckpts',
                            split_by_module=True
                        )


            To restore a module from a Module level checkpoint
            ::
            model = nemo.collections.asr.models.EncDecCTCModel(cfg)  # or any other method of restoration

            # load the individual components
            model.preprocessor.load_state_dict(torch.load("./asr_ckpts/preprocessor.ckpt"))
            model.encoder.load_state_dict(torch.load("./asr_ckpts/encoder.ckpt"))
            model.decoder.load_state_dict(torch.load("./asr_ckpts/decoder.ckpt"))


        Returns:
            The state dict that was loaded from the original .nemo checkpoint
        """
        if save_restore_connector is None:
            save_restore_connector = SaveRestoreConnector()

        if not path.exists(restore_path):
            raise FileExistsError(f"Can't find {restore_path}")

        cls.update_save_restore_connector(save_restore_connector)
        state_dict = cls._save_restore_connector.extract_state_dict_from(restore_path, save_dir, split_by_module)
        return state_dict

    def prepare_test(self, trainer: 'Trainer') -> bool:
        """
        Helper method to check whether the model can safely be tested
        on a dataset after training (or loading a checkpoint).

        ::

            trainer = Trainer()
            if model.prepare_test(trainer):
                trainer.test(model)

        Returns:
            bool which declares the model safe to test. Provides warnings if it has to
            return False to guide the user.
        """
        if not hasattr(self._cfg, 'test_ds'):
            logging.info("No `test_ds` config found within the manifest.")
            return False

        # Replace ddp multi-gpu until PTL has a fix
        DDP_WARN = """\n\nDuring testing, it is currently advisable to construct a new Trainer "
                    "with single GPU and no DDP to obtain accurate results.
                    "Following pattern should be used: "
                    "trainer = Trainer(devices=1, accelerator='gpu')"
                    "if model.prepare_test(trainer):"
                    "  trainer.test(model)\n\n"""

        if trainer is not None:
            if trainer.num_devices > 1:
                logging.warning(DDP_WARN)
                return False

        # Assign trainer to the model
        self.set_trainer(trainer)
        return True

    def set_trainer(self, trainer: Trainer):
        """
        Set an instance of Trainer object.

        Args:
            trainer: PyTorch Lightning Trainer object.
        """
        self.trainer = trainer
        self._trainer = trainer
        self.set_world_size(trainer)

    def set_world_size(self, trainer: Trainer):
        """
        Determines the world size from the PyTorch Lightning Trainer.
        And then updates AppState.

        Args:
            trainer (Trainer): PyTorch Lightning Trainer object
        """
        # Update AppState with world information from trainer
        self.world_size = 1

        if trainer is not None:
            if isinstance(trainer, Trainer):
                if trainer.num_devices and trainer.num_nodes:
                    self.world_size = trainer.num_devices * trainer.num_nodes
            else:
                logging.warning('World size can only be set by PyTorch Lightning Trainer.')
        app_state = AppState()
        app_state.world_size = self.world_size

    def summarize(self, max_depth: int = 1) -> model_summary.ModelSummary:
        """Summarize this LightningModule.

        Args:
            max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
                layer summary off. Default: 1.

        Return:
            The model summary object
        """
        return model_summary.summarize(self, max_depth=max_depth)

    def _update_dataset_config(self, dataset_name: str, config: Optional[Union[DictConfig, Dict]]):
        """
        Update the config (if not None) of the dataset by given name.
        Preserves said config after updating.

        Args:
            dataset_name: str name of the dataset whose config is being updated.
                Can be one of `train`, `validation` and `test`.
            config: Optional DictConfig or dict. If None is passed, this method simply returns.
                If dict is passed, it is cast into a DictConfig.
                The internal config is updated with the passed config.
        """
        if hasattr(self, '_multi_dataset_mode') and self._multi_dataset_mode is True:
            return

        if config is not None:
            if not isinstance(config, DictConfig):
                config = OmegaConf.create(config)

            if dataset_name in ['train', 'validation', 'test']:
                OmegaConf.set_struct(self.cfg, False)

                key_name = dataset_name + "_ds"
                self.cfg[key_name] = config

                OmegaConf.set_struct(self.cfg, True)

                # Update hyper parameters by calling property setter
                self.cfg = self._cfg
            else:
                raise ValueError("`dataset_name` when updating config must be one of [train, validation, test]")

    @property
    def num_weights(self):
        """
        Utility property that returns the total number of parameters of the Model.
        """
        num: int = 0
        for p in self.parameters():
            if p.requires_grad:
                num += p.numel()
        return num

    @property
    def cfg(self):
        """
        Property that holds the finalized internal config of the model.

        Note:
            Changes to this config are not reflected in the state of the model.
            Please create a new model using an updated config to properly update the model.
        """
        return self._cfg

    @LightningModule.trainer.getter
    def trainer(self):
        """
        Get the trainer object.
        """
        return self._trainer

    @cfg.setter
    def cfg(self, cfg):
        """
        Property that holds the finalized internal config of the model.

        Note:
            Changes to this config are not reflected in the state of the model.
            Please create a new model using an updated config to properly update the model.
        """
        self._cfg = cfg
        self._set_hparams(OmegaConf.create({'cfg': self._cfg}))

        # TODO: Remove in NeMo 1.7 (or when PTL fixes this on their end)
        if hasattr(self, '_hparams_initial') and 'cfg' in self._hparams_initial:
            self._hparams_initial['cfg'] = OmegaConf.to_object(self._cfg)

    @property
    def hparams(self):
        """
        Overwrite default hparams property to return the lastest model config.
        Without this change, the hparams property would return the old config if there was a direct change to
        self._cfg (e.g., in self.setup_optimization()) that was not done via `self.cfg = new_cfg`.
        """
        self._set_hparams(OmegaConf.create({'cfg': self._cfg}))

        if (
            hasattr(self, '_hparams_initial')
            and 'cfg' in self._hparams_initial
            and isinstance(self._hparams_initial['cfg'], DictConfig)
        ):
            self._hparams_initial['cfg'] = OmegaConf.to_object(self._hparams_initial['cfg'])

        return super().hparams

    @property
    def validation_step_outputs(self):
        """
        Cached outputs of validation_step. It can be a list of items (for single data loader) or a list of lists
        (for multiple data loaders).

        Returns:
            List of outputs of validation_step.
        """
        if self._validation_step_outputs is not None:
            return self._validation_step_outputs

        # Initialize new output list
        self._validation_step_outputs = []
        # Check len(self._validation_dl) > 1 as sometimes single dataloader can be in a
        # list: [<Dataloader obj>] when ds_item in config has 1 item passed in a list
        if (
            self._validation_dl is not None
            and isinstance(self._validation_dl, (list, tuple))
            and len(self._validation_dl) > 1
        ):
            for _ in range(len(self._validation_dl)):
                self._validation_step_outputs.append([])

        return self._validation_step_outputs

    @validation_step_outputs.setter
    def validation_step_outputs(self, value):
        self._validation_step_outputs = value

    @property
    def test_step_outputs(self):
        """
        Cached outputs of test_step. It can be a list of items (for single data loader) or a list of
        lists (for multiple data loaders).

        Returns:
            List of outputs of test_step.
        """
        if self._test_step_outputs is not None:
            return self._test_step_outputs

        # Initialize new output list
        self._test_step_outputs = []
        # Check len(self._test_dl) > 1 as sometimes single dataloader can be in a list: [<Dataloader obj>]
        # when ds_item in config has 1 item passed in a list
        if self._test_dl is not None and isinstance(self._test_dl, (list, tuple)) and len(self._test_dl) > 1:
            for _ in range(len(self._test_dl)):
                self._test_step_outputs.append([])

        return self._test_step_outputs

    @test_step_outputs.setter
    def test_step_outputs(self, value):
        self._test_step_outputs = value

    @staticmethod
    def _is_model_being_restored() -> bool:
        app_state = AppState()
        return app_state.is_model_being_restored

    @staticmethod
    def _set_model_restore_state(is_being_restored: bool, folder: str = None):
        app_state = AppState()
        app_state.is_model_being_restored = is_being_restored
        app_state.nemo_file_folder = folder

    def _set_model_guid(self):
        if not hasattr(self, 'model_guid'):
            appstate = AppState()

            # Generate a unique uuid for the instance
            # also determine if the model is being restored or not, and preserve the path
            self.model_guid = str(uuid.uuid4())
            if self._is_model_being_restored():
                restore_path = appstate.model_restore_path
            else:
                restore_path = None

            appstate.register_model_guid(self.model_guid, restoration_path=restore_path)

    @classmethod
    def update_save_restore_connector(cls, save_restore_connector):
        """
        Update the save_restore_connector for the model.
        """
        if hasattr(cls, '_save_restore_connector'):
            cls._save_restore_connector = save_restore_connector
        else:
            setattr(cls, '_save_restore_connector', save_restore_connector)

    def _setup_chakra_profiling(self):
        """Enables chakra profiling
        To use, add the following options to the model config:
        ## Chakra profiling options
        chakra_profile:
            enabled: False
            start_step: 2  # Global batch to start profiling
            end_step: 2 # Global batch to end profiling
            warmup_steps: 0  # Global batch to start profiling
            active_steps: 1  # Global batch to start profiling
            trace_dir: None # Path to store the profile output file
        """
        if self.cfg.get('chakra_profile', None) is not None:
            if self.cfg.chakra_profile.get('enabled', False):

                from torch.profiler import ExecutionTraceObserver

                from nemo.utils.env_var_parsing import get_envint

                self._chakra_profile_enabled = True
                self._chakra_profile_start_step = self.cfg.chakra_profile.get('start_step', 0)
                self._chakra_profile_end_step = self.cfg.chakra_profile.get('end_step', 0)
                trace_dir = self.cfg.chakra_profile.get('trace_dir', None)

                if trace_dir is None or not os.path.isdir(trace_dir):
                    raise ValueError(f'chakra profile output path ({trace_dir}) is not set or does not exist.')

                trace_dir = Path(trace_dir)
                warmup_steps = self.cfg.chakra_profile.get('warmup_steps', 0)
                active_steps = self.cfg.chakra_profile.get('active_steps', 1)

                job_id = get_envint("SLURM_JOB_ID", 0)

                self._chakra_trace_dir = trace_dir / f'{job_id}_chakra'
                self._kineto_trace_dir = trace_dir / f'{job_id}_kineto'

                self._chakra_trace_dir.mkdir(parents=True, exist_ok=True)
                self._kineto_trace_dir.mkdir(parents=True, exist_ok=True)

                if isinstance(self._chakra_profile_start_step, int):
                    logging.info(f'chakra profiling setup with start_step: {self._chakra_profile_start_step}')
                else:
                    raise ValueError(
                        f'chakra start_step must be of type int. Found: {type(self._chakra_profile_start_step)}'
                    )

                if isinstance(self._chakra_profile_end_step, int):
                    logging.info(f'chakra profiling setup with end_step: {self._chakra_profile_end_step}')
                else:
                    raise ValueError(
                        f'chakra end_step must be of type int. Found: {type(self._chakra_profile_end_step)}'
                    )

                if self._chakra_profile_end_step >= self._chakra_profile_start_step:
                    pass
                else:
                    raise ValueError('chakra end_step must be greater than or equal to chakra start_step')

                if self.cfg.nsys_profile.get('enabled', False):
                    raise Exception(
                        "Profiler conflict: Chakra profiling and Nsys profiling cannot be enabled at the same time."
                    )

                self._et = ExecutionTraceObserver()
                self._prof = torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    schedule=torch.profiler.schedule(wait=0, warmup=warmup_steps, active=active_steps),
                    execution_trace_observer=self._et,
                )

    def _setup_profiling(self):
        """Enables nsys profiling
        To use, add the following optoins to the model config:
        ## Nsys profiling options
        nsys_profile: False
            start_step: 10  # Global batch to start profiling
            end_step: 10 # Global batch to end profiling
            ranks: [0] # Global rank IDs to profile
            gen_shape: False # Generate model and kernel details including input shapes
        And then wrap the model training script with:
        nsys profile -s none -o <profile filepath>  -t cuda,nvtx --force-overwrite true
            --capture-range=cudaProfilerApi --capture-range-end=stop python ./examples/...
        See more options at: https://docs.nvidia.com/nsight-systems/UserGuide/index.html#cli-profiling

        Enables CUDA memory profiling
        To use, add the following options to the model config:
        ## CUDA memory profiling options
        memory_profile:
            enabled: True
            start_step: 10  # Global batch to start profiling
            end_step: 10 # Global batch to end profiling
            rank: 0 # Global rank ID to profile
            output_path: None # Path to store the profile output file
        """
        if self.cfg.get('nsys_profile', None) is not None:
            if self.cfg.nsys_profile.get('enabled', False):
                # Nsys profiling options
                self._nsys_profile_enabled = True
                self._nsys_profile_start_step = self.cfg.nsys_profile.get('start_step', 0)
                self._nsys_profile_end_step = self.cfg.nsys_profile.get('end_step', 0)
                self._nsys_profile_ranks = self.cfg.nsys_profile.get('ranks', [0])
                self._nsys_profile_gen_shape = self.cfg.nsys_profile.get('gen_shape', False)

                if type(self._nsys_profile_start_step) == int:
                    logging.info(f'Nsys profiling setup with start_step: {self._nsys_profile_start_step}')
                else:
                    raise ValueError(
                        f'Nsys start_step must be of type int. Found: {type(self._nsys_profile_start_step)}'
                    )

                if type(self._nsys_profile_end_step) == int:
                    logging.info(f'Nsys profiling setup with end_step: {self._nsys_profile_end_step}')
                else:
                    raise ValueError(f'Nsys end_step must be of type int. Found: {type(self._nsys_profile_end_step)}')

                if self._nsys_profile_end_step >= self._nsys_profile_start_step:
                    pass
                else:
                    raise ValueError('Nsys end_step must be greater than or equal to nsys start_step')

        if self.cfg.get('memory_profile', None) is not None:
            if self.cfg.memory_profile.get('enabled', False):
                # CUDA memory profiling options
                self._memory_profile_enabled = True
                self._memory_profile_start_step = self.cfg.memory_profile.get('start_step', 0)
                self._memory_profile_end_step = self.cfg.memory_profile.get('end_step', 0)
                self._memory_profile_rank = self.cfg.memory_profile.get('rank', 0)
                self._memory_profile_output_path = self.cfg.memory_profile.get('output_path', None)

                if type(self._memory_profile_start_step) == int:
                    logging.info(f'Nsys profiling setup with start_step: {self._memory_profile_start_step}')
                else:
                    raise ValueError(
                        f'CUDA memory start_step must be of type int. Found: {type(self._memory_profile_start_step)}'
                    )

                if type(self._memory_profile_end_step) == int:
                    logging.info(f'CUDA memory profiling setup with end_step: {self._memory_profile_end_step}')
                else:
                    raise ValueError(
                        f'CUDA memory end_step must be of type int. Found: {type(self._memory_profile_end_step)}'
                    )

                if self._memory_profile_end_step >= self._memory_profile_start_step:
                    pass
                else:
                    raise ValueError('CUDA memory end_step must be greater than or equal to memory start_step')

                if self._memory_profile_output_path is None or not os.path.isdir(self._memory_profile_output_path):
                    raise ValueError(
                        f'Memory profile output path ({self._memory_profile_output_path}) is not set '
                        'or does not exist.'
                    )

    def on_train_start(self):
        """PyTorch Lightning hook:
        https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-start
        We use it here to copy the relevant config for dynamic freezing.
        """

        # dynamic freezing
        # should fire only once, on the very first batch of training and never again
        if not hasattr(self, '_freeze_cfg'):
            if (
                hasattr(self.cfg, 'freeze_updates')
                and self.cfg.freeze_updates is not None
                and self.cfg.freeze_updates.get('enabled', False)
            ):
                setattr(self, '_freeze_cfg', OmegaConf.to_container(self.cfg.freeze_updates))
                self._freeze_cfg['is_frozen'] = {k: False for k in self._freeze_cfg['modules'].keys()}
            else:
                setattr(self, '_freeze_cfg', None)

    def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]:
        """PyTorch Lightning hook:
        https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-start
        We use it here to enable profiling and dynamic freezing.
        """
        if self.device.type == 'cuda':
            if hasattr(self, '_chakra_profile_enabled'):
                if self._chakra_profile_enabled and not self._chakra_profile_in_progress:
                    if (
                        self.trainer.global_step >= self._chakra_profile_start_step
                        and self.trainer.global_step <= self._chakra_profile_end_step
                    ):
                        logging.info(
                            f"====== Start chakra profiling from global_step {self.trainer.global_step} ======"
                        )
                        self._et.register_callback(str(self._chakra_trace_dir / f'rank-{get_rank()}.json'))
                        self._prof.start()
                        self._chakra_profile_in_progress = True

            if hasattr(self, '_nsys_profile_enabled'):
                if self._nsys_profile_enabled and not self._nsys_profile_started:
                    if batch_idx >= self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks:
                        logging.info("====== Start nsys profiling ======")
                        torch.cuda.cudart().cudaProfilerStart()
                        if self._nsys_profile_gen_shape:
                            torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
                        self._nsys_profile_started = True

            if hasattr(self, '_memory_profile_enabled'):
                if self._memory_profile_enabled and not self._memory_profile_started:
                    if batch_idx >= self._memory_profile_start_step and get_rank() == self._memory_profile_rank:
                        logging.info("====== Start CUDA memory profiling ======")
                        torch.cuda.memory._record_memory_history(max_entries=100000)
                        self._memory_profile_started = True

        # dynamic freezing
        if hasattr(self, '_freeze_cfg') and self._freeze_cfg is not None:
            if self.training and hasattr(self, "trainer") and self.trainer is not None:
                num_updates = self.trainer.global_step + 1

                for ml, m_steps in self._freeze_cfg['modules'].items():
                    # we could do hasattr check here, but it's too expensive for each step
                    # consequently you'll throw an error if the module name doesn't exist
                    # or was spelled wrong in the config.yaml
                    if isinstance(m_steps, list):
                        assert len(m_steps) == 2, "freeze_updates modules list cannot have more than two elements"
                        should_freeze = (num_updates >= m_steps[0]) and (num_updates <= m_steps[1] or m_steps[1] == -1)
                    else:
                        should_freeze = num_updates <= m_steps or m_steps == -1
                    if should_freeze and not self._freeze_cfg['is_frozen'][ml]:
                        getattr(self, ml).freeze()
                        getattr(self, ml).train()
                        self._freeze_cfg['is_frozen'][ml] = True
                    elif not should_freeze and self._freeze_cfg['is_frozen'][ml]:
                        getattr(self, ml).unfreeze()
                        self._freeze_cfg['is_frozen'][ml] = False

    def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int = 0) -> None:
        """PyTorch Lightning hook:
        https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-end
        We use it here to enable nsys profiling.
        """

        if self.device.type == 'cuda':
            if hasattr(self, '_chakra_profile_enabled'):
                # self.trainer.global_step is increaeasd before on_train_batch_end
                if self._chakra_profile_enabled and self._chakra_profile_in_progress:
                    if self.trainer.global_step - 1 >= self._chakra_profile_end_step:
                        logging.info(f"====== End chakra profiling at global_step {self.trainer.global_step} ======")
                        self._prof.stop()
                        self._prof.export_chrome_trace(str(self._kineto_trace_dir / f'rank-{get_rank()}.json'))
                        self._et.unregister_callback()
                        self._chakra_profile_in_progress = False
                    elif self.trainer.global_step - 1 >= self._chakra_profile_start_step:
                        self._prof.step()

            if hasattr(self, '_nsys_profile_enabled'):
                if self._nsys_profile_enabled and not self._nsys_profile_complete:
                    if batch_idx >= self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks:
                        logging.info("====== End nsys profiling ======")
                        torch.cuda.cudart().cudaProfilerStop()
                        self._nsys_profile_complete = True

            if hasattr(self, '_memory_profile_enabled'):
                if self._memory_profile_enabled and not self._memory_profile_complete:
                    if batch_idx >= self._memory_profile_end_step and get_rank() == self._memory_profile_rank:
                        logging.info("====== End CUDA memory profiling ======")
                        torch.cuda.memory._dump_snapshot(
                            f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}.pickle'
                        )
                        torch.cuda.memory._record_memory_history(enabled=None)
                        self._memory_profile_complete = True

    def _cleanup_on_execution_end(self):
        """
        Utility function to clean up the module state at the end of execution.
        """

        # dynamic freezing cleanup
        if hasattr(self, '_freeze_cfg'):
            delattr(self, '_freeze_cfg')

        # Clear up the val and test output caches
        self._validation_step_outputs = None
        self._test_step_outputs = None

    def on_train_end(self):
        """PyTorch Lightning hook:
        https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
        We use it here to cleanup the dynamic freezing config.
        """

        self._cleanup_on_execution_end()

    def on_test_end(self):
        """PyTorch Lightning hook:
        https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-test-end
        """

        self._cleanup_on_execution_end()

    def on_predict_end(self):
        """PyTorch Lightning hook:
        https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-test-end
        """

        self._cleanup_on_execution_end()

    def _optim_config_copy(self, optim_config: Optional[Union[DictConfig, Dict]]) -> Optional[DictConfig]:
        """
        Return a copy of `optim_config` if provided (and otherwise of the internal optim config, if available).
        """
        if optim_config is None:
            # See if internal config has `optim` namespace
            if self._cfg is not None and hasattr(self._cfg, 'optim'):
                optim_config = self._cfg.optim

        if optim_config is None:
            return None

        if isinstance(optim_config, DictConfig):
            return copy.deepcopy(optim_config)
        else:
            return OmegaConf.create(optim_config)
