from __future__ import annotations

import json
import os
import re
from typing import TYPE_CHECKING, Any

import wandb
from wandb import util
from wandb.sdk.launch.errors import LaunchError

if TYPE_CHECKING:
    from wandb.apis.public import Api as PublicApi

DEFAULT_SWEEP_COMMAND: list[str] = [
    "${env}",
    "${interpreter}",
    "${program}",
    "${args}",
]
SWEEP_COMMAND_ENV_VAR_REGEX = re.compile(r"\$\{envvar\:([A-Z0-9_]*)\}")


def parse_sweep_id(parts_dict: dict) -> str | None:
    """In place parse sweep path from parts dict.

    Arguments:
        parts_dict (dict): dict(entity=,project=,name=).  Modifies dict inplace.

    Returns:
        None or str if there is an error
    """
    entity = None
    project = None
    sweep_id = parts_dict.get("name")
    if not isinstance(sweep_id, str):
        return "Expected string sweep_id"

    sweep_split = sweep_id.split("/")
    if len(sweep_split) == 1:
        pass
    elif len(sweep_split) == 2:
        split_project, sweep_id = sweep_split
        project = split_project or project
    elif len(sweep_split) == 3:
        split_entity, split_project, sweep_id = sweep_split
        project = split_project or project
        entity = split_entity or entity
    else:
        return (
            "Expected sweep_id in form of sweep, project/sweep, or entity/project/sweep"
        )
    parts_dict.update(dict(name=sweep_id, project=project, entity=entity))
    return None


def sweep_config_err_text_from_jsonschema_violations(violations: list[str]) -> str:
    """Consolidate schema violation strings from wandb/sweeps into a single string.

    Parameters
    ----------
    violations: list of str
        The warnings to render.

    Returns:
    -------
    violation: str
        The consolidated violation text.

    """
    violation_base = (
        "Malformed sweep config detected! This may cause your sweep to behave in unexpected ways.\n"
        "To avoid this, please fix the sweep config schema violations below:"
    )

    for i, warning in enumerate(violations):
        violations[i] = f"  Violation {i + 1}. {warning}"
    violation = "\n".join([violation_base] + violations)

    return violation


def handle_sweep_config_violations(warnings: list[str]) -> None:
    """Echo sweep config schema violation warnings from Gorilla to the terminal.

    Parameters
    ----------
    warnings: list of str
        The warnings to render.
    """
    warning = sweep_config_err_text_from_jsonschema_violations(warnings)
    if len(warnings) > 0:
        wandb.termwarn(warning)


def load_sweep_config(sweep_config_path: str) -> dict[str, Any] | None:
    """Load a sweep yaml from path."""
    import yaml

    try:
        yaml_file = open(sweep_config_path)
    except OSError:
        wandb.termerror(f"Couldn't open sweep file: {sweep_config_path}")
        return None
    try:
        config: dict[str, Any] | None = yaml.safe_load(yaml_file)
    except yaml.YAMLError as err:
        wandb.termerror(f"Error in configuration file: {err}")
        return None
    if not config:
        wandb.termerror("Configuration file is empty")
        return None
    return config


def load_launch_sweep_config(config: str | None) -> Any:
    if not config:
        return {}

    parsed_config = util.load_json_yaml_dict(config)
    if parsed_config is None:
        raise LaunchError(f"Could not load config from {config}. Check formatting")
    return parsed_config


def construct_scheduler_args(
    sweep_config: dict[str, Any],
    queue: str,
    project: str,
    author: str | None = None,
    return_job: bool = False,
) -> list[str] | dict[str, str] | None:
    """Construct sweep scheduler args.

    logs error and returns None if misconfigured,
    otherwise returns args as a dict if is_job else a list of strings.
    """
    job = sweep_config.get("job")
    image_uri = sweep_config.get("image_uri")
    if not job and not image_uri:  # don't allow empty string
        wandb.termerror(
            "No 'job' nor 'image_uri' top-level key found in sweep config, exactly one is required for a launch-sweep"
        )
        return None
    elif job and image_uri:
        wandb.termerror(
            "Sweep config has both 'job' and 'image_uri' but a launch-sweep can use only one"
        )
        return None

    # if scheduler is a job, return args as dict
    if return_job:
        args_dict: dict[str, str] = {
            "sweep_id": "WANDB_SWEEP_ID",
            "queue": queue,
            "project": project,
        }
        if job:
            args_dict["job"] = job
        elif image_uri:
            args_dict["image_uri"] = image_uri

        if author:
            args_dict["author"] = author

        return args_dict

    # scheduler uses cli commands, pass args as param list
    args = [
        "--queue",
        f"{queue!r}",
        "--project",
        f"{project!r}",
    ]
    if author:
        args += [
            "--author",
            f"{author!r}",
        ]
    if job:
        args += [
            "--job",
            f"{job!r}",
        ]
    elif image_uri:
        args += ["--image_uri", image_uri]

    return args


