"""config."""

import logging
from typing import Optional

import wandb
from wandb.util import (
    _is_artifact_representation,
    check_dict_contains_nested_artifact,
    json_friendly_val,
)

from . import wandb_helper
from .lib import config_util

logger = logging.getLogger("wandb")


# TODO(jhr): consider a callback for persisting changes?
# if this is done right we might make sure this is pickle-able
# we might be able to do this on other objects like Run?
class Config:
    """Config object.

    Config objects are intended to hold all of the hyperparameters associated
    with a wandb run and are saved with the run object when `wandb.init` is
    called.

    We recommend setting the config once when initializing your run by passing
    the `config` parameter to `init`:

    ```
    wandb.init(config=my_config_dict)
    ```

    You can create a file called `config-defaults.yaml`, and it will
    automatically be loaded as each run's config. You can also pass the name
    of the file as the `config` parameter to `init`:

    ```
    wandb.init(config="my_config.yaml")
    ```

    See https://docs.wandb.com/guides/track/config#file-based-configs.

    Examples:
        Basic usage
        ```
        with wandb.init(config={"epochs": 4}) as run:
            for x in range(run.config.epochs):
                # train
        ```

        Nested values
        ```
        with wandb.init(config={"train": {"epochs": 4}}) as run:
            for x in range(run.config["train"]["epochs"]):
                # train
        ```

        Using absl flags
        ```
        flags.DEFINE_string("model", None, "model to run")  # name, default, help
        with wandb.init() as run:
            run.config.update(flags.FLAGS)  # adds all absl flags to config
        ```

        Argparse flags
        ```python
        with wandb.init(config={"epochs": 4}) as run:
            parser = argparse.ArgumentParser()
            parser.add_argument(
                "-b",
                "--batch-size",
                type=int,
                default=8,
                metavar="N",
                help="input batch size for training (default: 8)",
            )
            args = parser.parse_args()
            run.config.update(args)
        ```

        Using TensorFlow flags (deprecated in tensorflow v2)
        ```python
        flags = tf.app.flags
        flags.DEFINE_string("data_dir", "/tmp/data")
        flags.DEFINE_integer("batch_size", 128, "Batch size.")

        with wandb.init() as run:
            run.config.update(flags.FLAGS)
        ```
    """

    def __init__(self):
        object.__setattr__(self, "_items", dict())
        object.__setattr__(self, "_locked", dict())
        object.__setattr__(self, "_users", dict())
        object.__setattr__(self, "_users_inv", dict())
        object.__setattr__(self, "_users_cnt", 0)
        object.__setattr__(self, "_callback", None)
        object.__setattr__(self, "_settings", None)
        object.__setattr__(self, "_artifact_callback", None)

        self._load_defaults()

    def _set_callback(self, cb):
        object.__setattr__(self, "_callback", cb)

    def _set_artifact_callback(self, cb):
        object.__setattr__(self, "_artifact_callback", cb)

    def _set_settings(self, settings):
        object.__setattr__(self, "_settings", settings)

    def __repr__(self):
        return str(dict(self))

    def keys(self):
        return [k for k in self._items.keys() if not k.startswith("_")]

    def _as_dict(self):
        return self._items

    def as_dict(self):
        # TODO: add telemetry, deprecate, then remove
        return dict(self)

    def __getitem__(self, key):
        return self._items[key]

    def __iter__(self):
        return iter(self._items)

    def _check_locked(self, key, ignore_locked=False) -> bool:
        locked = self._locked.get(key)
        if locked is not None:
            locked_user = self._users_inv[locked]
            if not ignore_locked:
                wandb.termwarn(
                    f"Config item '{key}' was locked by '{locked_user}' (ignored update)."
                )
            return True
        return False

    def __setitem__(self, key, val):
        if self._check_locked(key):
            return
        with wandb.sdk.lib.telemetry.context() as tel:
            tel.feature.set_config_item = True
        self._raise_value_error_on_nested_artifact(val, nested=True)
        key, val = self._sanitize(key, val)
        self._items[key] = val
        logger.info("config set %s = %s - %s", key, val, self._callback)
        if self._callback:
            self._callback(key=key, val=val)

    def items(self):
        return [(k, v) for k, v in self._items.items() if not k.startswith("_")]

    __setattr__ = __setitem__

    def __getattr__(self, key):
        try:
            return self.__getitem__(key)
        except KeyError as ke:
            raise AttributeError(
                f"{self.__class__!r} object has no attribute {key!r}"
            ) from ke

    def __contains__(self, key):
        return key in self._items

    def _update(self, d, allow_val_change=None, ignore_locked=None):
        parsed_dict = wandb_helper.parse_config(d)
        locked_keys = set()
        for key in list(parsed_dict):
            if self._check_locked(key, ignore_locked=ignore_locked):
                locked_keys.add(key)
        sanitized = self._sanitize_dict(
            parsed_dict, allow_val_change, ignore_keys=locked_keys
        )
        self._items.update(sanitized)
        return sanitized

    def update(self, d, allow_val_change=None):
        sanitized = self._update(d, allow_val_change)
        if self._callback:
            self._callback(data=sanitized)

    def get(self, *args):
        return self._items.get(*args)

    def persist(self):
        """Call the callback if it's set."""
        if self._callback:
            self._callback(data=self._as_dict())

    def setdefaults(self, d):
        d = wandb_helper.parse_config(d)
        # strip out keys already configured
        d = {k: v for k, v in d.items() if k not in self._items}
        d = self._sanitize_dict(d)
        self._items.update(d)
        if self._callback:
            self._callback(data=d)

    def _get_user_id(self, user) -> int:
        if user not in self._users:
            self._users[user] = self._users_cnt
            self._users_inv[self._users_cnt] = user
            object.__setattr__(self, "_users_cnt", self._users_cnt + 1)

        return self._users[user]

    def update_locked(self, d, user=None, _allow_val_change=None):
        """Shallow-update config with `d` and lock config updates on d's keys."""
        num = self._get_user_id(user)

        for k, v in d.items():
            k, v = self._sanitize(k, v, allow_val_change=_allow_val_change)
            self._locked[k] = num
            self._items[k] = v

        if self._callback:
            self._callback(data=d)

    def merge_locked(self, d, user=None, _allow_val_change=None):
        """Recursively merge-update config with `d` and lock config updates on d's keys."""
        num = self._get_user_id(user)
        callback_d = {}

        for k, v in d.items():
            k, v = self._sanitize(k, v, allow_val_change=_allow_val_change)
            self._locked[k] = num

            if (
                k in self._items
                and isinstance(self._items[k], dict)
                and isinstance(v, dict)
            ):
                self._items[k] = config_util.merge_dicts(self._items[k], v)
            else:
                self._items[k] = v

            callback_d[k] = self._items[k]

        if self._callback:
            self._callback(data=callback_d)

    def _load_defaults(self):
        conf_dict = config_util.dict_from_config_file("config-defaults.yaml")
        if conf_dict is not None:
            self.update(conf_dict)

    def _sanitize_dict(
        self,
        config_dict,
        allow_val_change=None,
        ignore_keys: Optional[set] = None,
    ):
        sanitized = {}
        self._raise_value_error_on_nested_artifact(config_dict)
        for k, v in config_dict.items():
            if ignore_keys and k in ignore_keys:
                continue
            k, v = self._sanitize(k, v, allow_val_change)
            sanitized[k] = v
        return sanitized

    def _sanitize(self, key, val, allow_val_change=None):
        # TODO: enable WBValues in the config in the future
        # refuse all WBValues which is all Media and Histograms
        if isinstance(val, wandb.sdk.data_types.base_types.wb_value.WBValue):
            raise TypeError("WBValue objects cannot be added to the run config")
        # Let jupyter change config freely by default
        if self._settings and self._settings._jupyter and allow_val_change is None:
            allow_val_change = True
        # We always normalize keys by stripping '-'
        key = key.strip("-")
        if _is_artifact_representation(val):
            val = self._artifact_callback(key, val)
        # if the user inserts an artifact into the config
        if not isinstance(val, wandb.Artifact):
            val = json_friendly_val(val)
        if not allow_val_change:
            if key in self._items and val != self._items[key]:
                raise config_util.ConfigError(
                    f'Attempted to change value of key "{key}" '
                    f"from {self._items[key]} to {val}\n"
                    "If you really want to do this, pass"
                    " allow_val_change=True to config.update()"
                )
        return key, val

    def _raise_value_error_on_nested_artifact(self, v, nested=False):
        # we can't swap nested artifacts because their root key can be locked by other values
        # best if we don't allow nested artifacts until we can lock nested keys in the config
        if isinstance(v, dict) and check_dict_contains_nested_artifact(v, nested):
            raise ValueError(
                "Instances of wandb.Artifact can only be top level keys in"
                " a run's config"
            )


class ConfigStatic:
    def __init__(self, config):
        object.__setattr__(self, "__dict__", dict(config))

    def __setattr__(self, name, value):
        raise AttributeError("Error: run.config_static is a readonly object")

    def __setitem__(self, key, val):
        raise AttributeError("Error: run.config_static is a readonly object")

    def keys(self):
        return self.__dict__.keys()

    def __getitem__(self, key):
        return self.__dict__[key]

    def __str__(self):
        return str(self.__dict__)
