import asyncio
import json
import logging
import os
import platform
import re
import subprocess
import sys
from collections import defaultdict
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Iterator,
    List,
    Optional,
    Tuple,
    Union,
    cast,
)

import click

import wandb
import wandb.docker as docker
from wandb import util
from wandb.apis.internal import Api
from wandb.sdk.launch.errors import LaunchError
from wandb.sdk.launch.git_reference import GitReference
from wandb.sdk.launch.wandb_reference import WandbReference
from wandb.sdk.wandb_config import Config

from .builder.templates._wandb_bootstrap import (
    FAILED_PACKAGES_POSTFIX,
    FAILED_PACKAGES_PREFIX,
)

FAILED_PACKAGES_REGEX = re.compile(
    f"{re.escape(FAILED_PACKAGES_PREFIX)}(.*){re.escape(FAILED_PACKAGES_POSTFIX)}"
)

if TYPE_CHECKING:  # pragma: no cover
    from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker


# TODO: this should be restricted to just Git repos and not S3 and stuff like that
_GIT_URI_REGEX = re.compile(
    r"^[^/|^~|^\.].*(git|bitbucket|dev\.azure\.com|\.visualstudio\.com)"
)
_VALID_IP_REGEX = r"^https?://[0-9]+(?:\.[0-9]+){3}(:[0-9]+)?"
_VALID_PIP_PACKAGE_REGEX = r"^[a-zA-Z0-9_.-]+$"
_VALID_WANDB_REGEX = r"^https?://(api.)?wandb"
_WANDB_URI_REGEX = re.compile(r"|".join([_VALID_WANDB_REGEX, _VALID_IP_REGEX]))
_WANDB_QA_URI_REGEX = re.compile(
    r"^https?://ap\w.qa.wandb"
)  # for testing, not sure if we wanna keep this
_WANDB_DEV_URI_REGEX = re.compile(
    r"^https?://ap\w.wandb.test"
)  # for testing, not sure if we wanna keep this
_WANDB_LOCAL_DEV_URI_REGEX = re.compile(
    r"^https?://localhost"
)  # for testing, not sure if we wanna keep this

API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"

MACRO_REGEX = re.compile(r"\$\{(\w+)\}")

AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
    r"^(?:https://)?([\w]+)\.azurecr\.io/(?P<repository>[\w\-]+):?(?P<tag>.*)"
)

ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
    r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\.\/\w-]+):?(?P<tag>.*)$"
)

GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
    r"^(?:https://)?(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/?(?P<image_name>[\w-]+)?(?P<tag>:.*)?$",
    re.IGNORECASE,
)

S3_URI_RE = re.compile(r"s3://([^/]+)(/(.*))?")
GCS_URI_RE = re.compile(r"gs://([^/]+)(?:/(.*))?")
AZURE_BLOB_REGEX = re.compile(
    r"^https://([^\.]+)\.blob\.core\.windows\.net/([^/]+)/?(.*)$"
)

ARN_PARTITION_RE = re.compile(r"^arn:([^:]+):[^:]*:[^:]*:[^:]*:[^:]*$")

PROJECT_SYNCHRONOUS = "SYNCHRONOUS"

LAUNCH_CONFIG_FILE = "~/.config/wandb/launch-config.yaml"
LAUNCH_DEFAULT_PROJECT = "model-registry"

_logger = logging.getLogger(__name__)
LOG_PREFIX = f"{click.style('launch:', fg='magenta')} "

MAX_ENV_LENGTHS: Dict[str, int] = defaultdict(lambda: 32670)
MAX_ENV_LENGTHS["SageMakerRunner"] = 512

CODE_MOUNT_DIR = "/mnt/wandb"


def load_wandb_config() -> Config:
    """Load wandb config from WANDB_CONFIG environment variable(s).

    The WANDB_CONFIG environment variable is a json string that can contain
    multiple config keys. The WANDB_CONFIG_[0-9]+ environment variables are
    used for environments where there is a limit on the length of environment
    variables. In that case, we shard the contents of WANDB_CONFIG into
    multiple environment variables numbered from 0.

    Returns:
        A dictionary of wandb config values.
    """
    config_str = os.environ.get("WANDB_CONFIG")
    if config_str is None:
        config_str = ""
        idx = 0
        while True:
            chunk = os.environ.get(f"WANDB_CONFIG_{idx}")
            if chunk is None:
                break
            config_str += chunk
            idx += 1
        if idx < 1:
            raise LaunchError(
                "No WANDB_CONFIG or WANDB_CONFIG_[0-9]+ environment variables found"
            )
    wandb_config = Config()
    try:
        env_config = json.loads(config_str)
    except json.JSONDecodeError as e:
        raise LaunchError(f"Failed to parse WANDB_CONFIG: {e}") from e

    wandb_config.update(env_config)
    return wandb_config


