# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from pathlib import Path
from typing import Any

import torch
from torch.utils.dlpack import from_dlpack, to_dlpack

import kornia.color
from kornia.core import Device, Dtype, Tensor
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
from kornia.image.base import ChannelsOrder, ColorSpace, ImageLayout, ImageSize, PixelFormat
from kornia.io.io import ImageLoadType, load_image, write_image
from kornia.utils.image_print import image_to_string

# placeholder for numpy
np_ndarray = Any
DLPack = Any


class Image:
    r"""Class that holds an Image Tensor representation.

    .. note::

        Disclaimer: This class provides the minimum functionality for image manipulation. However, as soon
        as you start to experiment with advanced tensor manipulation, you might expect fancy
        polymorphic behaviours.

    .. warning::

        This API is experimental and might suffer changes in the future.

    Args:
        data: a torch tensor containing the image data.
        layout: a dataclass containing the image layout information.

    Examples:
        >>> # from a torch.tensor
        >>> data = torch.randint(0, 255, (3, 4, 5), dtype=torch.uint8)  # CxHxW
        >>> pixel_format = PixelFormat(
        ...     color_space=ColorSpace.RGB,
        ...     bit_depth=8,
        ... )
        >>> layout = ImageLayout(
        ...     image_size=ImageSize(4, 5),
        ...     channels=3,
        ...     channels_order=ChannelsOrder.CHANNELS_FIRST,
        ... )
        >>> img = Image(data, pixel_format, layout)
        >>> assert img.channels == 3

        >>> # from a numpy array (like opencv)
        >>> data = np.ones((4, 5, 3), dtype=np.uint8)  # HxWxC
        >>> img = Image.from_numpy(data, color_space=ColorSpace.RGB)
        >>> assert img.channels == 3
        >>> assert img.width == 5
        >>> assert img.height == 4

    """

    def __init__(self, data: Tensor, pixel_format: PixelFormat, layout: ImageLayout) -> None:
        """Image constructor.

        Args:
            data: a torch tensor containing the image data.
            pixel_format: the pixel format of the image.
            layout: a dataclass containing the image layout information.

        """
        # TODO: move this to a function KORNIA_CHECK_IMAGE_LAYOUT
        if layout.channels_order == ChannelsOrder.CHANNELS_FIRST:
            shape = [str(layout.channels), str(layout.image_size.height), str(layout.image_size.width)]
        elif layout.channels_order == ChannelsOrder.CHANNELS_LAST:
            shape = [str(layout.image_size.height), str(layout.image_size.width), str(layout.channels)]
        else:
            raise NotImplementedError(f"Layout {layout.channels_order} not implemented.")

        KORNIA_CHECK_SHAPE(data, shape)
        KORNIA_CHECK(data.element_size() == pixel_format.bit_depth // 8, "Invalid bit depth.")

        self._data = data
        self._pixel_format = pixel_format
        self._layout = layout

    def __repr__(self) -> str:
        return f"Image data: {self.data}\nPixel Format: {self.pixel_format}\n Layout: {self.layout}"

    # TODO: explore use TensorWrapper
    def to(self, device: Device = None, dtype: Dtype = None) -> Image:
        """Move the image to the given device and dtype.

        Args:
            device: the device to move the image to.
            dtype: the data type to cast the image to.

        Returns:
            Image: the image moved to the given device and dtype.

        """
        if device is not None and isinstance(device, torch.dtype):
            dtype, device = device, None
        # put the data to the device and dtype
        self._data = self.data.to(device, dtype)
        return self

    # TODO: explore use TensorWrapper
    def clone(self) -> Image:
        """Return a copy of the image."""
        return Image(self.data.clone(), self.pixel_format, self.layout)

    @property
    def data(self) -> Tensor:
        """Return the underlying tensor data."""
        return self._data

    @property
    def shape(self) -> tuple[int, ...]:
        """Return the image shape."""
        return self.data.shape

    @property
    def dtype(self) -> torch.dtype:
        """Return the image data type."""
        return self.data.dtype

    @property
    def device(self) -> torch.device:
        """Return the image device."""
        return self.data.device

    @property
    def pixel_format(self) -> PixelFormat:
        """Return the pixel format."""
        return self._pixel_format

    @property
    def layout(self) -> ImageLayout:
        """Return the image layout."""
        return self._layout

    @property
    def channels(self) -> int:
        """Return the number channels of the image."""
        return self.layout.channels

    @property
    def image_size(self) -> ImageSize:
        """Return the image size."""
        return self.layout.image_size

    @property
    def height(self) -> int:
        """Return the image height (columns)."""
        return int(self.layout.image_size.height)

    @property
    def width(self) -> int:
        """Return the image width (rows)."""
        return int(self.layout.image_size.width)

    @property
    def channels_order(self) -> ChannelsOrder:
        """Return the channels order."""
        return self.layout.channels_order

    # TODO: figure out a better way map this function
    def float(self) -> Image:
        """Return the image as float."""
        self._data = self.data.float()
        return self

    def to_gray(self) -> Image:
        """Converts the image to grayscale."""
        src = self._pixel_format.color_space
        data = self._data

        if src == ColorSpace.GRAY:
            return self

        is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
        if is_channels_last:
            data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)

        # Perform the color space conversion
        if src == ColorSpace.RGB:
            out = kornia.color.rgb_to_grayscale(data)
        elif src == ColorSpace.BGR:
            out = kornia.color.bgr_to_grayscale(data)
        else:
            raise ValueError(f"Unsupported source color space for to_gray(): {src}")

        if is_channels_last:
            if out.ndim == 4:
                out = out.permute(0, 2, 3, 1)
            elif out.ndim == 3:
                out = out.permute(1, 2, 0)
            else:
                raise ValueError(f"Unexpected shape after grayscale conversion: {out.shape}")

        new_pf = PixelFormat(color_space=ColorSpace.GRAY, bit_depth=self._pixel_format.bit_depth)
        new_layout = ImageLayout(self._layout.image_size, channels=1, channels_order=self._layout.channels_order)
        return Image(out, new_pf, new_layout)

    def to_rgb(self) -> Image:
        """Converts the image to RGB."""
        src = self._pixel_format.color_space
        data = self._data

        if src == ColorSpace.RGB:
            return self

        is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
        if is_channels_last:
            data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)

        if src == ColorSpace.GRAY:
            out = kornia.color.grayscale_to_rgb(data)
        elif src == ColorSpace.BGR:
            out = data[:, [2, 1, 0], ...] if data.ndim == 4 else data[[2, 1, 0], ...]
        else:
            raise ValueError(f"Unsupported source color space for to_rgb(): {src}")

        if is_channels_last:
            out = out.permute(0, 2, 3, 1) if out.ndim == 4 else out.permute(1, 2, 0)

        new_pf = PixelFormat(color_space=ColorSpace.RGB, bit_depth=self._pixel_format.bit_depth)
        new_layout = ImageLayout(self._layout.image_size, channels=3, channels_order=self._layout.channels_order)
        return Image(out, new_pf, new_layout)

    def to_bgr(self) -> Image:
        """Converts the image to BGR."""
        src = self._pixel_format.color_space
        data = self._data

        if src == ColorSpace.BGR:
            return self

        is_channels_last = self._layout.channels_order == ChannelsOrder.CHANNELS_LAST
        if is_channels_last:
            data = data.permute(0, 3, 1, 2) if data.ndim == 4 else data.permute(2, 0, 1)

        if src == ColorSpace.GRAY:
            rgb_data = kornia.color.grayscale_to_rgb(data)
            out = rgb_data[:, [2, 1, 0], ...] if rgb_data.ndim == 4 else rgb_data[[2, 1, 0], ...]
        elif src == ColorSpace.RGB:
            out = data[:, [2, 1, 0], ...] if data.ndim == 4 else data[[2, 1, 0], ...]
        else:
            raise ValueError(f"Unsupported source color space for to_bgr(): {src}")

        if is_channels_last:
            out = out.permute(0, 2, 3, 1) if out.ndim == 4 else out.permute(1, 2, 0)

        new_pf = PixelFormat(color_space=ColorSpace.BGR, bit_depth=self._pixel_format.bit_depth)
        new_layout = ImageLayout(self._layout.image_size, channels=3, channels_order=self._layout.channels_order)
        return Image(out, new_pf, new_layout)

    @classmethod
    def from_numpy(
        cls,
        data: np_ndarray,
        color_space: ColorSpace = ColorSpace.RGB,
        channels_order: ChannelsOrder = ChannelsOrder.CHANNELS_LAST,
    ) -> Image:
        """Construct an image tensor from a numpy array.

        Args:
            data: a numpy array containing the image data.
            color_space: the color space of the image.
            pixel_format: the pixel format of the image.
            channels_order: what dimension the channels are in the image tensor.

        Example:
            >>> data = np.ones((4, 5, 3), dtype=np.uint8)  # HxWxC
            >>> img = Image.from_numpy(data, color_space=ColorSpace.RGB)
            >>> assert img.channels == 3
            >>> assert img.width == 5
            >>> assert img.height == 4

        """
        if channels_order == ChannelsOrder.CHANNELS_LAST:
            image_size = ImageSize(height=data.shape[0], width=data.shape[1])
            channels = data.shape[2]
        elif channels_order == ChannelsOrder.CHANNELS_FIRST:
            image_size = ImageSize(height=data.shape[1], width=data.shape[2])
            channels = data.shape[0]
        else:
            raise ValueError("channels_order must be either `CHANNELS_LAST` or `CHANNELS_FIRST`")

        # create the pixel format based on the input data
        pixel_format = PixelFormat(color_space=color_space, bit_depth=data.itemsize * 8)

        # create the image layout based on the input data
        layout = ImageLayout(image_size=image_size, channels=channels, channels_order=channels_order)

        # create the image tensor
        return cls(torch.from_numpy(data), pixel_format, layout)

    def to_numpy(self) -> np_ndarray:
        """Return a numpy array in cpu from the image tensor."""
        return self.data.cpu().detach().numpy()

    @classmethod
    def from_dlpack(cls, data: DLPack) -> Image:
        """Construct an image tensor from a DLPack capsule.

        Args:
            data: a DLPack capsule from numpy, tvm or jax.

        Example:
            >>> x = np.ones((4, 5, 3))
            >>> img = Image.from_dlpack(x.__dlpack__())

        """
        _data: Tensor = from_dlpack(data)

        pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=_data.element_size() * 8)

        # create the image layout based on the input data
        layout = ImageLayout(
            image_size=ImageSize(height=_data.shape[1], width=_data.shape[2]),
            channels=_data.shape[0],
            channels_order=ChannelsOrder.CHANNELS_FIRST,
        )

        return cls(_data, pixel_format, layout)

    def to_dlpack(self) -> DLPack:
        """Return a DLPack capsule from the image tensor."""
        return to_dlpack(self.data)

    @classmethod
    def from_file(cls, file_path: str | Path) -> Image:
        """Construct an image tensor from a file.

        Args:
            file_path: the path to the file to read the image from.

        """
        # TODO: allow user to specify the desired type and device
        data: Tensor = load_image(file_path, desired_type=ImageLoadType.RGB8, device="cpu")

        pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=data.element_size() * 8)

        layout = ImageLayout(
            image_size=ImageSize(height=data.shape[1], width=data.shape[2]),
            channels=data.shape[0],
            channels_order=ChannelsOrder.CHANNELS_FIRST,
        )
        return cls(data, pixel_format, layout)

    def write(self, file_path: str | Path) -> None:
        """Write the image to a file.

        For now, only support writing to JPEG format.

        Args:
            file_path: the path to the file to write the image to.

        Example:
            >>> data = np.ones((4, 5, 3), dtype=np.uint8)  # HxWxC
            >>> img = Image.from_numpy(data)
            >>> img.write("test.jpg")

        """
        data = self.data
        if self.channels_order == ChannelsOrder.CHANNELS_LAST:
            data = data.permute(2, 0, 1)
        write_image(file_path, data)

    def print(self, max_width: int = 256) -> None:
        """Print the image tensor to the console.

        Args:
            max_width: the maximum width of the image to print.

        .. code-block:: python

            img = Image.from_file("panda.png")
            img.print()

        .. image:: https://github.com/kornia/data/blob/main/print_image.png?raw=true

        """
        print(image_to_string(self.data, max_width))
