# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

import functools
import os
from importlib import metadata
from typing import TYPE_CHECKING, List, Mapping, Tuple, Type

from ..core import logger

if TYPE_CHECKING:
    # Breaks the import cycle
    from ..core.core import Executor
    from ..core.job_environment import JobEnvironment


def _iter_submitit_entrypoints():
    """Return an iterable of EntryPoint objects in the 'submitit' group
    compatible with Python 3.8+ and the backport."""

    # 3.10+ API: EntryPoints with .select
    eps = metadata.entry_points()
    if hasattr(eps, "select"):
        return eps.select(group="submitit")

    # importlib_metadata backport newer signature: entry_points("submitit")
    try:
        return metadata.entry_points()["submitit"]
    except TypeError:
        pass  # older API; fall through

    # 3.8/3.9 legacy: mapping {group: [EntryPoint, ...]}
    if hasattr(eps, "get"):
        return eps.get("submitit", [])

    # old style (should in theory never get here if 3.8+): flat iterable; filter by .group
    return [ep for ep in eps if getattr(ep, "group", None) == "submitit"]


@functools.lru_cache()
def _get_plugins() -> Tuple[List[Type["Executor"]], List["JobEnvironment"]]:
    # pylint: disable=cyclic-import,import-outside-toplevel
    from ..local import debug, local
    from ..slurm import slurm

    executors: List[Type["Executor"]] = [slurm.SlurmExecutor, local.LocalExecutor, debug.DebugExecutor]
    job_envs = [slurm.SlurmJobEnvironment(), local.LocalJobEnvironment(), debug.DebugJobEnvironment()]
    for entry_point in _iter_submitit_entrypoints():
        if entry_point.name not in ("executor", "job_environment"):
            logger.warning(f"{entry_point.name} = {entry_point.value}")
            continue

        module_name = entry_point.value.split(":", 1)[0]
        try:
            # call `load` rather than `resolve`.
            # `load` also checks the module and its dependencies are correctly installed.
            obj = entry_point.load()
        except Exception as e:
            # This may happen if the plugin haven't been correctly installed
            logger.exception(f"Failed to load submitit plugin '{module_name}': {e}")
            continue

        if entry_point.name == "executor":
            executors.append(obj)
        else:
            try:
                job_env = obj()
            except Exception as e:
                name = getattr(obj, "name", getattr(obj, "__name__", str(obj)))
                logger.exception(
                    f"Failed to init JobEnvironment '{name}' ({obj}) from submitit plugin '{module_name}': {e}"
                )
                continue
            job_envs.append(job_env)

    return (executors, job_envs)


@functools.lru_cache()
def get_executors() -> Mapping[str, Type["Executor"]]:
    # TODO: check collisions between executor names
    return {ex.name(): ex for ex in _get_plugins()[0]}


def get_job_environment() -> "JobEnvironment":
    # Don't cache this function. It makes testing harder.
    # The slow part is the plugin discovery anyway.
    envs = get_job_environments()
    # bypassing can be helful for testing
    if "_TEST_CLUSTER_" in os.environ:
        c = os.environ["_TEST_CLUSTER_"]
        assert c in envs, f"Unknown $_TEST_CLUSTER_='{c}', available: {envs.keys()}."
        return envs[c]
    for env in envs.values():
        # TODO? handle the case where several envs are valid
        if env.activated():
            return env
    raise RuntimeError(
        f"Could not figure out which environment the job is runnning in. Known environments: {', '.join(envs.keys())}."
    )


@functools.lru_cache()
def get_job_environments() -> Mapping[str, "JobEnvironment"]:
    return {env.name(): env for env in _get_plugins()[1]}
