# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from typing import Optional, Union

from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.pytorch import Trainer
from lightning.pytorch.plugins.environments import TorchElasticEnvironment
from omegaconf import DictConfig, open_dict

from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback
from nemo.collections.common.parts.nlp_overrides import (
    CustomProgressBar,
    FSDPMixedPrecisionPlugin,
    GradScaler,
    MegatronHalfPrecisionPlugin,
    NLPDDPStrategy,
    NLPDDPStrategyNotebook,
    NLPFSDPStrategy,
    PipelineMixedPrecisionPlugin,
)
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import (
    AsyncFinalizableCheckpointIO,
    AsyncFinalizerCallback,
    DistributedCheckpointIO,
)


class MegatronTrainerBuilder:
    """
    Builder type to hide complex configuration of PTL Trainers for Megatron LLM models.
    Can be extended to change behavior for a specific model.
    """

    def __init__(self, cfg: DictConfig) -> None:
        self.cfg = cfg

    def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]:
        """
        Returns a DDP or a FSDP strategy passed to Trainer.strategy.
        """
        # check interactive environment
        _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
        if _IS_INTERACTIVE and self.cfg.trainer.devices == 1:
            logging.info("Detected interactive environment, using NLPDDPStrategyNotebook")
            return NLPDDPStrategyNotebook(
                no_ddp_communication_hook=True,
                find_unused_parameters=False,
            )

        if self.cfg.model.get('fsdp', False):
            assert (
                not self.cfg.model.optim.get('name') == 'distributed_fused_adam'
                and not self.cfg.model.optim.get('name') == 'mcore_distributed_optim'
            ), 'Distributed optimizer cannot be used with FSDP.'
            sharded_checkpoint = self.cfg.model.get('fsdp_sharded_checkpoint', False)
            if self.cfg.model.get('tensor_model_parallel_size', 1) > 1:
                assert not sharded_checkpoint, 'FSDP sharded checkpoint is not supported when TP size > 1.'
            if self.cfg.model.get('megatron_amp_O2', False):
                logging.info('Torch FSDP is not compatible with O2 precision recipe. Setting O2 `False`.')
                self.cfg.model.megatron_amp_O2 = False
            return NLPFSDPStrategy(
                limit_all_gathers=self.cfg.model.get('fsdp_limit_all_gathers', True),
                sharding_strategy=self.cfg.model.get('fsdp_sharding_strategy', 'full'),
                cpu_offload=self.cfg.model.get('fsdp_cpu_offload', False),
                grad_reduce_dtype=self.cfg.model.get('fsdp_grad_reduce_dtype', 32),
                sharded_checkpoint=sharded_checkpoint,
                precision=self.cfg.trainer.precision,
                nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
                sharp=self.cfg.model.get('sharp', False),
                use_orig_params=self.cfg.model.get('fsdp_use_orig_params', False),
            )

        return NLPDDPStrategy(
            no_ddp_communication_hook=True,
            gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
            find_unused_parameters=False,
            nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
            sharp=self.cfg.model.get('sharp', False),
            dist_ckpt_parallel_save=self.cfg.model.get('dist_ckpt_parallel_dist_opt', True),
        )

    def _grad_scaler(self) -> GradScaler:
        """
        Returns a scaler for precision plugins.
        """
        return GradScaler(
            init_scale=self.cfg.model.get('native_amp_init_scale', 2**32),
            growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000),
            hysteresis=self.cfg.model.get('hysteresis', 2),
        )

    def _plugins(self) -> list:
        """
        Returns:
            plugins: list of plugins passed to Trainer.plugins including precision plugins.
        """
        megatron_amp_O2 = self.cfg.model.get('megatron_amp_O2', False)
        with_distributed_adam = (
            (
                self.cfg.model.optim.get('name') == 'distributed_fused_adam'
                or self.cfg.model.optim.get('name') == 'mcore_distributed_optim'
            )
            if self.cfg.model.get('optim')
            else False
        )

        plugins = []
        if self.cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']:
            scaler = None
            if self.cfg.trainer.precision in [16, '16', '16-mixed']:
                if not self.cfg.model.get('fsdp', False):
                    scaler = self._grad_scaler()
                plugin_precision = '16-mixed'
            else:
                plugin_precision = 'bf16-mixed'

            if megatron_amp_O2 and not with_distributed_adam:
                plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))
            else:
                if self.cfg.model.get('fsdp', False):
                    plugins.append(FSDPMixedPrecisionPlugin(precision=plugin_precision, scaler=scaler))
                else:
                    plugins.append(
                        PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)
                    )
            self.cfg.trainer.precision = None

        if self.cfg.get('cluster_type', None) == 'BCP':
            plugins.append(TorchElasticEnvironment())

        # Use dist-ckt for non-FSDP MCore models
        use_dist_ckpt = not self.cfg.model.get('fsdp', False) and (
            self.cfg.model.get('mcore_gpt', False) or self.cfg.model.get('mcore_bert', False)
        )
        # exp_manager == None is valid and indicates no exp_manager should be initialized
        async_save = (
            (self.cfg.get('exp_manager', {}) or {}).get('checkpoint_callback_params', {}).get('async_save', False)
        )
        if use_dist_ckpt:
            checkpoint_io = DistributedCheckpointIO.from_config(self.cfg.model, async_save)
            if async_save:
                checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io)
            plugins.append(checkpoint_io)
        elif async_save:
            raise MisconfigurationException(
                'exp_manager.checkpoint_callback_params.async_save=True without'
                'distributed checkpoints is currently not supported'
            )

        return plugins

    def _callbacks(self, callbacks: Optional[list]) -> list:
        """
        Returns:
            callbacks: list of callbacks passed to Trainer.callbacks.
        """
        if callbacks is None:
            callbacks = []
        # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False,
        # CustomProgressBar is not appended to callbacks
        if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
            callbacks.append(CustomProgressBar())

        # exp_manager == None is valid and indicates no exp_manager should be initialized
        if (self.cfg.get('exp_manager', {}) or {}).get('checkpoint_callback_params', {}).get('async_save', False):
            callbacks.append(AsyncFinalizerCallback())

        # exp_manager == None is valid and indicates no exp_manager should be initialized
        if (self.cfg.get('exp_manager', {}) or {}).get('log_tflops_per_sec_per_gpu', True):
            callbacks.append(FLOPsMeasurementCallback(self.cfg))

        return callbacks

    def create_trainer(self, callbacks=None) -> Trainer:
        """ """
        # Make a dummy train step if skip_train
        if self.cfg.model.get("skip_train", False):
            self.cfg.trainer.max_steps = 1
            self.cfg.trainer.val_check_interval = 1
            # Set num_sanity_val_steps to 0
            with open_dict(self.cfg.trainer):
                self.cfg.trainer.num_sanity_val_steps = 0
            self.cfg.exp_manager.create_checkpoint_callback = False

        # cfg.trainer.precision becomes None in Trainer if precision_plugins exist
        # since both precision plugins and precision
        precision = self.cfg.trainer.precision
        strategy = self._training_strategy()
        plugins = self._plugins()
        callbacks = self._callbacks(callbacks)
        trainer = Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)
        # Restore the precision value after Trainer is built.
        self.cfg.trainer.precision = precision
        return trainer


class MegatronLMPPTrainerBuilder(MegatronTrainerBuilder):
    """Builder for scripts where grad scaler is turned off for pipeline parallel LM model. E.g. PEFT tuning scripts"""

    def _grad_scaler(self) -> GradScaler:
        return GradScaler(
            init_scale=self.cfg.model.get("native_amp_init_scale", 2**32),
            growth_interval=self.cfg.model.get("native_amp_growth_interval", 1000),
            hysteresis=self.cfg.model.get("hysteresis", 2),
            enabled=False if self.cfg.model.pipeline_model_parallel_size > 1 else True,
        )
