# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from functools import wraps
from typing import Any, Callable, List, Optional

import torch
from torch import nn
from torch.nn import functional as F

from kornia.core import Tensor


def image_to_tensor(image: Any, keepdim: bool = True) -> Tensor:
    """Convert a numpy image to a PyTorch 4d tensor image.

    Args:
        image: image of the form :math:`(H, W, C)`, :math:`(H, W)` or
            :math:`(B, H, W, C)`.
        keepdim: If ``False`` unsqueeze the input image to match the shape
            :math:`(B, H, W, C)`.

    Returns:
        tensor of the form :math:`(B, C, H, W)` if keepdim is ``False``,
            :math:`(C, H, W)` otherwise.

    Example:
        >>> img = np.ones((3, 3))
        >>> image_to_tensor(img).shape
        torch.Size([1, 3, 3])

        >>> img = np.ones((4, 4, 1))
        >>> image_to_tensor(img).shape
        torch.Size([1, 4, 4])

        >>> img = np.ones((4, 4, 3))
        >>> image_to_tensor(img, keepdim=False).shape
        torch.Size([1, 3, 4, 4])

    """
    if len(image.shape) > 4 or len(image.shape) < 2:
        raise ValueError("Input size must be a two, three or four dimensional array")

    input_shape = image.shape
    tensor: Tensor = torch.from_numpy(image)

    if len(input_shape) == 2:
        # (H, W) -> (1, H, W)
        tensor = tensor.unsqueeze(0)
    elif len(input_shape) == 3:
        # (H, W, C) -> (C, H, W)
        tensor = tensor.permute(2, 0, 1)
    elif len(input_shape) == 4:
        # (B, H, W, C) -> (B, C, H, W)
        tensor = tensor.permute(0, 3, 1, 2)
        keepdim = True  # no need to unsqueeze
    else:
        raise ValueError(f"Cannot process image with shape {input_shape}")

    return tensor.unsqueeze(0) if not keepdim else tensor


def image_list_to_tensor(images: List[Any]) -> Tensor:
    """Convert a list of numpy images to a PyTorch 4d tensor image.

    Args:
        images: list of images, each of the form :math:`(H, W, C)`.
        Image shapes must be consistent

    Returns:
        tensor of the form :math:`(B, C, H, W)`.

    Example:
        >>> imgs = [np.ones((4, 4, 1)), np.zeros((4, 4, 1))]
        >>> image_list_to_tensor(imgs).shape
        torch.Size([2, 1, 4, 4])

    """
    if not images:
        raise ValueError("Input list of images is empty")

    images_t = []
    for img in images:
        if not torch.is_tensor(img):
            img = torch.as_tensor(img)
        images_t.append(img)

    shape = images_t[0].shape
    if len(shape) != 3:
        raise ValueError("Each image must have shape (H, W, C)")
    if any(img.shape != shape for img in images_t):
        raise ValueError("All images must have the same shape")

    # Stack into (N, H, W, C) then permute to (N, C, H, W)
    return torch.stack(images_t, dim=0).permute(0, 3, 1, 2)


def _to_bchw(tensor: Tensor) -> Tensor:
    """Convert a PyTorch tensor image to BCHW format.

    Args:
        tensor (torch.Tensor): image of the form :math:`(*, H, W)`.

    Returns:
        input tensor of the form :math:`(B, C, H, W)`.

    """
    if not isinstance(tensor, Tensor):
        raise TypeError(f"Input type is not a Tensor. Got {type(tensor)}")

    if len(tensor.shape) < 2:
        raise ValueError(f"Input size must be a two, three or four dimensional tensor. Got {tensor.shape}")

    if len(tensor.shape) == 2:
        tensor = tensor.unsqueeze(0)

    if len(tensor.shape) == 3:
        tensor = tensor.unsqueeze(0)

    if len(tensor.shape) > 4:
        tensor = tensor.view(-1, tensor.shape[-3], tensor.shape[-2], tensor.shape[-1])

    return tensor


def _to_bcdhw(tensor: Tensor) -> Tensor:
    """Convert a PyTorch tensor image to BCDHW format.

    Args:
        tensor (torch.Tensor): image of the form :math:`(*, D, H, W)`.

    Returns:
        input tensor of the form :math:`(B, C, D, H, W)`.

    """
    if not isinstance(tensor, Tensor):
        raise TypeError(f"Input type is not a Tensor. Got {type(tensor)}")

    if len(tensor.shape) < 3:
        raise ValueError(f"Input size must be a three, four or five dimensional tensor. Got {tensor.shape}")

    if len(tensor.shape) == 3:
        tensor = tensor.unsqueeze(0)

    if len(tensor.shape) == 4:
        tensor = tensor.unsqueeze(0)

    if len(tensor.shape) > 5:
        tensor = tensor.view(-1, tensor.shape[-4], tensor.shape[-3], tensor.shape[-2], tensor.shape[-1])

    return tensor


