from __future__ import annotations

from typing import Any, Callable

from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.engine.trainer import BaseTrainer

try:
    from ultralytics.yolo.utils import RANK
    from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
except ModuleNotFoundError:
    from ultralytics.utils import RANK
    from ultralytics.utils.torch_utils import get_flops, get_num_params
from ultralytics.yolo.v8.classify.train import ClassificationTrainer

import wandb
from wandb.sdk.lib import telemetry


class WandbCallback:
    """An internal YOLO model wrapper that tracks metrics, and logs models to Weights & Biases.

    Usage:
    ```python
    from wandb.integration.yolov8.yolov8 import WandbCallback

    model = YOLO("yolov8n.pt")
    wandb_logger = WandbCallback(
        model,
    )
    for event, callback_fn in wandb_logger.callbacks.items():
        model.add_callback(event, callback_fn)
    ```
    """

    def __init__(
        self,
        yolo: YOLO,
        run_name: str | None = None,
        project: str | None = None,
        tags: list[str] | None = None,
        resume: str | None = None,
        **kwargs: Any | None,
    ) -> None:
        """A utility class to manage wandb run and various callbacks for the ultralytics YOLOv8 framework.

        Args:
            yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO`
            run_name, str: The name of the Weights & Biases run, defaults to an auto generated run_name if `trainer.args.name` is not defined.
            project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined.
            tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`.
            resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`.
            **kwargs: Additional arguments to be passed to `wandb.init()`.
        """
        self.yolo = yolo
        self.run_name = run_name
        self.project = project
        self.tags = tags
        self.resume = resume
        self.kwargs = kwargs

    def on_pretrain_routine_start(self, trainer: BaseTrainer) -> None:
        """Starts a new wandb run to track the training process and log to Weights & Biases.

        Args:
            trainer: A task trainer that's inherited from `:class:ultralytics.yolo.engine.trainer.BaseTrainer`
                    that contains the model training and optimization routine.
        """
        if wandb.run is None:
            self.run = wandb.init(
                name=self.run_name if self.run_name else trainer.args.name,
                project=self.project
                if self.project
                else trainer.args.project or "YOLOv8",
                tags=self.tags if self.tags else ["YOLOv8"],
                config=vars(trainer.args),
                resume=self.resume if self.resume else None,
                **self.kwargs,
            )
        else:
            self.run = wandb.run
        assert self.run is not None
        self.run.define_metric("epoch", hidden=True)
        self.run.define_metric(
            "train/*", step_metric="epoch", step_sync=True, summary="min"
        )

        self.run.define_metric(
            "val/*", step_metric="epoch", step_sync=True, summary="min"
        )

        self.run.define_metric(
            "metrics/*", step_metric="epoch", step_sync=True, summary="max"
        )
        self.run.define_metric(
            "lr/*", step_metric="epoch", step_sync=True, summary="last"
        )

        with telemetry.context(run=wandb.run) as tel:
            tel.feature.ultralytics_yolov8 = True

    def on_pretrain_routine_end(self, trainer: BaseTrainer) -> None:
        assert self.run is not None
        self.run.summary.update(
            {
                "model/parameters": get_num_params(trainer.model),
                "model/GFLOPs": round(get_flops(trainer.model), 3),
            }
        )

    def on_train_epoch_start(self, trainer: BaseTrainer) -> None:
        """On train epoch start we only log epoch number to the Weights & Biases run."""
        # We log the epoch number here to commit the previous step,
        assert self.run is not None
        self.run.log({"epoch": trainer.epoch + 1})

    def on_train_epoch_end(self, trainer: BaseTrainer) -> None:
        """On train epoch end we log all the metrics to the Weights & Biases run."""
        assert self.run is not None
        self.run.log(
            {
                **trainer.metrics,
                **trainer.label_loss_items(trainer.tloss, prefix="train"),
                **trainer.lr,
            },
        )
        # Currently only the detection and segmentation trainers save images to the save_dir
        if not isinstance(trainer, ClassificationTrainer):
            self.run.log(
                {
                    "train_batch_images": [
                        wandb.Image(str(image_path), caption=image_path.stem)
                        for image_path in trainer.save_dir.glob("train_batch*.jpg")
                    ]
                }
            )

    def on_fit_epoch_end(self, trainer: BaseTrainer) -> None:
        """On fit epoch end we log all the best metrics and model detail to Weights & Biases run summary."""
        assert self.run is not None
        if trainer.epoch == 0:
            speeds = [
                trainer.validator.speed.get(
                    key,
                )
                for key in (1, "inference")
            ]
            speed = speeds[0] if speeds[0] else speeds[1]
            if speed:
                self.run.summary.update(
                    {
                        "model/speed(ms/img)": round(speed, 3),
                    }
                )
        if trainer.best_fitness == trainer.fitness:
            self.run.summary.update(
                {
                    "best/epoch": trainer.epoch + 1,
                    **{f"best/{key}": val for key, val in trainer.metrics.items()},
                }
            )

    def on_train_end(self, trainer: BaseTrainer) -> None:
        """On train end we log all the media, including plots, images and best model artifact to Weights & Biases."""
        # Currently only the detection and segmentation trainers save images to the save_dir
        assert self.run is not None
        if not isinstance(trainer, ClassificationTrainer):
            assert self.run is not None
            self.run.log(
                {
                    "plots": [
                        wandb.Image(str(image_path), caption=image_path.stem)
                        for image_path in trainer.save_dir.glob("*.png")
                    ],
                    "val_images": [
                        wandb.Image(str(image_path), caption=image_path.stem)
                        for image_path in trainer.validator.save_dir.glob("val*.jpg")
                    ],
                },
            )

        if trainer.best.exists():
            assert self.run is not None
            self.run.log_artifact(
                str(trainer.best),
                type="model",
                name=f"{self.run.name}_{trainer.args.task}.pt",
                aliases=["best", f"epoch_{trainer.epoch + 1}"],
            )

    def on_model_save(self, trainer: BaseTrainer) -> None:
        """On model save we log the model as an artifact to Weights & Biases."""
        assert self.run is not None
        self.run.log_artifact(
            str(trainer.last),
            type="model",
            name=f"{self.run.name}_{trainer.args.task}.pt",
            aliases=["last", f"epoch_{trainer.epoch + 1}"],
        )

    def teardown(self, _trainer: BaseTrainer) -> None:
        """On teardown, we finish the Weights & Biases run and set it to None."""
        assert self.run is not None
        self.run.finish()
        self.run = None

    @property
    def callbacks(
        self,
    ) -> dict[str, Callable]:
        """Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging."""
        return {
            "on_pretrain_routine_start": self.on_pretrain_routine_start,
            "on_pretrain_routine_end": self.on_pretrain_routine_end,
            "on_train_epoch_start": self.on_train_epoch_start,
            "on_train_epoch_end": self.on_train_epoch_end,
            "on_fit_epoch_end": self.on_fit_epoch_end,
            "on_train_end": self.on_train_end,
            "on_model_save": self.on_model_save,
            "teardown": self.teardown,
        }


