"""xgboost init!"""

from __future__ import annotations

import json
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, NamedTuple, Tuple, Union, cast

import xgboost as xgb
import xgboost.callback
from typing_extensions import TypeAlias, override

import wandb
from wandb.sdk.lib import telemetry as wb_telemetry

MINIMIZE_METRICS = [
    "rmse",
    "rmsle",
    "mae",
    "mape",
    "mphe",
    "logloss",
    "error",
    "error@t",
    "merror",
]

MAXIMIZE_METRICS = ["auc", "aucpr", "ndcg", "map", "ndcg@n", "map@n"]


if TYPE_CHECKING:

    class CallbackEnv(NamedTuple):
        evaluation_result_list: list

    # Copied from xgboost's source code. These types are not exported.
    _ScoreList = Union[List[float], List[Tuple[float, float]]]
    _EvalsLog: TypeAlias = Dict[str, Dict[str, _ScoreList]]


def wandb_callback() -> Callable:
    """Old style callback that will be deprecated in favor of WandbCallback. Please try the new logger for more features."""
    warnings.warn(
        "wandb_callback will be deprecated in favor of WandbCallback. Please use WandbCallback for more features.",
        UserWarning,
        stacklevel=2,
    )

    with wb_telemetry.context() as tel:
        tel.feature.xgboost_old_wandb_callback = True

    def callback(env: CallbackEnv) -> None:
        for k, v in env.evaluation_result_list:
            wandb.log({k: v}, commit=False)
        wandb.log({})

    return callback


class WandbCallback(xgboost.callback.TrainingCallback):
    """`WandbCallback` automatically integrates XGBoost with wandb.

    Args:
        log_model: (boolean) if True save and upload the model to Weights & Biases Artifacts
        log_feature_importance: (boolean) if True log a feature importance bar plot
        importance_type: (str) one of {weight, gain, cover, total_gain, total_cover} for tree model. weight for linear model.
        define_metric: (boolean) if True (default) capture model performance at the best step, instead of the last step, of training in your `wandb.summary`.

    Passing `WandbCallback` to XGBoost will:

    - log the booster model configuration to Weights & Biases
    - log evaluation metrics collected by XGBoost, such as rmse, accuracy etc. to Weights & Biases
    - log training metric collected by XGBoost (if you provide training data to eval_set)
    - log the best score and the best iteration
    - save and upload your trained model to Weights & Biases Artifacts (when `log_model = True`)
    - log feature importance plot when `log_feature_importance=True` (default).
    - Capture the best eval metric in `wandb.summary` when `define_metric=True` (default).

    Example:
        ```python
        bst_params = dict(
            objective="reg:squarederror",
            colsample_bytree=0.3,
            learning_rate=0.1,
            max_depth=5,
            alpha=10,
            n_estimators=10,
            tree_method="hist",
            callbacks=[WandbCallback()],
        )

        xg_reg = xgb.XGBRegressor(**bst_params)
        xg_reg.fit(
            X_train,
            y_train,
            eval_set=[(X_test, y_test)],
        )
        ```
    """

    def __init__(
        self,
        log_model: bool = False,
        log_feature_importance: bool = True,
        importance_type: str = "gain",
        define_metric: bool = True,
    ):
        super().__init__()

        self.log_model: bool = log_model
        self.log_feature_importance: bool = log_feature_importance
        self.importance_type: str = importance_type
        self.define_metric: bool = define_metric

        if wandb.run is None:
            raise wandb.Error("You must call wandb.init() before WandbCallback()")

        with wb_telemetry.context() as tel:
            tel.feature.xgboost_wandb_callback = True

    @override
    def before_training(self, model: xgb.Booster) -> xgb.Booster:
        """Run before training is finished."""
        # Update W&B config
        config = model.save_config()
        wandb.config.update(json.loads(config))

        return model

    @override
    def after_training(self, model: xgb.Booster) -> xgb.Booster:
        """Run after training is finished."""
        # Log the booster model as artifacts
        if self.log_model:
            self._log_model_as_artifact(model)

        # Plot feature importance
        if self.log_feature_importance:
            self._log_feature_importance(model)

        # Log the best score and best iteration
        if model.attr("best_score") is not None:
            wandb.log(
                {
                    "best_score": float(cast(str, model.attr("best_score"))),
                    "best_iteration": int(cast(str, model.attr("best_iteration"))),
                }
            )

        return model

    @override
    def after_iteration(
        self,
        model: xgb.Booster,
        epoch: int,
        evals_log: _EvalsLog,
    ) -> bool:
        """Run after each iteration. Return True when training should stop."""
        # Log metrics
        for data, metric in evals_log.items():
            for metric_name, log in metric.items():
                if self.define_metric:
                    self._define_metric(data, metric_name)
                    wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)
                else:
                    wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False)

        wandb.log({"epoch": epoch})

        self.define_metric = False

        return False

    def _log_model_as_artifact(self, model: xgb.Booster) -> None:
        model_name = f"{wandb.run.id}_model.json"  # type: ignore
        model_path = Path(wandb.run.dir) / model_name  # type: ignore
        model.save_model(str(model_path))

        model_artifact = wandb.Artifact(name=model_name, type="model")
        model_artifact.add_file(str(model_path))
        wandb.log_artifact(model_artifact)

    def _log_feature_importance(self, model: xgb.Booster) -> None:
        fi = model.get_score(importance_type=self.importance_type)
        fi_data = [[k, fi[k]] for k in fi]
        table = wandb.Table(data=fi_data, columns=["Feature", "Importance"])
        wandb.log(
            {
                "Feature Importance": wandb.plot.bar(
                    table, "Feature", "Importance", title="Feature Importance"
                )
            }
        )

    def _define_metric(self, data: str, metric_name: str) -> None:
        if "loss" in str.lower(metric_name):
            wandb.define_metric(f"{data}-{metric_name}", summary="min")
        elif str.lower(metric_name) in MINIMIZE_METRICS:
            wandb.define_metric(f"{data}-{metric_name}", summary="min")
        elif str.lower(metric_name) in MAXIMIZE_METRICS:
            wandb.define_metric(f"{data}-{metric_name}", summary="max")
        else:
            pass
