# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import logging
import traceback
from itertools import chain
from typing import TYPE_CHECKING

from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_xccl

from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum
from vllm_omni.plugins import (
    OMNI_PLATFORM_PLUGINS_GROUP,
    load_omni_plugins_by_group,
)

logger = logging.getLogger(__name__)


def cuda_omni_platform_plugin() -> str | None:
    """Check if CUDA OmniPlatform should be activated."""
    is_cuda = False
    logger.debug("Checking if CUDA OmniPlatform is available.")
    try:
        from vllm.utils.import_utils import import_pynvml

        pynvml = import_pynvml()
        pynvml.nvmlInit()
        try:
            if pynvml.nvmlDeviceGetCount() > 0:
                is_cuda = True
                logger.debug("Confirmed CUDA OmniPlatform is available.")
            else:
                logger.debug("CUDA OmniPlatform is not available because no GPU is found.")
        finally:
            pynvml.nvmlShutdown()
    except Exception as e:
        logger.debug("CUDA OmniPlatform is not available because: %s", str(e))

    return "vllm_omni.platforms.cuda.platform.CudaOmniPlatform" if is_cuda else None


def rocm_omni_platform_plugin() -> str | None:
    """Check if ROCm OmniPlatform should be activated."""
    is_rocm = False
    logger.debug("Checking if ROCm OmniPlatform is available.")
    try:
        import amdsmi

        amdsmi.amdsmi_init()
        try:
            if len(amdsmi.amdsmi_get_processor_handles()) > 0:
                is_rocm = True
                logger.debug("Confirmed ROCm OmniPlatform is available.")
            else:
                logger.debug("ROCm OmniPlatform is not available because no GPU is found.")
        finally:
            amdsmi.amdsmi_shut_down()
    except Exception as e:
        logger.debug("ROCm OmniPlatform is not available because: %s", str(e))

    return "vllm_omni.platforms.rocm.platform.RocmOmniPlatform" if is_rocm else None


def npu_omni_platform_plugin() -> str | None:
    """Check if NPU OmniPlatform should be activated."""
    is_npu = False
    logger.debug("Checking if NPU OmniPlatform is available.")
    try:
        import torch

        if hasattr(torch, "npu") and torch.npu.is_available():
            is_npu = True
            logger.debug("Confirmed NPU OmniPlatform is available.")
    except Exception as e:
        logger.debug("NPU OmniPlatform is not available because: %s", str(e))

    return "vllm_omni.platforms.npu.platform.NPUOmniPlatform" if is_npu else None


def xpu_omni_platform_plugin() -> str | None:
    """Check if XPU OmniPlatform should be activated."""
    is_xpu = False
    logger.debug("Checking if XPU OmniPlatform is available.")
    try:
        import torch

        if supports_xccl():
            dist_backend = "xccl"
        else:
            dist_backend = "ccl"
            import oneccl_bindings_for_pytorch  # noqa: F401

        if hasattr(torch, "xpu") and torch.xpu.is_available():
            is_xpu = True
            from vllm_omni.platforms.xpu import XPUOmniPlatform

            XPUOmniPlatform.dist_backend = dist_backend
            logger.debug("Confirmed %s backend is available.", XPUOmniPlatform.dist_backend)
            logger.debug("Confirmed XPU platform is available.")
    except Exception as e:
        logger.debug("XPU omni platform is not available because: %s", str(e))

    return "vllm_omni.platforms.xpu.platform.XPUOmniPlatform" if is_xpu else None


builtin_omni_platform_plugins = {
    "cuda": cuda_omni_platform_plugin,
    "rocm": rocm_omni_platform_plugin,
    "npu": npu_omni_platform_plugin,
    "xpu": xpu_omni_platform_plugin,
}


def resolve_current_omni_platform_cls_qualname() -> str:
    """Resolve the current OmniPlatform class qualified name."""
    platform_plugins = load_omni_plugins_by_group(OMNI_PLATFORM_PLUGINS_GROUP)

    activated_plugins = []

    for name, func in chain(builtin_omni_platform_plugins.items(), platform_plugins.items()):
        try:
            assert callable(func)
            platform_cls_qualname = func()
            if platform_cls_qualname is not None:
                activated_plugins.append(name)
        except Exception:
            pass

    activated_builtin_plugins = list(set(activated_plugins) & set(builtin_omni_platform_plugins.keys()))
    activated_oot_plugins = list(set(activated_plugins) & set(platform_plugins.keys()))

    if len(activated_oot_plugins) >= 2:
        raise RuntimeError(f"Only one OmniPlatform plugin can be activated, but got: {activated_oot_plugins}")
    elif len(activated_oot_plugins) == 1:
        platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
        logger.info("OmniPlatform plugin %s is activated", activated_oot_plugins[0])
    elif len(activated_builtin_plugins) >= 2:
        raise RuntimeError(f"Only one OmniPlatform plugin can be activated, but got: {activated_builtin_plugins}")
    elif len(activated_builtin_plugins) == 1:
        platform_cls_qualname = builtin_omni_platform_plugins[activated_builtin_plugins[0]]()
        logger.debug("Automatically detected OmniPlatform %s.", activated_builtin_plugins[0])
    else:
        platform_cls_qualname = "vllm_omni.platforms.interface.UnspecifiedOmniPlatform"
        logger.debug("No platform detected, vLLM-Omni is running on UnspecifiedOmniPlatform")

    return platform_cls_qualname


_current_omni_platform = None
_init_trace: str = ""

if TYPE_CHECKING:
    current_omni_platform: OmniPlatform


def __getattr__(name: str):
    if name == "current_omni_platform":
        # Lazy init current_omni_platform
        global _current_omni_platform
        if _current_omni_platform is None:
            platform_cls_qualname = resolve_current_omni_platform_cls_qualname()
            _current_omni_platform = resolve_obj_by_qualname(platform_cls_qualname)()
            global _init_trace
            _init_trace = "".join(traceback.format_stack())
        return _current_omni_platform
    elif name in globals():
        return globals()[name]
    else:
        raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")


def __setattr__(name: str, value):  # noqa: N807
    if name == "current_omni_platform":
        global _current_omni_platform
        _current_omni_platform = value
    elif name in globals():
        globals()[name] = value
    else:
        raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")


__all__ = [
    "OmniPlatform",
    "OmniPlatformEnum",
    "current_omni_platform",
    "_init_trace",
]
