from __future__ import annotations

import cv2
import numpy as np
from tqdm.auto import tqdm
from ultralytics.engine.results import Results
from ultralytics.models.yolo.segment import SegmentationPredictor
from ultralytics.utils.ops import scale_image

import wandb
from wandb.integration.ultralytics.bbox_utils import (
    get_ground_truth_bbox_annotations,
    get_mean_confidence_map,
)


def instance_mask_to_semantic_mask(instance_mask, class_indices):
    height, width, num_instances = instance_mask.shape
    semantic_mask = np.zeros((height, width), dtype=np.uint8)
    for i in range(num_instances):
        instance_map = instance_mask[:, :, i]
        class_index = class_indices[i]
        semantic_mask[instance_map == 1] = class_index
    return semantic_mask


def get_boxes_and_masks(result: Results) -> tuple[dict, dict, dict]:
    boxes = result.boxes.xywh.long().numpy()
    classes = result.boxes.cls.long().numpy()
    confidence = result.boxes.conf.numpy()
    class_id_to_label = {int(k): str(v) for k, v in result.names.items()}
    class_id_to_label.update({len(result.names.items()): "background"})
    mean_confidence_map = get_mean_confidence_map(
        classes, confidence, class_id_to_label
    )
    masks = None
    if result.masks is not None:
        scaled_instance_mask = scale_image(
            np.transpose(result.masks.data.numpy(), (1, 2, 0)),
            result.orig_img[:, :, ::-1].shape,
        )
        scaled_semantic_mask = instance_mask_to_semantic_mask(
            scaled_instance_mask, classes.tolist()
        )
        scaled_semantic_mask[scaled_semantic_mask == 0] = len(result.names.items())
        masks = {
            "predictions": {
                "mask_data": scaled_semantic_mask,
                "class_labels": class_id_to_label,
            }
        }
    box_data, total_confidence = [], 0.0
    for idx in range(len(boxes)):
        box_data.append(
            {
                "position": {
                    "middle": [int(boxes[idx][0]), int(boxes[idx][1])],
                    "width": int(boxes[idx][2]),
                    "height": int(boxes[idx][3]),
                },
                "domain": "pixel",
                "class_id": int(classes[idx]),
                "box_caption": class_id_to_label[int(classes[idx])],
                "scores": {"confidence": float(confidence[idx])},
            }
        )
        total_confidence += float(confidence[idx])

    boxes = {
        "predictions": {
            "box_data": box_data,
            "class_labels": class_id_to_label,
        },
    }
    return boxes, masks, mean_confidence_map


def plot_mask_predictions(
    result: Results, model_name: str, table: wandb.Table | None = None
) -> tuple[wandb.Image, dict, dict, dict]:
    result = result.to("cpu")
    boxes, masks, mean_confidence_map = get_boxes_and_masks(result)
    image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes, masks=masks)
    if table is not None:
        table.add_data(
            model_name,
            image,
            len(boxes["predictions"]["box_data"]),
            mean_confidence_map,
            result.speed,
        )
        return table
    return image, masks, boxes["predictions"], mean_confidence_map


def structure_prompts_and_image(image: np.array, prompt: dict) -> dict:
    wb_box_data = []
    if prompt["bboxes"] is not None:
        wb_box_data.append(
            {
                "position": {
                    "middle": [prompt["bboxes"][0], prompt["bboxes"][1]],
                    "width": prompt["bboxes"][2],
                    "height": prompt["bboxes"][3],
                },
                "domain": "pixel",
                "class_id": 1,
                "box_caption": "Prompt-Box",
            }
        )
    if prompt["points"] is not None:
        image = image.copy().astype(np.uint8)
        image = cv2.circle(
            image, tuple(prompt["points"]), 5, (0, 255, 0), -1, lineType=cv2.LINE_AA
        )
    wb_box_data = {
        "prompts": {
            "box_data": wb_box_data,
            "class_labels": {1: "Prompt-Box"},
        }
    }
    return image, wb_box_data


def plot_sam_predictions(
    result: Results, prompt: dict, table: wandb.Table
) -> wandb.Table:
    result = result.to("cpu")
    image = result.orig_img[:, :, ::-1]
    image, wb_box_data = structure_prompts_and_image(image, prompt)
    image = wandb.Image(
        image,
        boxes=wb_box_data,
        masks={
            "predictions": {
                "mask_data": np.squeeze(result.masks.data.cpu().numpy().astype(int)),
                "class_labels": {0: "Background", 1: "Prediction"},
            }
        },
    )
    table.add_data(image)
    return table


def plot_segmentation_validation_results(
    dataloader,
    class_label_map,
    model_name: str,
    predictor: SegmentationPredictor,
    table: wandb.Table,
    max_validation_batches: int,
    epoch: int | None = None,
):
    data_idx = 0
    num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size
    max_validation_batches = min(max_validation_batches, num_dataloader_batches)
    for batch_idx, batch in enumerate(dataloader):
        prediction_results = predictor(batch["im_file"])
        progress_bar_result_iterable = tqdm(
            enumerate(prediction_results),
            total=len(prediction_results),
            desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}",
        )
        for img_idx, prediction_result in progress_bar_result_iterable:
            prediction_result = prediction_result.to("cpu")
            (
                _,
                prediction_mask_data,
                prediction_box_data,
                mean_confidence_map,
            ) = plot_mask_predictions(prediction_result, model_name)
            try:
                ground_truth_data = get_ground_truth_bbox_annotations(
                    img_idx, batch["im_file"][img_idx], batch, class_label_map
                )
                wandb_image = wandb.Image(
                    batch["im_file"][img_idx],
                    boxes={
                        "ground-truth": {
                            "box_data": ground_truth_data,
                            "class_labels": class_label_map,
                        },
                        "predictions": prediction_box_data,
                    },
                    masks=prediction_mask_data,
                )
                table_rows = [
                    data_idx,
                    batch_idx,
                    wandb_image,
                    mean_confidence_map,
                    prediction_result.speed,
                ]
                table_rows = [epoch] + table_rows if epoch is not None else table_rows
                table_rows = [model_name] + table_rows
                table.add_data(*table_rows)
                data_idx += 1
            except TypeError:
                pass
        if batch_idx + 1 == max_validation_batches:
            break
    return table