def event_loop_thread_exec(func: Any) -> Any:
    """Wrapper for running any function in an awaitable thread on an event loop.

    Example usage:
    ```
    def my_func(arg1, arg2):
        return arg1 + arg2


    future = event_loop_thread_exec(my_func)(2, 2)
    assert await future == 4
    ```

    The returned function must be called within an active event loop.
    """

    async def wrapper(*args: Any, **kwargs: Any) -> Any:
        loop = asyncio.get_event_loop()
        result = cast(
            Any, await loop.run_in_executor(None, lambda: func(*args, **kwargs))
        )
        return result

    return wrapper


def _is_wandb_uri(uri: str) -> bool:
    return (
        _WANDB_URI_REGEX.match(uri)
        or _WANDB_DEV_URI_REGEX.match(uri)
        or _WANDB_LOCAL_DEV_URI_REGEX.match(uri)
        or _WANDB_QA_URI_REGEX.match(uri)
    ) is not None


def _is_wandb_dev_uri(uri: str) -> bool:
    return bool(_WANDB_DEV_URI_REGEX.match(uri))


def _is_wandb_local_uri(uri: str) -> bool:
    return bool(_WANDB_LOCAL_DEV_URI_REGEX.match(uri))


def _is_git_uri(uri: str) -> bool:
    return bool(_GIT_URI_REGEX.match(uri))


def sanitize_wandb_api_key(s: str) -> str:
    return str(re.sub(API_KEY_REGEX, "WANDB_API_KEY", s))


def get_project_from_job(job: str) -> Optional[str]:
    job_parts = job.split("/")
    if len(job_parts) == 3:
        return job_parts[1]
    return None


def set_project_entity_defaults(
    uri: Optional[str],
    job: Optional[str],
    api: Api,
    project: Optional[str],
    entity: Optional[str],
    launch_config: Optional[Dict[str, Any]],
) -> Tuple[Optional[str], str]:
    # set the target project and entity if not provided
    source_uri = None
    if uri is not None:
        if _is_wandb_uri(uri):
            _, source_uri, _ = parse_wandb_uri(uri)
        elif _is_git_uri(uri):
            source_uri = os.path.splitext(os.path.basename(uri))[0]
    elif job is not None:
        source_uri = get_project_from_job(job)
    if project is None:
        config_project = None
        if launch_config:
            config_project = launch_config.get("project")
        project = config_project or source_uri or ""
    if entity is None:
        entity = get_default_entity(api, launch_config)
    prefix = ""
    if platform.system() != "Windows" and sys.stdout.encoding == "UTF-8":
        prefix = "🚀 "
    wandb.termlog(
        f"{LOG_PREFIX}{prefix}Launching run into {entity}{'/' + project if project else ''}"
    )
    return project, entity


def get_default_entity(api: Api, launch_config: Optional[Dict[str, Any]]):
    config_entity = None
    if launch_config:
        config_entity = launch_config.get("entity")
    return config_entity or api.default_entity


def strip_resource_args_and_template_vars(launch_spec: Dict[str, Any]) -> None:
    if launch_spec.get("resource_args", None) and launch_spec.get(
        "template_variables", None
    ):
        wandb.termwarn(
            "Launch spec contains both resource_args and template_variables, "
            "only one can be set. Using template_variables."
        )
        launch_spec.pop("resource_args")


