# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import singledispatch
from typing import Any, TypeVar

from lightning.fabric import plugins as fl_plugins
from lightning.fabric import strategies as fl_strategies
from lightning.pytorch import plugins as pl_plugins
from lightning.pytorch import strategies as pl_strategies

T = TypeVar('T')
FabricT = TypeVar('FabricT')


@singledispatch
def to_fabric(obj: Any) -> Any:
    """
    Convert a PyTorch Lightning object to its Fabric equivalent.

    Args:
        obj: The object to convert.

    Returns:
        The Fabric equivalent of the input object.

    Raises:
        NotImplementedError: If no converter is registered for the object's type.

    Example:
        >>> from lightning.pytorch.strategies import Strategy as PLStrategy
        >>> from lightning.fabric.strategies import Strategy as FabricStrategy
        >>> from nemo.lightning.fabric.conversion import to_fabric
        >>>
        >>> # Define a custom PyTorch Lightning strategy
        >>> class CustomPLStrategy(PLStrategy):
        ...     def __init__(self, custom_param: str):
        ...         super().__init__()
        ...         self.custom_param = custom_param
        >>>
        >>> # Define a custom Fabric strategy
        >>> class CustomFabricStrategy(FabricStrategy):
        ...     def __init__(self, custom_param: str):
        ...         super().__init__()
        ...         self.custom_param = custom_param
        >>>
        >>> # Register a custom conversion
        >>> @to_fabric.register(CustomPLStrategy)
        ... def _custom_converter(strategy: CustomPLStrategy) -> CustomFabricStrategy:
        ...     return CustomFabricStrategy(custom_param=strategy.custom_param)
        >>>
        >>> # Use the custom conversion
        >>> pl_strategy = CustomPLStrategy(custom_param="test")
        >>> fabric_strategy = to_fabric(pl_strategy)
        >>> assert isinstance(fabric_strategy, CustomFabricStrategy)
        >>> assert fabric_strategy.custom_param == "test"
    """
    raise NotImplementedError(
        f"No Fabric converter registered for {type(obj).__name__}. "
        f"To register a new conversion, use the @to_fabric.register decorator:\n\n"
        f"from nemo.lightning.fabric.conversion import to_fabric\n"
        f"from lightning.fabric import strategies as fl_strategies\n\n"
        f"@to_fabric.register({type(obj).__name__})\n"
        f"def _{type(obj).__name__.lower()}_converter(obj: {type(obj).__name__}) -> fl_strategies.Strategy:\n"
        f"    return fl_strategies.SomeStrategy(\n"
        f"        # Map relevant attributes from 'obj' to Fabric equivalent\n"
        f"        param1=obj.param1,\n"
        f"        param2=obj.param2,\n"
        f"        # ... other parameters ...\n"
        f"    )\n\n"
        f"Add this code to the appropriate module (e.g., nemo/lightning/fabric/conversion.py)."
    )


@to_fabric.register(pl_strategies.DDPStrategy)
def _ddp_converter(strategy: pl_strategies.DDPStrategy) -> fl_strategies.DDPStrategy:
    return fl_strategies.DDPStrategy(
        accelerator=strategy.accelerator,
        parallel_devices=strategy.parallel_devices,
        cluster_environment=strategy.cluster_environment,
        process_group_backend=strategy.process_group_backend,
        timeout=strategy._timeout,
        start_method=strategy._start_method,
        **strategy._ddp_kwargs,
    )


@to_fabric.register(pl_strategies.FSDPStrategy)
def _fsdp_converter(strategy: pl_strategies.FSDPStrategy) -> fl_strategies.FSDPStrategy:
    return fl_strategies.FSDPStrategy(
        cpu_offload=strategy.cpu_offload,
        parallel_devices=strategy.parallel_devices,
        cluster_environment=strategy.cluster_environment,
        process_group_backend=strategy.process_group_backend,
        timeout=strategy._timeout,
        **strategy.kwargs,
    )


@to_fabric.register(pl_plugins.MixedPrecision)
def _mixed_precision_converter(plugin: pl_plugins.MixedPrecision) -> fl_plugins.MixedPrecision:
    return fl_plugins.MixedPrecision(
        precision=plugin.precision,
        device=plugin.device,
        scaler=plugin.scaler,
    )


@to_fabric.register(pl_plugins.FSDPPrecision)
def _fsdp_precision_converter(plugin: pl_plugins.FSDPPrecision) -> fl_plugins.FSDPPrecision:
    return fl_plugins.FSDPPrecision(
        precision=plugin.precision,
    )
