from __future__ import annotations

import copy
from datetime import datetime
from typing import Callable, Union

from packaging import version

try:
    import dill as pickle
except ImportError:
    import pickle

import wandb
from wandb.sdk.lib import telemetry

try:
    import torch
    import ultralytics
    from tqdm.auto import tqdm

    if version.parse(ultralytics.__version__) > version.parse("8.0.238"):
        wandb.termwarn(
            """This integration is tested and supported for ultralytics v8.0.238 and below.
            Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.""",
            repeat=False,
        )

    from ultralytics.models import YOLO
    from ultralytics.models.sam.predict import Predictor as SAMPredictor
    from ultralytics.models.yolo.classify import (
        ClassificationPredictor,
        ClassificationTrainer,
        ClassificationValidator,
    )
    from ultralytics.models.yolo.detect import (
        DetectionPredictor,
        DetectionTrainer,
        DetectionValidator,
    )
    from ultralytics.models.yolo.pose import PosePredictor, PoseTrainer, PoseValidator
    from ultralytics.models.yolo.segment import (
        SegmentationPredictor,
        SegmentationTrainer,
        SegmentationValidator,
    )
    from ultralytics.utils.torch_utils import de_parallel

    try:
        from ultralytics.yolo.utils import RANK, __version__
    except ModuleNotFoundError:
        from ultralytics.utils import RANK, __version__

    from wandb.integration.ultralytics.bbox_utils import (
        plot_bbox_predictions,
        plot_detection_validation_results,
    )
    from wandb.integration.ultralytics.classification_utils import (
        plot_classification_predictions,
        plot_classification_validation_results,
    )
    from wandb.integration.ultralytics.mask_utils import (
        plot_mask_predictions,
        plot_sam_predictions,
        plot_segmentation_validation_results,
    )
    from wandb.integration.ultralytics.pose_utils import (
        plot_pose_predictions,
        plot_pose_validation_results,
    )
except Exception as e:
    wandb.Error(e)


TRAINER_TYPE = Union[
    ClassificationTrainer, DetectionTrainer, SegmentationTrainer, PoseTrainer
]
VALIDATOR_TYPE = Union[
    ClassificationValidator, DetectionValidator, SegmentationValidator, PoseValidator
]
PREDICTOR_TYPE = Union[
    ClassificationPredictor,
    DetectionPredictor,
    SegmentationPredictor,
    PosePredictor,
    SAMPredictor,
]