def construct_launch_spec(
    uri: Optional[str],
    job: Optional[str],
    api: Api,
    name: Optional[str],
    project: Optional[str],
    entity: Optional[str],
    docker_image: Optional[str],
    resource: Optional[str],
    entry_point: Optional[List[str]],
    version: Optional[str],
    resource_args: Optional[Dict[str, Any]],
    launch_config: Optional[Dict[str, Any]],
    run_id: Optional[str],
    repository: Optional[str],
    author: Optional[str],
    sweep_id: Optional[str] = None,
) -> Dict[str, Any]:
    """Construct the launch specification from CLI arguments."""
    # override base config (if supplied) with supplied args
    launch_spec = launch_config if launch_config is not None else {}
    if uri is not None:
        launch_spec["uri"] = uri
    if job is not None:
        launch_spec["job"] = job
    project, entity = set_project_entity_defaults(
        uri,
        job,
        api,
        project,
        entity,
        launch_config,
    )
    launch_spec["entity"] = entity
    if author:
        launch_spec["author"] = author

    launch_spec["project"] = project
    if name:
        launch_spec["name"] = name
    if "docker" not in launch_spec:
        launch_spec["docker"] = {}
    if docker_image:
        launch_spec["docker"]["docker_image"] = docker_image
    if sweep_id:  # all runs in a sweep have this set
        launch_spec["sweep_id"] = sweep_id

    if "resource" not in launch_spec:
        launch_spec["resource"] = resource if resource else None

    if "git" not in launch_spec:
        launch_spec["git"] = {}
    if version:
        launch_spec["git"]["version"] = version

    if "overrides" not in launch_spec:
        launch_spec["overrides"] = {}

    if not isinstance(launch_spec["overrides"].get("args", []), list):
        raise LaunchError("override args must be a list of strings")

    if resource_args:
        launch_spec["resource_args"] = resource_args

    if entry_point:
        launch_spec["overrides"]["entry_point"] = entry_point

    if run_id is not None:
        launch_spec["run_id"] = run_id

    if repository:
        launch_config = launch_config or {}
        if launch_config.get("registry"):
            launch_config["registry"]["url"] = repository
        else:
            launch_config["registry"] = {"url": repository}

    # dont send both resource args and template variables
    strip_resource_args_and_template_vars(launch_spec)

    return launch_spec


def validate_launch_spec_source(launch_spec: Dict[str, Any]) -> None:
    job = launch_spec.get("job")
    docker_image = launch_spec.get("docker", {}).get("docker_image")
    if bool(job) == bool(docker_image):
        raise LaunchError(
            "Exactly one of job or docker_image must be specified in the launch spec."
        )


def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
    """Parse wandb uri to retrieve entity, project and run name."""
    ref = WandbReference.parse(uri)
    if not ref or not ref.entity or not ref.project or not ref.run_id:
        raise LaunchError(f"Trouble parsing wandb uri {uri}")
    return (ref.entity, ref.project, ref.run_id)


def get_local_python_deps(
    dir: str, filename: str = "requirements.local.txt"
) -> Optional[str]:
    try:
        env = os.environ
        with open(os.path.join(dir, filename), "w") as f:
            subprocess.call(["pip", "freeze"], env=env, stdout=f)
        return filename
    except subprocess.CalledProcessError as e:
        wandb.termerror(f"Command failed: {e}")
        return None


def diff_pip_requirements(req_1: List[str], req_2: List[str]) -> Dict[str, str]:
    """Return a list of pip requirements that are not in req_1 but are in req_2."""

    def _parse_req(req: List[str]) -> Dict[str, str]:
        # TODO: This can be made more exhaustive, but for 99% of cases this is fine
        # see https://pip.pypa.io/en/stable/reference/requirements-file-format/#example
        d: Dict[str, str] = dict()
        for line in req:
            _name: str = None  # type: ignore
            _version: str = None  # type: ignore
            if line.startswith("#"):  # Ignore comments
                continue
            elif "git+" in line or "hg+" in line:
                _name = line.split("#egg=")[1]
                _version = line.split("@")[-1].split("#")[0]
            elif "==" in line:
                _s = line.split("==")
                _name = _s[0].lower()
                _version = _s[1].split("#")[0].strip()
            elif ">=" in line:
                _s = line.split(">=")
                _name = _s[0].lower()
                _version = _s[1].split("#")[0].strip()
            elif ">" in line:
                _s = line.split(">")
                _name = _s[0].lower()
                _version = _s[1].split("#")[0].strip()
            elif re.match(_VALID_PIP_PACKAGE_REGEX, line) is not None:
                _name = line
            else:
                raise ValueError(f"Unable to parse pip requirements file line: {line}")
            if _name is not None:
                assert re.match(_VALID_PIP_PACKAGE_REGEX, _name), (
                    f"Invalid pip package name {_name}"
                )
                d[_name] = _version
        return d

    # Use symmetric difference between dict representation to print errors
    try:
        req_1_dict: Dict[str, str] = _parse_req(req_1)
        req_2_dict: Dict[str, str] = _parse_req(req_2)
    except (AssertionError, ValueError, IndexError, KeyError) as e:
        raise LaunchError(f"Failed to parse pip requirements: {e}")
    diff: List[Tuple[str, str]] = []
    for item in set(req_1_dict.items()) ^ set(req_2_dict.items()):
        diff.append(item)
    # Parse through the diff to make it pretty
    pretty_diff: Dict[str, str] = {}
    for name, version in diff:
        if pretty_diff.get(name) is None:
            pretty_diff[name] = version
        else:
            pretty_diff[name] = f"v{version} and v{pretty_diff[name]}"
    return pretty_diff