def tensor_to_image(tensor: Tensor, keepdim: bool = False, force_contiguous: bool = False) -> Any:
    """Convert a PyTorch tensor image to a numpy image.

    In case the tensor is in the GPU, it will be copied back to CPU.

    Args:
        tensor: image of the form :math:`(H, W)`, :math:`(C, H, W)` or
            :math:`(B, C, H, W)`.
        keepdim: If ``False`` squeeze the input image to match the shape
            :math:`(H, W, C)` or :math:`(H, W)`.
        force_contiguous: If ``True`` call `contiguous` to the tensor before

    Returns:
        image of the form :math:`(H, W)`, :math:`(H, W, C)` or :math:`(B, H, W, C)`.

    Example:
        >>> img = torch.ones(1, 3, 3)
        >>> tensor_to_image(img).shape
        (3, 3)

        >>> img = torch.ones(3, 4, 4)
        >>> tensor_to_image(img).shape
        (4, 4, 3)

    """
    if not isinstance(tensor, Tensor):
        raise TypeError(f"Input type is not a Tensor. Got {type(tensor)}")

    if len(tensor.shape) > 4 or len(tensor.shape) < 2:
        raise ValueError("Input size must be a two, three or four dimensional tensor")

    input_shape = tensor.shape
    image = tensor.cpu().detach()

    if len(input_shape) == 2:
        # (H, W) -> (H, W)
        pass
    elif len(input_shape) == 3:
        # (C, H, W) -> (H, W, C)
        if input_shape[0] == 1:
            # Grayscale for proper plt.imshow needs to be (H,W)
            image = image.squeeze()
        else:
            image = image.permute(1, 2, 0)
    elif len(input_shape) == 4:
        # (B, C, H, W) -> (B, H, W, C)
        image = image.permute(0, 2, 3, 1)
        if input_shape[0] == 1 and not keepdim:
            image = image.squeeze(0)
        if input_shape[1] == 1:
            image = image.squeeze(-1)
    else:
        raise ValueError(f"Cannot process tensor with shape {input_shape}")

    # make sure the image is contiguous
    if force_contiguous:
        image = image.contiguous()

    return image.numpy()


class ImageToTensor(nn.Module):
    """Converts a numpy image to a PyTorch 4d tensor image.

    Args:
        keepdim: If ``False`` unsqueeze the input image to match the shape :math:`(B, H, W, C)`.

    """

    def __init__(self, keepdim: bool = False) -> None:
        super().__init__()
        self.keepdim = keepdim

    def forward(self, x: Any) -> Tensor:
        return image_to_tensor(x, keepdim=self.keepdim)


def make_grid(tensor: Tensor, n_row: Optional[int] = None, padding: int = 2) -> Tensor:
    """Convert a batched tensor to one image with padding in between.

    Args:
        tensor: A batched tensor of shape (B, C, H, W).
        n_row: Number of images displayed in each row of the grid.
        padding: The amount of padding to add between images.

    Returns:
        Tensor: The combined image grid.

    """
    if not isinstance(tensor, torch.Tensor):
        raise TypeError("Input tensor must be a PyTorch tensor.")

    B, C, H, W = tensor.shape
    if n_row is None:
        n_row = int(torch.sqrt(torch.tensor(B, dtype=torch.float32)).ceil().item())
    n_col = (B + n_row - 1) // n_row

    padded_H = H + padding
    padded_W = W + padding

    # pad each image on right and bottom with `padding` zeros
    tensor_padded = F.pad(tensor, (0, padding, 0, padding))

    total = n_row * n_col
    if total > B:
        pad_tiles = torch.zeros((total - B, C, padded_H, padded_W), dtype=tensor.dtype, device=tensor.device)
        tensor_padded = torch.cat((tensor_padded, pad_tiles), dim=0)

    # ensure contiguous memory layout before reshaping / permuting
    tensor_padded = tensor_padded.contiguous()

    # reshape into (n_row, n_col, C, padded_H, padded_W)
    grid = tensor_padded.view(n_row, n_col, C, padded_H, padded_W)

    # permute to (C, n_row, padded_H, n_col, padded_W) then collapse
    grid = grid.permute(2, 0, 3, 1, 4).contiguous()
    combined = grid.view(C, n_row * padded_H, n_col * padded_W)

    # crop trailing right/bottom padding to match original
    combined_H = n_row * padded_H - padding
    combined_W = n_col * padded_W - padding
    combined = combined[:, :combined_H, :combined_W]

    return combined


def perform_keep_shape_image(f: Callable[..., Tensor]) -> Callable[..., Tensor]:
    """Apply `f` to an image of arbitrary leading dimensions `(*, C, H, W)`.

    It works by first viewing the image as `(B, C, H, W)`, applying the function and re-viewing the image as original
    shape.
    """

    @wraps(f)
    def _wrapper(input: Tensor, *args: Any, **kwargs: Any) -> Tensor:
        if not isinstance(input, Tensor):
            raise TypeError(f"Input input type is not a Tensor. Got {type(input)}")

        if input.shape.numel() == 0:
            raise ValueError("Invalid input tensor, it is empty.")

        input_shape = input.shape
        input = _to_bchw(input)  # view input as (B, C, H, W)
        output = f(input, *args, **kwargs)
        if len(input_shape) == 3:
            output = output[0]

        if len(input_shape) == 2:
            output = output[0, 0]

        if len(input_shape) > 4:
            output = output.view(*(input_shape[:-3] + output.shape[-3:]))

        return output

    return _wrapper


def perform_keep_shape_video(f: Callable[..., Tensor]) -> Callable[..., Tensor]:
    """Apply `f` to an image of arbitrary leading dimensions `(*, C, D, H, W)`.

    It works by first viewing the image as `(B, C, D, H, W)`, applying the function and re-viewing the image as original
    shape.
    """

    @wraps(f)
    def _wrapper(input: Tensor, *args: Any, **kwargs: Any) -> Tensor:
        if not isinstance(input, Tensor):
            raise TypeError(f"Input input type is not a Tensor. Got {type(input)}")

        if input.numel() == 0:
            raise ValueError("Invalid input tensor, it is empty.")

        input_shape = input.shape
        input = _to_bcdhw(input)  # view input as (B, C, D, H, W)
        output = f(input, *args, **kwargs)
        if len(input_shape) == 4:
            output = output[0]

        if len(input_shape) == 3:
            output = output[0, 0]

        if len(input_shape) > 5:
            output = output.view(*(input_shape[:-4] + output.shape[-4:]))

        return output

    return _wrapper
