# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from copy import deepcopy

import fiddle as fdl
import lightning.pytorch as pl
from lightning.pytorch.loops import _TrainingEpochLoop
from lightning.pytorch.loops.fetchers import _DataFetcher
from typing_extensions import Self

from nemo.lightning.fabric.conversion import to_fabric
from nemo.lightning.fabric.fabric import Fabric
from nemo.lightning.io.mixin import IOMixin, serialization, track_io


class NoValOnRestartTrainingLoop(_TrainingEpochLoop):
    """
    Extend the PTL Epoch loop to skip validation when restarting.
    This happens when resuming a checkpoint that has already run validation, but loading restores
    the training state before validation has run.
    """

    def _should_check_val_fx(self, data_fetcher) -> bool:
        if self.skip_val_on_restart:
            return False
        return super()._should_check_val_fx(data_fetcher)

    def load_state_dict(self, state_dict: dict, prefix: str = "") -> None:
        super().load_state_dict(state_dict, prefix)

        self.skip_val_on_restart = True

    def advance(self, data_fetcher: _DataFetcher) -> None:
        super().advance(data_fetcher)

        self.skip_val_on_restart = False


def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None:
    if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop):
        warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning)
        return

    ## Pass trainer object to avoid trainer getting overwritten as None
    loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps)
    trainer.fit_loop.epoch_loop = loop


class Trainer(pl.Trainer, IOMixin):
    def add_io(self, obj):
        """Recurse to the leaves of a container and add io functionality to non-serializable leaves"""
        if isinstance(obj, (dict, list)):
            if isinstance(obj, dict):
                obj = obj.values()
            for item in obj:
                self.add_io(item)
        else:
            if not serialization.find_node_traverser(type(obj)):
                track_io(type(obj))
            return

    def io_init(self, **kwargs) -> fdl.Config[Self]:
        # Each argument of the trainer can be stateful so we copy them
        cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items()}

        self.add_io(cfg_kwargs)
        return fdl.Config(type(self), **cfg_kwargs)

    def to_fabric(self, callbacks=None, loggers=None) -> Fabric:
        accelerator, devices, strategy, plugins, num_nodes = None, None, None, None, None
        if hasattr(self.__io__, "devices"):
            devices = self.__io__.devices
        if hasattr(self.__io__, "accelerator"):
            accelerator = self.__io__.accelerator
        if hasattr(self.__io__, "strategy"):
            strategy = self.__io__.strategy
            if isinstance(strategy, fdl.Config):
                strategy = fdl.build(strategy)

            strategy = to_fabric(strategy)
        if hasattr(self.__io__, "plugins"):
            plugins = self.__io__.plugins
            if isinstance(plugins, fdl.Config):
                plugins = fdl.build(plugins)
            plugins = to_fabric(plugins)

        if hasattr(self.__io__, "num_nodes"):
            num_nodes = self.__io__.num_nodes

        out = Fabric(
            devices=devices,
            accelerator=accelerator,
            strategy=strategy,
            plugins=plugins,
            callbacks=callbacks,
            loggers=loggers,
            num_nodes=num_nodes,
        )

        return out