def validate_wandb_python_deps(
    requirements_file: Optional[str],
    dir: str,
) -> None:
    """Warn if local python dependencies differ from wandb requirements.txt."""
    if requirements_file is not None:
        requirements_path = os.path.join(dir, requirements_file)
        with open(requirements_path) as f:
            wandb_python_deps: List[str] = f.read().splitlines()

        local_python_file = get_local_python_deps(dir)
        if local_python_file is not None:
            local_python_deps_path = os.path.join(dir, local_python_file)
            with open(local_python_deps_path) as f:
                local_python_deps: List[str] = f.read().splitlines()

            diff_pip_requirements(wandb_python_deps, local_python_deps)
            return
    _logger.warning("Unable to validate local python dependencies")


def apply_patch(patch_string: str, dst_dir: str) -> None:
    """Applies a patch file to a directory."""
    _logger.info("Applying diff.patch")
    with open(os.path.join(dst_dir, "diff.patch"), "w") as fp:
        fp.write(patch_string)
    try:
        subprocess.check_call(
            [
                "patch",
                "-s",
                f"--directory={dst_dir}",
                "-p1",
                "-i",
                "diff.patch",
            ]
        )
    except subprocess.CalledProcessError:
        raise wandb.Error("Failed to apply diff.patch associated with run.")


def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[str]:
    """Clones the git repo at ``uri`` into ``dst_dir``.

    checks out commit ``version``. Assumes authentication parameters are
    specified by the environment, e.g. by a Git credential helper.
    """
    # We defer importing git until the last moment, because the import requires that the git
    # executable is available on the PATH, so we only want to fail if we actually need it.

    _logger.info("Fetching git repo")
    ref = GitReference(uri, version)
    if ref is None:
        raise LaunchError(f"Unable to parse git uri: {uri}")
    ref.fetch(dst_dir)
    if version is None:
        version = ref.ref
    return version


def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
    nbconvert = wandb.util.get_module(
        "nbconvert", "nbformat and nbconvert are required to use launch with notebooks"
    )
    nbformat = wandb.util.get_module(
        "nbformat", "nbformat and nbconvert are required to use launch with notebooks"
    )

    _logger.info("Converting notebook to script")
    new_name = fname.replace(".ipynb", ".py")
    with open(os.path.join(project_dir, fname)) as fh:
        nb = nbformat.reads(fh.read(), nbformat.NO_CONVERT)
        for cell in nb.cells:
            if cell.cell_type == "code":
                source_lines = cell.source.split("\n")
                modified_lines = []
                for line in source_lines:
                    if not line.startswith("!"):
                        modified_lines.append(line)
                cell.source = "\n".join(modified_lines)

    exporter = nbconvert.PythonExporter()
    source, meta = exporter.from_notebook_node(nb)

    with open(os.path.join(project_dir, new_name), "w+") as fh:
        fh.writelines(source)
    return new_name


def to_camel_case(maybe_snake_str: str) -> str:
    if "_" not in maybe_snake_str:
        return maybe_snake_str
    components = maybe_snake_str.split("_")
    return "".join(x.title() if x else "_" for x in components)


def validate_build_and_registry_configs(
    build_config: Dict[str, Any], registry_config: Dict[str, Any]
) -> None:
    build_config_credentials = build_config.get("credentials", {})
    registry_config_credentials = registry_config.get("credentials", {})
    if (
        build_config_credentials
        and registry_config_credentials
        and build_config_credentials != registry_config_credentials
    ):
        raise LaunchError("registry and build config credential mismatch")