def create_sweep_command(command: list | None = None) -> list:
    """Return sweep command, filling in environment variable macros."""
    # Start from default sweep command
    command = command or DEFAULT_SWEEP_COMMAND
    for i, chunk in enumerate(command):
        # Replace environment variable macros
        # Search a str(chunk), but allow matches to be of any (ex: int) type
        if SWEEP_COMMAND_ENV_VAR_REGEX.search(str(chunk)):
            # Replace from backwards forwards
            matches = list(SWEEP_COMMAND_ENV_VAR_REGEX.finditer(chunk))
            for m in matches[::-1]:
                # Default to just leaving as is if environment variable does not exist
                _var: str = os.environ.get(m.group(1), m.group(1))
                command[i] = f"{command[i][: m.start()]}{_var}{command[i][m.end() :]}"
    return command


def create_sweep_command_args(command: dict) -> dict[str, Any]:
    """Create various formats of command arguments for the agent.

    Raises:
        ValueError: improperly formatted command dict

    """
    if "args" not in command:
        raise ValueError(f'No "args" found in command: {command}')
    # four different formats of command args
    # (1) standard command line flags (e.g. --foo=bar)
    flags: list[str] = []
    # (2) flags without hyphens (e.g. foo=bar)
    flags_no_hyphens: list[str] = []
    # (3) flags with false booleans omitted  (e.g. --foo)
    flags_no_booleans: list[str] = []
    # (4) flags as a dictionary (used for constructing a json)
    flags_dict: dict[str, Any] = {}
    # (5) flags without equals (e.g. --foo bar)
    args_no_equals: list[str] = []
    # (6) flags for hydra append config value (e.g. +foo=bar)
    flags_append_hydra: list[str] = []
    # (7) flags for hydra override config value (e.g. ++foo=bar)
    flags_override_hydra: list[str] = []
    for param, config in command["args"].items():
        # allow 'None' as a valid value, but error if no value is found
        try:
            _value: Any = config["value"]
        except KeyError:
            raise ValueError(f'No "value" found for command["args"]["{param}"]')

        _flag: str = f"{param}={_value}"
        flags.append("--" + _flag)
        flags_no_hyphens.append(_flag)
        args_no_equals += [f"--{param}", str(_value)]
        flags_append_hydra.append("+" + _flag)
        flags_override_hydra.append("++" + _flag)
        if isinstance(_value, bool):
            # omit flags if they are boolean and false
            if _value:
                flags_no_booleans.append("--" + param)
        else:
            flags_no_booleans.append("--" + _flag)
        flags_dict[param] = _value
    return {
        "args": flags,
        "args_no_equals": args_no_equals,
        "args_no_hyphens": flags_no_hyphens,
        "args_no_boolean_flags": flags_no_booleans,
        "args_json": [json.dumps(flags_dict)],
        "args_dict": flags_dict,
        "args_append_hydra": flags_append_hydra,
        "args_override_hydra": flags_override_hydra,
    }


def make_launch_sweep_entrypoint(
    args: dict[str, Any], command: list[str] | None
) -> tuple[list[str] | None, Any]:
    """Use args dict from create_sweep_command_args to construct entrypoint.

    If replace is True, remove macros from entrypoint, fill them in with args
    and then return the args in separate return value.
    """
    if not command:
        return None, None

    entry_point = create_sweep_command(command)
    macro_args = {}
    for macro in args:
        mstr = "${" + macro + "}"
        if mstr in entry_point:
            idx = entry_point.index(mstr)
            # only supports 1 macro per entrypoint
            macro_args = args[macro]
            entry_point = entry_point[:idx] + entry_point[idx + 1 :]

    if len(entry_point) == 0:
        return None, macro_args

    return entry_point, macro_args


def check_job_exists(public_api: PublicApi, job: str | None) -> bool:
    """Check if the job exists using the public api.

    Returns: True if no job is passed, or if the job exists.
    Returns: False if the job is misformatted or doesn't exist.
    """
    if not job:
        return True

    try:
        public_api.job(job)
    except Exception as e:
        wandb.termerror(f"Failed to load job. {e}")
        return False
    return True


def get_previous_args(
    run_spec: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
    """Parse through previous scheduler run_spec.

    returns scheduler_args and settings.
    """
    scheduler_args = (
        run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {})
    )
    # also pipe through top level resource setup
    if run_spec.get("resource"):
        scheduler_args["resource"] = run_spec["resource"]
    if run_spec.get("resource_args"):
        scheduler_args["resource_args"] = run_spec["resource_args"]

    settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {})

    return scheduler_args, settings
