from __future__ import annotations

from typing import Any

import torch
from tqdm.auto import tqdm
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect import DetectionPredictor
from ultralytics.utils import ops

import wandb


def scale_bounding_box_to_original_image_shape(
    box: torch.Tensor,
    resized_image_shape: tuple,
    original_image_shape: tuple,
    ratio_pad: bool,
) -> list[int]:
    """YOLOv8 resizes images during training and the label values are normalized based on this resized shape.

    This function rescales the bounding box labels to the original
    image shape.

    Reference: https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/utils/callbacks/comet.py#L105
    """
    resized_image_height, resized_image_width = resized_image_shape
    # Convert normalized xywh format predictions to xyxy in resized scale format
    box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
    # Scale box predictions from resized image scale back to original image scale
    box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
    # # Convert bounding box format from xyxy to xywh for Comet logging
    box = ops.xyxy2xywh(box)
    return box.tolist()


def get_ground_truth_bbox_annotations(
    img_idx: int, image_path: str, batch: dict, class_name_map: dict = None
) -> list[dict[str, Any]]:
    """Get ground truth bounding box annotation data in the form required for `wandb.Image` overlay system."""
    indices = batch["batch_idx"] == img_idx
    bboxes = batch["bboxes"][indices]
    if len(batch["cls"][indices]):
        cls_labels = batch["cls"][indices].squeeze(1).tolist()
    else:
        cls_labels = []

    class_name_map_reverse = {v: k for k, v in class_name_map.items()}

    if len(bboxes) == 0:
        wandb.termwarn(
            f"Image: {image_path} has no bounding boxes labels", repeat=False
        )
        return None

    if len(batch["cls"][indices]):
        cls_labels = batch["cls"][indices].squeeze(1).tolist()
    else:
        cls_labels = []

    if class_name_map:
        cls_labels = [str(class_name_map[label]) for label in cls_labels]

    original_image_shape = batch["ori_shape"][img_idx]
    resized_image_shape = batch["resized_shape"][img_idx]
    ratio_pad = batch["ratio_pad"][img_idx]

    data = []
    for box, label in zip(bboxes, cls_labels):
        box = scale_bounding_box_to_original_image_shape(
            box, resized_image_shape, original_image_shape, ratio_pad
        )
        data.append(
            {
                "position": {
                    "middle": [int(box[0]), int(box[1])],
                    "width": int(box[2]),
                    "height": int(box[3]),
                },
                "domain": "pixel",
                "class_id": class_name_map_reverse[label],
                "box_caption": label,
            }
        )

    return data


def get_mean_confidence_map(
    classes: list, confidence: list, class_id_to_label: dict
) -> dict[str, float]:
    """Get Mean-confidence map from the predictions to be logged into a `wandb.Table`."""
    confidence_map = {v: [] for _, v in class_id_to_label.items()}
    for class_idx, confidence_value in zip(classes, confidence):
        confidence_map[class_id_to_label[class_idx]].append(confidence_value)
    updated_confidence_map = {}
    for label, confidence_list in confidence_map.items():
        if len(confidence_list) > 0:
            updated_confidence_map[label] = sum(confidence_list) / len(confidence_list)
        else:
            updated_confidence_map[label] = 0
    return updated_confidence_map


def get_boxes(result: Results) -> tuple[dict, dict]:
    """Convert an ultralytics prediction result into metadata for the `wandb.Image` overlay system."""
    boxes = result.boxes.xywh.long().numpy()
    classes = result.boxes.cls.long().numpy()
    confidence = result.boxes.conf.numpy()
    class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
    mean_confidence_map = get_mean_confidence_map(
        classes, confidence, class_id_to_label
    )
    box_data = []
    for idx in range(len(boxes)):
        box_data.append(
            {
                "position": {
                    "middle": [int(boxes[idx][0]), int(boxes[idx][1])],
                    "width": int(boxes[idx][2]),
                    "height": int(boxes[idx][3]),
                },
                "domain": "pixel",
                "class_id": int(classes[idx]),
                "box_caption": class_id_to_label[int(classes[idx])],
                "scores": {"confidence": float(confidence[idx])},
            }
        )
    boxes = {
        "predictions": {
            "box_data": box_data,
            "class_labels": class_id_to_label,
        },
    }
    return boxes, mean_confidence_map


def plot_bbox_predictions(
    result: Results, model_name: str, table: wandb.Table | None = None
) -> wandb.Table | tuple[wandb.Image, dict, dict]:
    """Plot the images with the W&B overlay system.

    The `wandb.Image` is either added to a `wandb.Table` or returned.
    """
    result = result.to("cpu")
    boxes, mean_confidence_map = get_boxes(result)
    image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes)
    if table is not None:
        table.add_data(
            model_name,
            image,
            len(boxes["predictions"]["box_data"]),
            mean_confidence_map,
            result.speed,
        )
        return table
    return image, boxes["predictions"], mean_confidence_map


def plot_detection_validation_results(
    dataloader: Any,
    class_label_map: dict,
    model_name: str,
    predictor: DetectionPredictor,
    table: wandb.Table,
    max_validation_batches: int,
    epoch: int | None = None,
) -> wandb.Table:
    """Plot validation results in a table."""
    data_idx = 0
    num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
    max_validation_batches = min(max_validation_batches, num_dataloader_batches)
    for batch_idx, batch in enumerate(dataloader):
        prediction_results = predictor(batch["im_file"])
        progress_bar_result_iterable = tqdm(
            enumerate(prediction_results),
            total=len(prediction_results),
            desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
        )
        for img_idx, prediction_result in progress_bar_result_iterable:
            prediction_result = prediction_result.to("cpu")
            _, prediction_box_data, mean_confidence_map = plot_bbox_predictions(
                prediction_result, model_name
            )
            try:
                ground_truth_data = get_ground_truth_bbox_annotations(
                    img_idx, batch["im_file"][img_idx], batch, class_label_map
                )
                wandb_image = wandb.Image(
                    batch["im_file"][img_idx],
                    boxes={
                        "ground-truth": {
                            "box_data": ground_truth_data,
                            "class_labels": class_label_map,
                        },
                        "predictions": {
                            "box_data": prediction_box_data["box_data"],
                            "class_labels": class_label_map,
                        },
                    },
                )
                table_rows = [
                    data_idx,
                    batch_idx,
                    wandb_image,
                    mean_confidence_map,
                    prediction_result.speed,
                ]
                table_rows = [epoch] + table_rows if epoch is not None else table_rows
                table_rows = [model_name] + table_rows
                table.add_data(*table_rows)
                data_idx += 1
            except TypeError:
                pass
        if batch_idx + 1 == max_validation_batches:
            break
    return table