async def get_kube_context_and_api_client(
    kubernetes: Any,
    resource_args: Dict[str, Any],
) -> Tuple[Any, Any]:
    config_file = resource_args.get("configFile", None)
    context = None
    if config_file is not None or os.path.exists(os.path.expanduser("~/.kube/config")):
        # context only exist in the non-incluster case
        (
            all_contexts,
            active_context,
        ) = kubernetes.config.list_kube_config_contexts(config_file)
        context = None
        if resource_args.get("context"):
            context_name = resource_args["context"]
            for c in all_contexts:
                if c["name"] == context_name:
                    context = c
                    break
            raise LaunchError(f"Specified context {context_name} was not found.")
        else:
            context = active_context
        # TODO: We should not really be performing this check if the user is not
        # using EKS but I don't see an obvious way to make an eks specific code path
        # right here.
        util.get_module(
            "awscli",
            "awscli is required to load a kubernetes context "
            "from eks. Please run `pip install wandb[launch]` to install it.",
        )
        await kubernetes.config.load_kube_config(config_file, context["name"])
        api_client = await kubernetes.config.new_client_from_config(
            config_file, context=context["name"]
        )
        return context, api_client
    else:
        kubernetes.config.load_incluster_config()
        api_client = kubernetes.client.api_client.ApiClient()
        return context, api_client


