# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

import datetime
import logging
import os
from pathlib import Path
from typing import Any, Optional, Union

from kornia.config import kornia_config
from kornia.core import Tensor, tensor
from kornia.core.external import boxmot
from kornia.core.external import numpy as np
from kornia.io import write_image
from kornia.models.detection.base import ObjectDetector
from kornia.models.detection.rtdetr import RTDETRDetectorBuilder
from kornia.utils.image import tensor_to_image

__all__ = ["BoxMotTracker"]

logger = logging.getLogger(__name__)


class BoxMotTracker:
    """BoxMotTracker is a module that wraps a detector and a tracker model.

    This module uses BoxMot library for tracking.

    Args:
        detector: ObjectDetector: The detector model.
        tracker_model_name: The name of the tracker model. Valid options are:
            - "BoTSORT"
            - "DeepOCSORT"
            - "OCSORT"
            - "HybridSORT"
            - "ByteTrack"
            - "StrongSORT"
            - "ImprAssoc"
        tracker_model_weights: Path to the model weights for ReID (Re-Identification).
        device: Device on which to run the model (e.g., 'cpu' or 'cuda').
        fp16: Whether to use half-precision (fp16) for faster inference on compatible devices.
        per_class: Whether to perform per-class tracking
        track_high_thresh: High threshold for detection confidence.
            Detections above this threshold are used in the first association round.
        track_low_thresh: Low threshold for detection confidence.
            Detections below this threshold are ignored.
        new_track_thresh: Threshold for creating a new track.
            Detections above this threshold will be considered as potential new tracks.
        track_buffer: Number of frames to keep a track alive after it was last detected.
        match_thresh: Threshold for the matching step in data association.
        proximity_thresh: Threshold for IoU (Intersection over Union) distance in first-round association.
        appearance_thresh: Threshold for appearance embedding distance in the ReID module.
        cmc_method: Method for correcting camera motion. Options include "sof" (simple optical flow).
        frame_rate: Frame rate of the video being processed. Used to scale the track buffer size.
        fuse_first_associate: Whether to fuse appearance and motion information during the first association step.
        with_reid: Whether to use ReID (Re-Identification) features for association.

    .. code-block:: python

        import kornia
        image = kornia.utils.sample.get_sample_images()[0][None]
        model = BoxMotTracker()
        for i in range(4):  # At least 4 frames are needed to initialize the tracking position
            model.update(image)
        model.save(image)

    .. note::
        At least 4 frames are needed to initialize the tracking position.

    """

    name: str = "boxmot_tracker"

    def __init__(
        self,
        detector: Union[ObjectDetector, str] = "rtdetr_r18vd",
        tracker_model_name: str = "DeepOCSORT",
        tracker_model_weights: str = "osnet_x0_25_msmt17.pt",
        device: str = "cpu",
        fp16: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        if isinstance(detector, str):
            if detector.startswith("rtdetr"):
                detector = RTDETRDetectorBuilder.build(model_name=detector)
            else:
                raise ValueError(
                    f"Detector `{detector}` not available. You may pass an ObjectDetector instance instead."
                )
        self.detector = detector
        os.makedirs(f"{kornia_config.hub_models_dir}/boxmot", exist_ok=True)
        self.tracker = getattr(boxmot, tracker_model_name)(
            model_weights=Path(os.path.join(f"{kornia_config.hub_models_dir}/boxmot", tracker_model_weights)),
            device=device,
            fp16=fp16,
            **kwargs,
        )

    def update(self, image: Tensor) -> None:
        """Update the tracker with a new image.

        Args:
            image: The input image.

        """
        if not (image.ndim == 4 and image.shape[0] == 1) and not image.ndim == 3:
            raise ValueError(f"Input tensor must be of shape (1, 3, H, W) or (3, H, W). Got {image.shape}")

        if image.ndim == 3:
            image = image.unsqueeze(0)

        detections_raw: Union[Tensor, list[Tensor]] = self.detector(image)

        detections = detections_raw[0].cpu().numpy()  # Batch size is 1

        detections = np.array(  # type: ignore
            [
                detections[:, 2],
                detections[:, 3],
                detections[:, 2] + detections[:, 4],
                detections[:, 3] + detections[:, 5],
                detections[:, 1],
                detections[:, 0],
            ]
        ).T

        if detections.shape[0] == 0:
            # empty N X (x, y, x, y, conf, cls)
            detections = np.empty((0, 6))  # type: ignore

        frame_raw = (tensor_to_image(image) * 255).astype(np.uint8)
        # --> M X (x, y, x, y, id, conf, cls, ind)
        return self.tracker.update(detections, frame_raw)

    def visualize(self, image: Tensor, show_trajectories: bool = True) -> Tensor:
        """Visualize the results of the tracker.

        Args:
            image: The input image.
            show_trajectories: Whether to show the trajectories.

        Returns:
            The image with the results of the tracker.

        """
        frame_raw = (tensor_to_image(image) * 255).astype(np.uint8)
        self.tracker.plot_results(frame_raw, show_trajectories=show_trajectories)

        return tensor(frame_raw).permute(2, 0, 1)

    def save(self, image: Tensor, show_trajectories: bool = True, directory: Optional[str] = None) -> None:
        """Save the model to ONNX format.

        Args:
            image: The input image.
            show_trajectories: Whether to visualize trajectories.
            directory: Where to save the file(s).

        """
        if directory is None:
            name = f"{self.name}_{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}"
            directory = os.path.join("kornia_outputs", name)
        output = self.visualize(image, show_trajectories=show_trajectories)

        os.makedirs(directory, exist_ok=True)
        write_image(
            os.path.join(directory, f"{str(0).zfill(6)}.jpg"),
            output.byte(),
        )
        logger.info(f"Outputs are saved in {directory}")
