# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Optional

import lightning.pytorch as pl
import lightning.pytorch as L
from torch.optim import Optimizer
from torch.optim.optimizer import ParamsT

from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule


def _param_does_not_have_wd(param_name, param):
    return 'bias' in param_name


def _extract_model_params_for_optim(model, weight_decay=0, no_weight_decay_cond=None):
    params_with_wd, params_without_wd = [], []
    if no_weight_decay_cond is not None:
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
            if no_weight_decay_cond(name, param):
                params_without_wd.append(param)
            else:
                params_with_wd.append(param)
    else:
        params_with_wd = list(filter(lambda x: x.requires_grad, model.parameters()))

    assert max(map(len, (params_with_wd, params_without_wd))) > 0, "Expected at least one optimizer with params"

    return [
        {'params': params, 'weight_decay': wd}
        for params, wd in zip((params_with_wd, params_without_wd), (weight_decay, 0))
    ]


class PytorchOptimizerModule(OptimizerModule):
    """A OptimizerModule for pytorch optimizers.

    Attributes:
        optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
        no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
        scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
        lr_mult (float): Learning rate multiplier.

    Example::

        optimizer_fn = run.Partial(
            SGD,
            lr=lr,
            weight_decay=wd,
        )
        lr_scheduler = MyLRSchedulerModule(...)
        optimizer_module = PytorchOptimizerModule(optimizer_fn, lr_scheduler)

    Methods:
        setup(model): Sets up the optimizer.
        optimizers(model): Defines the optimizers.
    """

    def __init__(
        self,
        optimizer_fn: Callable[[ParamsT], Optimizer],
        lr_scheduler: Optional[LRSchedulerModule] = None,
        no_weight_decay_cond: Optional[Callable] = _param_does_not_have_wd,
        scale_lr_cond: Optional[Callable] = None,
        lr_mult: float = 1.0,
    ):
        """Initializes the PytorchOptimizerModule.

        Args:
            optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
            lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module.
            no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
            scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
            lr_mult (float): Learning rate multiplier.
        """

        super().__init__(lr_scheduler=lr_scheduler)
        self.optimizer_fn = optimizer_fn
        self.no_weight_decay_cond = no_weight_decay_cond
        self.scale_lr_cond = scale_lr_cond
        self.lr_mult = lr_mult

    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        """nooop"""
        # Noop
        pass

    def optimizers(self, model) -> List[Optimizer]:
        """Defines the optimizers.

        Args:
            model (nn.Module): The model for which the optimizers are being defined.

        Returns:
            List[Optimizer]: The list of optimizers.

        Raises:
            ValueError: If the model is an instance of MegatronParallel.
        """

        if isinstance(model, MegatronParallel):
            raise ValueError("Model cannot be an instance of MegatronParallel")

        wd = self.optimizer_fn.keywords.get('weight_decay', 0)
        optim = self.optimizer_fn(_extract_model_params_for_optim(model, wd, self.no_weight_decay_cond))
        self._optimizers = optim
        if not isinstance(optim, list):
            optim = [optim]
        if self.lr_scheduler is None:
            return optim
        else:
            return [self.lr_scheduler.scheduler(model, opt) for opt in optim]

    def connect(self, model: L.LightningModule) -> None:
        """Connects the optimizer module to the model.

        Args:
            model (L.LightningModule): The model to which the optimizer module is being connected.
        """
        model.configure_optimizers = lambda: self.optimizers(model)