def resolve_build_and_registry_config(
    default_launch_config: Optional[Dict[str, Any]],
    build_config: Optional[Dict[str, Any]],
    registry_config: Optional[Dict[str, Any]],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    resolved_build_config: Dict[str, Any] = {}
    if build_config is None and default_launch_config is not None:
        resolved_build_config = default_launch_config.get("builder", {})
    elif build_config is not None:
        resolved_build_config = build_config
    resolved_registry_config: Dict[str, Any] = {}
    if registry_config is None and default_launch_config is not None:
        resolved_registry_config = default_launch_config.get("registry", {})
    elif registry_config is not None:
        resolved_registry_config = registry_config
    validate_build_and_registry_configs(resolved_build_config, resolved_registry_config)
    return resolved_build_config, resolved_registry_config


def check_logged_in(api: Api) -> bool:
    """Check if a user is logged in.

    Raises an error if the viewer doesn't load (likely a broken API key). Expected time
    cost is 0.1-0.2 seconds.
    """
    res = api.api.viewer()
    if not res:
        raise LaunchError(
            "Could not connect with current API-key. "
            "Please relogin using `wandb login --relogin`"
            " and try again (see `wandb login --help` for more options)"
        )

    return True


def make_name_dns_safe(name: str) -> str:
    resp = name.replace("_", "-").lower()
    resp = re.sub(r"[^a-z\.\-]", "", resp)
    # Actual length limit is 253, but we want to leave room for the generated suffix
    resp = resp[:200]
    return resp


def make_k8s_label_safe(value: str) -> str:
    """Return a Kubernetes label/identifier safe string (DNS-1123 label).

    See:
    https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names

    Rules:
    - lowercase alphanumeric and '-'
    - must start and end with an alphanumeric
    - max length 63
    """
    # Normalize common separators first
    safe = value.replace("_", "-").lower()
    # Remove any invalid characters
    safe = re.sub(r"[^a-z0-9\-]", "", safe)
    # Collapse consecutive '-'
    safe = re.sub(r"-+", "-", safe)
    # Trim to 63 and strip leading/trailing '-'
    safe = safe[:63].strip("-")

    if not safe:
        raise LaunchError(f"Invalid value for Kubernetes label: {value}")

    return safe


def warn_failed_packages_from_build_logs(
    log: str, image_uri: str, api: Api, job_tracker: Optional["JobAndRunStatusTracker"]
) -> None:
    match = FAILED_PACKAGES_REGEX.search(log)
    if match:
        _msg = f"Failed to install the following packages: {match.group(1)} for image: {image_uri}. Will attempt to launch image without them."
        wandb.termwarn(_msg)
        if job_tracker is not None:
            res = job_tracker.saver.save_contents(
                _msg, "failed-packages.log", "warning"
            )
            api.update_run_queue_item_warning(
                job_tracker.run_queue_item_id,
                "Some packages were not successfully installed during the build",
                "build",
                res,
            )


def docker_image_exists(docker_image: str, should_raise: bool = False) -> bool:
    """Check if a specific image is already available.

    Optionally raises an exception if the image is not found.
    """
    _logger.info("Checking if base image exists...")
    try:
        docker.run(["docker", "image", "inspect", docker_image])
        return True
    except (docker.DockerError, ValueError):
        if should_raise:
            raise
        _logger.info("Base image not found. Generating new base image")
        return False


def pull_docker_image(docker_image: str) -> None:
    """Pull the requested docker image."""
    try:
        docker.run(["docker", "pull", docker_image])
    except docker.DockerError as e:
        raise LaunchError(f"Docker server returned error: {e}")


def macro_sub(original: str, sub_dict: Dict[str, Optional[str]]) -> str:
    """Substitute macros in a string.

    Macros occur in the string in the ${macro} format. The macro names are
    substituted with their values from the given dictionary. If a macro
    is not found in the dictionary, it is left unchanged.

    Args:
        original: The string to substitute macros in.
        sub_dict: A dictionary mapping macro names to their values.

    Returns:
        The string with the macros substituted.
    """
    return MACRO_REGEX.sub(
        lambda match: str(sub_dict.get(match.group(1), match.group(0))), original
    )


def recursive_macro_sub(source: Any, sub_dict: Dict[str, Optional[str]]) -> Any:
    """Recursively substitute macros in a parsed JSON or YAML blob.

    Macros occur in strings at leaves of the blob in the ${macro} format.
    The macro names are substituted with their values from the given dictionary.
    If a macro is not found in the dictionary, it is left unchanged.

    Arguments:
        source: The JSON or YAML blob to substitute macros in.
        sub_dict: A dictionary mapping macro names to their values.

    Returns:
        The blob with the macros substituted.
    """
    if isinstance(source, str):
        return macro_sub(source, sub_dict)
    elif isinstance(source, list):
        return [recursive_macro_sub(item, sub_dict) for item in source]
    elif isinstance(source, dict):
        return {
            key: recursive_macro_sub(value, sub_dict) for key, value in source.items()
        }
    else:
        return source


def fetch_and_validate_template_variables(
    runqueue: Any, fields: dict
) -> Dict[str, Any]:
    template_variables = {}

    variable_schemas = {}
    for tv in runqueue.template_variables:
        variable_schemas[tv["name"]] = json.loads(tv["schema"])

    for field in fields:
        field_parts = field.split("=")
        if len(field_parts) != 2:
            raise LaunchError(
                f'--set-var value must be in the format "--set-var key1=value1", instead got: {field}'
            )
        key, val = field_parts
        if key not in variable_schemas:
            raise LaunchError(
                f"Queue {runqueue.name} does not support overriding {key}."
            )
        schema = variable_schemas.get(key, {})
        field_type = schema.get("type")
        try:
            if field_type == "integer":
                val = int(val)
            elif field_type == "number":
                val = float(val)

        except ValueError:
            raise LaunchError(f"Value for {key} must be of type {field_type}.")
        template_variables[key] = val
    return template_variables


def get_entrypoint_file(entrypoint: List[str]) -> Optional[str]:
    """Get the entrypoint file from the given command.

    Args:
        entrypoint (List[str]): List of command and arguments.

    Returns:
        Optional[str]: The entrypoint file if found, otherwise None.
    """
    if not entrypoint:
        return None
    if entrypoint[0].endswith(".py") or entrypoint[0].endswith(".sh"):
        return entrypoint[0]
    if len(entrypoint) < 2:
        return None
    return entrypoint[1]


def get_current_python_version() -> Tuple[str, str]:
    full_version = sys.version.split()[0].split(".")
    major = full_version[0]
    version = ".".join(full_version[:2]) if len(full_version) >= 2 else major + ".0"
    return version, major


def yield_containers(root: Union[dict, list]) -> Iterator[dict]:
    """Yield all container specs in a manifest.

    Recursively traverses the manifest and yields all container specs. Container
    specs are identified by the presence of a "containers" key in the value.
    """
    if isinstance(root, dict):
        for k, v in root.items():
            if k == "containers":
                if isinstance(v, list):
                    yield from v
            elif isinstance(v, (dict, list)):
                yield from yield_containers(v)
    elif isinstance(root, list):
        for item in root:
            yield from yield_containers(item)


def sanitize_identifiers_for_k8s(root: Any) -> None:
    if isinstance(root, list):
        for item in root:
            sanitize_identifiers_for_k8s(item)
        return

    # Only dicts have metadata and nested structures we need to sanitize.
    if not isinstance(root, dict):
        return

    metadata = root.get("metadata")
    if isinstance(metadata, dict):
        if name := metadata.get("name"):
            metadata["name"] = make_k8s_label_safe(str(name))

    for container in yield_containers(root):
        if name := container.get("name"):
            container["name"] = make_k8s_label_safe(str(name))

    # nested names
    for key, value in root.items():
        if isinstance(value, (dict, list)):
            sanitize_identifiers_for_k8s(value)
        elif key == "name" and isinstance(value, str):
            root[key] = make_k8s_label_safe(value)