def add_callbacks(
    yolo: YOLO,
    run_name: str | None = None,
    project: str | None = None,
    tags: list[str] | None = None,
    resume: str | None = None,
    **kwargs: Any | None,
) -> YOLO:
    """A YOLO model wrapper that tracks metrics, and logs models to Weights & Biases.

    Args:
        yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO`
        run_name, str: The name of the Weights & Biases run, defaults to an auto generated name if `trainer.args.name` is not defined.
        project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined.
        tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`.
        resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`.
        **kwargs: Additional arguments to be passed to `wandb.init()`.

    Usage:
    ```python
    from wandb.integration.yolov8 import add_callbacks as add_wandb_callbacks

    model = YOLO("yolov8n.pt")
    add_wandb_callbacks(
        model,
    )
    model.train(
        data="coco128.yaml",
        epochs=3,
        imgsz=640,
    )
    ```
    """
    wandb.termwarn(
        """The wandb callback is currently in beta and is subject to change based on updates to `ultralytics yolov8`.
        The callback is tested and supported for ultralytics v8.0.43 and above.
        Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.
        """,
        repeat=False,
    )
    wandb.termwarn(
        """This wandb callback is no longer functional and would be deprecated in the near future.
        We recommend you to use the updated callback using `from wandb.integration.ultralytics import add_wandb_callback`.
        The updated callback is tested and supported for ultralytics 8.0.167 and above.
        You can refer to https://docs.wandb.ai/guides/integrations/ultralytics for the updated documentation.
        Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.
        """,
        repeat=False,
    )

    if RANK in [-1, 0]:
        wandb_logger = WandbCallback(
            yolo, run_name=run_name, project=project, tags=tags, resume=resume, **kwargs
        )
        for event, callback_fn in wandb_logger.callbacks.items():
            yolo.add_callback(event, callback_fn)
        return yolo
    else:
        wandb.termerror(
            "The RANK of the process to add the callbacks was neither 0 or -1."
            "No Weights & Biases callbacks were added to this instance of the YOLO model."
        )
    return yolo