class WandBUltralyticsCallback:
    """Stateful callback for logging to W&B.

    In particular, it will log model checkpoints, predictions, and
    ground-truth annotations with interactive overlays for bounding boxes
    to Weights & Biases Tables during training, validation and prediction
    for a `ultratytics` workflow.

    Example:
        ```python
        from ultralytics.yolo.engine.model import YOLO
        from wandb.yolov8 import add_wandb_callback

        # initialize YOLO model
        model = YOLO("yolov8n.pt")

        # add wandb callback
        add_wandb_callback(
            model, max_validation_batches=2, enable_model_checkpointing=True
        )

        # train
        model.train(data="coco128.yaml", epochs=5, imgsz=640)

        # validate
        model.val()

        # perform inference
        model(["img1.jpeg", "img2.jpeg"])
        ```

    Args:
        model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
            `ultralytics.yolo.engine.model.YOLO`.
        epoch_logging_interval: (int) interval to log the prediction visualizations
            during training.
        max_validation_batches: (int) maximum number of validation batches to log to
            a table per epoch.
        enable_model_checkpointing: (bool) enable logging model checkpoints as
            artifacts at the end of eveny epoch if set to `True`.
        visualize_skeleton: (bool) visualize pose skeleton by drawing lines connecting
            keypoints for human pose.
    """

    def __init__(
        self,
        model: YOLO,
        epoch_logging_interval: int = 1,
        max_validation_batches: int = 1,
        enable_model_checkpointing: bool = False,
        visualize_skeleton: bool = False,
    ) -> None:
        self.epoch_logging_interval = epoch_logging_interval
        self.max_validation_batches = max_validation_batches
        self.enable_model_checkpointing = enable_model_checkpointing
        self.visualize_skeleton = visualize_skeleton
        self.task = model.task
        self.task_map = model.task_map
        self.model_name = (
            str(model.overrides["model"]).split(".")[0]
            if "model" in model.overrides
            else None
        )
        self._make_tables()
        self._make_predictor(model)
        self.supported_tasks = ["detect", "segment", "pose", "classify"]
        self.prompts = None
        self.run_id = None
        self.train_epoch = None

    def _make_tables(self):
        if self.task in ["detect", "segment"]:
            validation_columns = [
                "Data-Index",
                "Batch-Index",
                "Image",
                "Mean-Confidence",
                "Speed",
            ]
            train_columns = ["Epoch"] + validation_columns
            self.train_validation_table = wandb.Table(
                columns=["Model-Name"] + train_columns
            )
            self.validation_table = wandb.Table(
                columns=["Model-Name"] + validation_columns
            )
            self.prediction_table = wandb.Table(
                columns=[
                    "Model-Name",
                    "Image",
                    "Num-Objects",
                    "Mean-Confidence",
                    "Speed",
                ]
            )
        elif self.task == "classify":
            classification_columns = [
                "Image",
                "Predicted-Category",
                "Prediction-Confidence",
                "Top-5-Prediction-Categories",
                "Top-5-Prediction-Confindence",
                "Probabilities",
                "Speed",
            ]
            validation_columns = ["Data-Index", "Batch-Index"] + classification_columns
            validation_columns.insert(3, "Ground-Truth-Category")
            self.train_validation_table = wandb.Table(
                columns=["Model-Name", "Epoch"] + validation_columns
            )
            self.validation_table = wandb.Table(
                columns=["Model-Name"] + validation_columns
            )
            self.prediction_table = wandb.Table(
                columns=["Model-Name"] + classification_columns
            )
        elif self.task == "pose":
            validation_columns = [
                "Data-Index",
                "Batch-Index",
                "Image-Ground-Truth",
                "Image-Prediction",
                "Num-Instances",
                "Mean-Confidence",
                "Speed",
            ]
            train_columns = ["Epoch"] + validation_columns
            self.train_validation_table = wandb.Table(
                columns=["Model-Name"] + train_columns
            )
            self.validation_table = wandb.Table(
                columns=["Model-Name"] + validation_columns
            )
            self.prediction_table = wandb.Table(
                columns=[
                    "Model-Name",
                    "Image-Prediction",
                    "Num-Instances",
                    "Mean-Confidence",
                    "Speed",
                ]
            )

    def _make_predictor(self, model: YOLO):
        overrides = copy.deepcopy(model.overrides)
        overrides["conf"] = 0.1
        self.predictor = self.task_map[self.task]["predictor"](overrides=overrides)
        self.predictor.callbacks = {}
        self.predictor.args.save = False
        self.predictor.args.save_txt = False
        self.predictor.args.save_crop = False
        self.predictor.args.verbose = None

    def _save_model(self, trainer: TRAINER_TYPE):
        model_checkpoint_artifact = wandb.Artifact(f"run_{wandb.run.id}_model", "model")
        checkpoint_dict = {
            "epoch": trainer.epoch,
            "best_fitness": trainer.best_fitness,
            "model": copy.deepcopy(de_parallel(self.model)).half(),
            "ema": copy.deepcopy(trainer.ema.ema).half(),
            "updates": trainer.ema.updates,
            "optimizer": trainer.optimizer.state_dict(),
            "train_args": vars(trainer.args),
            "date": datetime.now().isoformat(),
            "version": __version__,
        }
        checkpoint_path = trainer.wdir / f"epoch{trainer.epoch}.pt"
        torch.save(checkpoint_dict, checkpoint_path, pickle_module=pickle)
        model_checkpoint_artifact.add_file(checkpoint_path)
        wandb.log_artifact(
            model_checkpoint_artifact, aliases=[f"epoch_{trainer.epoch}"]
        )

    def on_train_start(self, trainer: TRAINER_TYPE):
        with telemetry.context(run=wandb.run) as tel:
            tel.feature.ultralytics_yolov8 = True
        wandb.config.train = vars(trainer.args)
        self.run_id = wandb.run.id

    @torch.no_grad()
    def on_fit_epoch_end(self, trainer: DetectionTrainer):
        if self.task in self.supported_tasks and self.train_epoch != trainer.epoch:
            self.train_epoch = trainer.epoch
            if (self.train_epoch + 1) % self.epoch_logging_interval == 0:
                validator = trainer.validator
                dataloader = validator.dataloader
                class_label_map = validator.names
                self.device = next(trainer.model.parameters()).device
                if isinstance(trainer.model, torch.nn.parallel.DistributedDataParallel):
                    model = trainer.model.module
                else:
                    model = trainer.model
                self.model = copy.deepcopy(model).eval().to(self.device)
                self.predictor.setup_model(model=self.model, verbose=False)
                if self.task == "pose":
                    self.train_validation_table = plot_pose_validation_results(
                        dataloader=dataloader,
                        class_label_map=class_label_map,
                        model_name=self.model_name,
                        predictor=self.predictor,
                        visualize_skeleton=self.visualize_skeleton,
                        table=self.train_validation_table,
                        max_validation_batches=self.max_validation_batches,
                        epoch=trainer.epoch,
                    )
                elif self.task == "segment":
                    self.train_validation_table = plot_segmentation_validation_results(
                        dataloader=dataloader,
                        class_label_map=class_label_map,
                        model_name=self.model_name,
                        predictor=self.predictor,
                        table=self.train_validation_table,
                        max_validation_batches=self.max_validation_batches,
                        epoch=trainer.epoch,
                    )
                elif self.task == "detect":
                    self.train_validation_table = plot_detection_validation_results(
                        dataloader=dataloader,
                        class_label_map=class_label_map,
                        model_name=self.model_name,
                        predictor=self.predictor,
                        table=self.train_validation_table,
                        max_validation_batches=self.max_validation_batches,
                        epoch=trainer.epoch,
                    )
                elif self.task == "classify":
                    self.train_validation_table = (
                        plot_classification_validation_results(
                            dataloader=dataloader,
                            model_name=self.model_name,
                            predictor=self.predictor,
                            table=self.train_validation_table,
                            max_validation_batches=self.max_validation_batches,
                            epoch=trainer.epoch,
                        )
                    )
            if self.enable_model_checkpointing:
                self._save_model(trainer)
            trainer.model.to(self.device)

    def on_train_end(self, trainer: TRAINER_TYPE):
        if self.task in self.supported_tasks:
            wandb.log({"Train-Table": self.train_validation_table}, commit=False)

    def on_val_start(self, validator: VALIDATOR_TYPE):
        wandb.run or wandb.init(
            project=validator.args.project or "YOLOv8",
            job_type="validation_" + validator.args.task,
        )

    @torch.no_grad()
    def on_val_end(self, trainer: VALIDATOR_TYPE):
        if self.task in self.supported_tasks:
            validator = trainer
            dataloader = validator.dataloader
            class_label_map = validator.names
            if self.task == "pose":
                self.validation_table = plot_pose_validation_results(
                    dataloader=dataloader,
                    class_label_map=class_label_map,
                    model_name=self.model_name,
                    predictor=self.predictor,
                    visualize_skeleton=self.visualize_skeleton,
                    table=self.validation_table,
                    max_validation_batches=self.max_validation_batches,
                )
            elif self.task == "segment":
                self.validation_table = plot_segmentation_validation_results(
                    dataloader=dataloader,
                    class_label_map=class_label_map,
                    model_name=self.model_name,
                    predictor=self.predictor,
                    table=self.validation_table,
                    max_validation_batches=self.max_validation_batches,
                )
            elif self.task == "detect":
                self.validation_table = plot_detection_validation_results(
                    dataloader=dataloader,
                    class_label_map=class_label_map,
                    model_name=self.model_name,
                    predictor=self.predictor,
                    table=self.validation_table,
                    max_validation_batches=self.max_validation_batches,
                )
            elif self.task == "classify":
                self.validation_table = plot_classification_validation_results(
                    dataloader=dataloader,
                    model_name=self.model_name,
                    predictor=self.predictor,
                    table=self.validation_table,
                    max_validation_batches=self.max_validation_batches,
                )
            wandb.log({"Validation-Table": self.validation_table}, commit=False)

    def on_predict_start(self, predictor: PREDICTOR_TYPE):
        wandb.run or wandb.init(
            project=predictor.args.project or "YOLOv8",
            config=vars(predictor.args),
            job_type="prediction_" + predictor.args.task,
        )
        if isinstance(predictor, SAMPredictor):
            self.prompts = copy.deepcopy(predictor.prompts)
            self.prediction_table = wandb.Table(columns=["Image"])

    def on_predict_end(self, predictor: PREDICTOR_TYPE):
        wandb.config.prediction_configs = vars(predictor.args)
        if self.task in self.supported_tasks:
            for result in tqdm(predictor.results):
                if self.task == "pose":
                    self.prediction_table = plot_pose_predictions(
                        result,
                        self.model_name,
                        self.visualize_skeleton,
                        self.prediction_table,
                    )
                elif self.task == "segment":
                    if isinstance(predictor, SegmentationPredictor):
                        self.prediction_table = plot_mask_predictions(
                            result, self.model_name, self.prediction_table
                        )
                    elif isinstance(predictor, SAMPredictor):
                        self.prediction_table = plot_sam_predictions(
                            result, self.prompts, self.prediction_table
                        )
                elif self.task == "detect":
                    self.prediction_table = plot_bbox_predictions(
                        result, self.model_name, self.prediction_table
                    )
                elif self.task == "classify":
                    self.prediction_table = plot_classification_predictions(
                        result, self.model_name, self.prediction_table
                    )

            wandb.log({"Prediction-Table": self.prediction_table}, commit=False)

    @property
    def callbacks(self) -> dict[str, Callable]:
        """Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
        return {
            "on_train_start": self.on_train_start,
            "on_fit_epoch_end": self.on_fit_epoch_end,
            "on_train_end": self.on_train_end,
            "on_val_start": self.on_val_start,
            "on_val_end": self.on_val_end,
            "on_predict_start": self.on_predict_start,
            "on_predict_end": self.on_predict_end,
        }


# TODO: Add epoch interval
def add_wandb_callback(
    model: YOLO,
    epoch_logging_interval: int = 1,
    enable_model_checkpointing: bool = False,
    enable_train_validation_logging: bool = True,
    enable_validation_logging: bool = True,
    enable_prediction_logging: bool = True,
    max_validation_batches: int | None = 1,
    visualize_skeleton: bool | None = True,
):
    """Function to add the `WandBUltralyticsCallback` callback to the `YOLO` model.

    Example:
        ```python
        from ultralytics.yolo.engine.model import YOLO
        from wandb.yolov8 import add_wandb_callback

        # initialize YOLO model
        model = YOLO("yolov8n.pt")

        # add wandb callback
        add_wandb_callback(
            model, max_validation_batches=2, enable_model_checkpointing=True
        )

        # train
        model.train(data="coco128.yaml", epochs=5, imgsz=640)

        # validate
        model.val()

        # perform inference
        model(["img1.jpeg", "img2.jpeg"])
        ```

    Args:
        model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
            `ultralytics.yolo.engine.model.YOLO`.
        epoch_logging_interval: (int) interval to log the prediction visualizations
            during training.
        enable_model_checkpointing: (bool) enable logging model checkpoints as
            artifacts at the end of eveny epoch if set to `True`.
        enable_train_validation_logging: (bool) enable logging the predictions and
            ground-truths as interactive image overlays on the images from
            the validation dataloader to a `wandb.Table` along with
            mean-confidence of the predictions per-class at the end of each
            training epoch.
        enable_validation_logging: (bool) enable logging the predictions and
            ground-truths as interactive image overlays on the images from the
            validation dataloader to a `wandb.Table` along with
            mean-confidence of the predictions per-class at the end of
            validation.
        enable_prediction_logging: (bool) enable logging the predictions and
            ground-truths as interactive image overlays on the images from the
            validation dataloader to a `wandb.Table` along with mean-confidence
            of the predictions per-class at the end of each prediction.
        max_validation_batches: (Optional[int]) maximum number of validation batches to log to
            a table per epoch.
        visualize_skeleton: (Optional[bool]) visualize pose skeleton by drawing lines connecting
            keypoints for human pose.

    Returns:
        An instance of `ultralytics.yolo.engine.model.YOLO` with the `WandBUltralyticsCallback`.
    """
    if RANK in [-1, 0]:
        wandb_callback = WandBUltralyticsCallback(
            copy.deepcopy(model),
            epoch_logging_interval,
            max_validation_batches,
            enable_model_checkpointing,
            visualize_skeleton,
        )
        callbacks = wandb_callback.callbacks
        if not enable_train_validation_logging:
            _ = callbacks.pop("on_fit_epoch_end")
            _ = callbacks.pop("on_train_end")
        if not enable_validation_logging:
            _ = callbacks.pop("on_val_start")
            _ = callbacks.pop("on_val_end")
        if not enable_prediction_logging:
            _ = callbacks.pop("on_predict_start")
            _ = callbacks.pop("on_predict_end")
        for event, callback_fn in callbacks.items():
            model.add_callback(event, callback_fn)
    else:
        wandb.termerror(
            "The RANK of the process to add the callbacks was neither 0 or "
            "-1. No Weights & Biases callbacks were added to this instance "
            "of the YOLO model."
        )
    return model
