# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from typing import ClassVar, Optional, Union

import torch

import kornia
from kornia.core import Tensor
from kornia.core.external import PILImage as Image
from kornia.models.base import ModelBase

__all__ = ["SemanticSegmentation"]


class SemanticSegmentation(ModelBase):
    """Semantic Segmentation is a module that wraps a semantic segmentation model.

    This module uses SegmentationModel library for semantic segmentation.
    """

    ONNX_DEFAULT_INPUTSHAPE: ClassVar[list[int]] = [-1, 3, -1, -1]
    ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[list[int]] = [-1, -1, -1, -1]

    @torch.inference_mode()
    def forward(self, images: Union[Tensor, list[Tensor]]) -> Union[Tensor, list[Tensor]]:
        """Forward pass of the semantic segmentation model.

        Args:
            images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
                If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.

        Returns:
            output tensor.

        """
        outputs: Union[Tensor, list[Tensor]]

        if isinstance(
            images,
            (
                list,
                tuple,
            ),
        ):
            outputs = []
            for image in images:
                image = self.pre_processor(image[None])
                output = self.model(image)
                output = self.post_processor(output)
                outputs.append(output[0])
        else:
            images = self.pre_processor(images)
            outputs = self.model(images)
            outputs = self.post_processor(outputs)

        return outputs

    def get_colormap(self, num_classes: int, colormap: str = "random", manual_seed: int = 2147) -> Tensor:
        """Get a color map of size num_classes.

        Args:
            num_classes: The number of colors in the color map.
            colormap: The colormap to use, can be "random" or a custom color map.
            manual_seed: The manual seed to use for the colormap.

        Returns:
            A tensor of shape (num_classes, 3) representing the color map.

        """
        if colormap == "random":
            # Generate a color for each class
            g_cpu = torch.Generator()
            g_cpu.manual_seed(manual_seed)
            colors = torch.rand(num_classes, 3, generator=g_cpu)
        else:
            raise ValueError(f"Unsupported colormap: {colormap}")

        return colors

    def visualize_output(self, semantic_mask: Tensor, colors: Tensor) -> Tensor:
        """Visualize the output of the segmentation model.

        Args:
            semantic_mask: The output of the segmentation model. Shape should be (C, H, W) or (B, C, H, W).
            colors: The color map to use for visualizing the output of the segmentation model.
                Shape should be (num_classes, 3).

        Returns:
            A tensor of shape (3, H, W) or (B, 3, H, W) representing the visualized output of the segmentation model.

        Raises:
            ValueError: If the shape of the semantic mask is not of shape (C, H, W) or (B, C, H, W).
            ValueError: If the shape of the colors is not of shape (num_classes, 3).
            ValueError: If only muliclass segmentation is supported. Please ensure a softmax is used, or submit a PR.

        """
        if semantic_mask.dim() == 3:
            channel_dim = 0
        elif semantic_mask.dim() == 4:
            channel_dim = 1
        else:
            raise ValueError(f"Semantic mask must be of shape (C, H, W) or (B, C, H, W), got {semantic_mask.shape}.")

        if torch.allclose(
            semantic_mask.sum(dim=channel_dim), torch.tensor(1, dtype=semantic_mask.dtype, device=semantic_mask.device)
        ):
            # Softmax is used, thus, muliclass segmentation
            semantic_mask = semantic_mask.argmax(dim=channel_dim, keepdim=True)
            # Create a colormap for each pixel based on the class with the highest probability
            output = colors[semantic_mask.squeeze(channel_dim)]
            if semantic_mask.dim() == 3:
                output = output.permute(2, 0, 1)
            elif semantic_mask.dim() == 4:
                output = output.permute(0, 3, 1, 2)
            else:
                raise ValueError(
                    f"Semantic mask must be of shape (C, H, W) or (B, C, H, W), got {semantic_mask.shape}."
                )
        else:
            raise ValueError(
                "Only muliclass segmentation is supported. Please ensure a softmax is used, or submit a PR."
            )

        return output

    def visualize(
        self,
        images: Union[Tensor, list[Tensor]],
        semantic_masks: Optional[Union[Tensor, list[Tensor]]] = None,
        output_type: str = "torch",
        colormap: str = "random",
        manual_seed: int = 2147,
    ) -> Union[Tensor, list[Tensor], list[Image.Image]]:  # type: ignore
        """Visualize the segmentation masks.

        Args:
            images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
                If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
            semantic_masks: If list of segmentation masks. Each mask is a Tensor with shape :math:`(C, H, W)`.
                If Tensor, a Tensor with shape :math:`(B, C, H, W)`.
            output_type: The type of output, can be "torch" or "PIL".
            colormap: The colormap to use, can be "random" or a custom color map.
            manual_seed: The manual seed to use for the colormap.

        """
        if semantic_masks is None:
            semantic_masks = self.forward(images)

        outputs: Union[Tensor, list[Tensor]]
        if isinstance(
            semantic_masks,
            (
                list,
                tuple,
            ),
        ):
            outputs = []
            for semantic_mask in semantic_masks:
                if semantic_mask.ndim != 3:
                    raise ValueError(f"Semantic mask must be of shape (C, H, W), got {semantic_mask.shape}.")
                # Generate a color for each class
                colors = self.get_colormap(semantic_mask.size(0), colormap, manual_seed=manual_seed)
                outputs.append(self.visualize_output(semantic_mask, colors))

        else:
            # Generate a color for each class
            colors = self.get_colormap(semantic_masks.size(1), colormap, manual_seed=manual_seed)
            outputs = self.visualize_output(semantic_masks, colors)

        return self._tensor_to_type(outputs, output_type, is_batch=True if isinstance(outputs, Tensor) else False)

    def save(
        self,
        images: Union[Tensor, list[Tensor]],
        semantic_masks: Optional[Union[Tensor, list[Tensor]]] = None,
        directory: Optional[str] = None,
        output_type: str = "torch",
        colormap: str = "random",
        manual_seed: int = 2147,
    ) -> None:
        """Save the segmentation results.

        Args:
            images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`.
                If Tensor, a Tensor with shape :math:`(B, 3, H, W)`.
            semantic_masks: If list of segmentation masks. Each mask is a Tensor with shape :math:`(C, H, W)`.
                If Tensor, a Tensor with shape :math:`(B, C, H, W)`.
            directory: The directory to save the results.
            output_type: The type of output, can be "torch" or "PIL".
            colormap: The colormap to use, can be "random" or a custom color map.
            manual_seed: The manual seed to use for the colormap.

        """
        colored_masks = self.visualize(images, semantic_masks, output_type, colormap=colormap, manual_seed=manual_seed)
        overlaid: Union[Tensor, list[Tensor]]
        if isinstance(images, Tensor) and isinstance(colored_masks, Tensor):
            overlaid = kornia.enhance.add_weighted(images, 0.5, colored_masks, 0.5, 1.0)
        elif isinstance(
            images,
            (
                list,
                tuple,
            ),
        ) and isinstance(
            colored_masks,
            (
                list,
                tuple,
            ),
        ):
            overlaid = []
            for i in range(len(images)):
                overlaid.append(kornia.enhance.add_weighted(images[i][None], 0.5, colored_masks[i][None], 0.5, 1.0)[0])
        else:
            raise ValueError(f"`images` should be a Tensor or a list of Tensors. Got {type(images)}")

        self._save_outputs(images, directory, suffix="_src")
        self._save_outputs(colored_masks, directory, suffix="_mask")
        self._save_outputs(overlaid, directory, suffix="_overlay")
