# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

import torch
from torch import nn

from kornia.core import ones, zeros
from kornia.utils import create_meshgrid
from kornia.utils.helpers import _torch_solve_cast

__all__ = ["get_tps_transform", "warp_image_tps", "warp_points_tps"]

# utilities for computing thin plate spline transforms


def _pair_square_euclidean(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
    r"""Compute the pairwise squared euclidean distance matrices :math:`(B, N, M)` between two tensors.

    Tensors with shapes (B, N, C) and (B, M, C).
    """
    # ||t1-t2||^2 = (t1-t2)^T(t1-t2) = t1^T*t1 + t2^T*t2 - 2*t1^T*t2
    t1_sq: torch.Tensor = tensor1.mul(tensor1).sum(dim=-1, keepdim=True)
    t2_sq: torch.Tensor = tensor2.mul(tensor2).sum(dim=-1, keepdim=True).transpose(1, 2)
    t1_t2: torch.Tensor = tensor1.matmul(tensor2.transpose(1, 2))
    square_dist: torch.Tensor = -2 * t1_t2 + t1_sq + t2_sq
    square_dist = square_dist.clamp(min=0)  # handle possible numerical errors
    return square_dist


def _kernel_distance(squared_distances: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    r"""Compute the TPS kernel distance function: :math:`r^2 log(r)`, where `r` is the euclidean distance.

    Since
    :math: `\log(r) = 1/2 \log(r^2)`, this function takes the squared distance matrix and calculates
    :math: `0.5 r^2 log(r^2)`.
    """
    # r^2 * log(r) = 1/2 * r^2 * log(r^2)
    return 0.5 * squared_distances * squared_distances.add(eps).log()


def get_tps_transform(points_src: torch.Tensor, points_dst: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    r"""Compute the TPS transform parameters that warp source points to target points.

    The input to this function is a tensor of :math:`(x, y)` source points :math:`(B, N, 2)` and a corresponding
    tensor of target :math:`(x, y)` points :math:`(B, N, 2)`.

    Args:
        points_src: batch of source points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors.
        points_dst: batch of target points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors.

    Returns:
        :math:`(B, N, 2)` tensor of kernel weights and :math:`(B, 3, 2)`
            tensor of affine weights. The last dimension contains the x-transform and y-transform weights
            as separate columns.

    Example:
        >>> points_src = torch.rand(1, 5, 2)
        >>> points_dst = torch.rand(1, 5, 2)
        >>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)

    .. note::
        This function is often used in conjunction with :func:`warp_points_tps`, :func:`warp_image_tps`.

    """
    if not isinstance(points_src, torch.Tensor):
        raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")

    if not isinstance(points_dst, torch.Tensor):
        raise TypeError(f"Input points_dst is not torch.Tensor. Got {type(points_dst)}")

    if not len(points_src.shape) == 3:
        raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")

    if not len(points_dst.shape) == 3:
        raise ValueError(f"Invalid shape for points_dst, expected BxNx2. Got {points_dst.shape}")

    device, dtype = points_src.device, points_src.dtype
    batch_size, num_points = points_src.shape[:2]

    # set up and solve linear system
    # [K   P] [w] = [dst]
    # [P^T 0] [a]   [ 0 ]
    pair_distance: torch.Tensor = _pair_square_euclidean(points_src, points_dst)
    k_matrix: torch.Tensor = _kernel_distance(pair_distance)

    zero_mat: torch.Tensor = zeros(batch_size, 3, 3, device=device, dtype=dtype)
    one_mat: torch.Tensor = ones(batch_size, num_points, 1, device=device, dtype=dtype)
    dest_with_zeros: torch.Tensor = torch.cat((points_dst, zero_mat[:, :, :2]), 1)
    p_matrix: torch.Tensor = torch.cat((one_mat, points_src), -1)
    p_matrix_t: torch.Tensor = torch.cat((p_matrix, zero_mat), 1).transpose(1, 2)
    l_matrix: torch.Tensor = torch.cat((k_matrix, p_matrix), -1)
    l_matrix = torch.cat((l_matrix, p_matrix_t), 1)

    weights = _torch_solve_cast(l_matrix, dest_with_zeros)
    kernel_weights: torch.Tensor = weights[:, :-3]
    affine_weights: torch.Tensor = weights[:, -3:]

    return kernel_weights, affine_weights


def warp_points_tps(
    points_src: torch.Tensor, kernel_centers: torch.Tensor, kernel_weights: torch.Tensor, affine_weights: torch.Tensor
) -> torch.Tensor:
    r"""Warp a tensor of coordinate points using the thin plate spline defined by arguments.

    The source points should be a :math:`(B, N, 2)` tensor of :math:`(x, y)` coordinates. The kernel centers are
    a :math:`(B, K, 2)` tensor of :math:`(x, y)` coordinates. The kernel weights are a :math:`(B, K, 2)` tensor,
    and the affine weights are a :math:`(B, 3, 2)` tensor.  For the weight tensors, tensor[..., 0] contains the
    weights for the x-transform and tensor[..., 1] the weights for the y-transform.

    Args:
        points_src: tensor of source points :math:`(B, N, 2)`.
        kernel_centers: tensor of kernel center points :math:`(B, K, 2)`.
        kernel_weights: tensor of kernl weights :math:`(B, K, 2)`.
        affine_weights: tensor of affine weights :math:`(B, 3, 2)`.

    Returns:
        The :math:`(B, N, 2)` tensor of warped source points, from applying the TPS transform.

    Example:
        >>> points_src = torch.rand(1, 5, 2)
        >>> points_dst = torch.rand(1, 5, 2)
        >>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)
        >>> warped = warp_points_tps(points_src, points_dst, kernel_weights, affine_weights)
        >>> warped_correct = torch.allclose(warped, points_dst)

    .. note::
        This function is often used in conjunction with :func:`get_tps_transform`.

    """
    if not isinstance(points_src, torch.Tensor):
        raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")

    if not isinstance(kernel_centers, torch.Tensor):
        raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}")

    if not isinstance(kernel_weights, torch.Tensor):
        raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}")

    if not isinstance(affine_weights, torch.Tensor):
        raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}")

    if not len(points_src.shape) == 3:
        raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")

    if not len(kernel_centers.shape) == 3:
        raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}")

    if not len(kernel_weights.shape) == 3:
        raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}")

    if not len(affine_weights.shape) == 3:
        raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}")

    # f_{x|y}(v) = a_0 + [a_x a_y].v + \sum_i w_i * U(||v-u_i||)
    pair_distance: torch.Tensor = _pair_square_euclidean(points_src, kernel_centers)
    k_matrix: torch.Tensor = _kernel_distance(pair_distance)

    # broadcast the kernel distance matrix against the x and y weights to compute the x and y
    # transforms simultaneously
    k_mul_kernel = k_matrix[..., None].mul(kernel_weights[:, None]).sum(-2)
    points_mul_affine = points_src[..., None].mul(affine_weights[:, None, 1:]).sum(-2)
    warped: torch.Tensor = k_mul_kernel + points_mul_affine + affine_weights[:, None, 0]

    return warped


def warp_image_tps(
    image: torch.Tensor,
    kernel_centers: torch.Tensor,
    kernel_weights: torch.Tensor,
    affine_weights: torch.Tensor,
    align_corners: bool = False,
    padding_mode: str = "zeros",
) -> torch.Tensor:
    r"""Warp an image tensor according to the thin plate spline transform defined by arguments.

    .. image:: _static/img/warp_image_tps.png

    The transform is applied to each pixel coordinate in the output image to obtain a point in the input
    image for interpolation of the output pixel. So the TPS parameters should correspond to a warp from
    output space to input space.

    The input `image` is a :math:`(B, C, H, W)` tensor. The kernel centers, kernel weight and affine weights
    are the same as in `warp_points_tps`.

    Args:
        image: input image tensor :math:`(B, C, H, W)`.
        kernel_centers: kernel center points :math:`(B, K, 2)`.
        kernel_weights: tensor of kernl weights :math:`(B, K, 2)`.
        affine_weights: tensor of affine weights :math:`(B, 3, 2)`.
        align_corners: interpolation flag used by `grid_sample`.
        padding_mode: padding flag used by `grid_sample`.

    Returns:
        warped image tensor :math:`(B, C, H, W)`.

    Example:
        >>> points_src = torch.rand(1, 5, 2)
        >>> points_dst = torch.rand(1, 5, 2)
        >>> image = torch.rand(1, 3, 32, 32)
        >>> # note that we are getting the reverse transform: dst -> src
        >>> kernel_weights, affine_weights = get_tps_transform(points_dst, points_src)
        >>> warped_image = warp_image_tps(image, points_src, kernel_weights, affine_weights)

    .. note::
        This function is often used in conjunction with :func:`get_tps_transform`.

    """
    if not isinstance(image, torch.Tensor):
        raise TypeError(f"Input image is not torch.Tensor. Got {type(image)}")

    if not isinstance(kernel_centers, torch.Tensor):
        raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}")

    if not isinstance(kernel_weights, torch.Tensor):
        raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}")

    if not isinstance(affine_weights, torch.Tensor):
        raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}")

    if not len(image.shape) == 4:
        raise ValueError(f"Invalid shape for image, expected BxCxHxW. Got {image.shape}")

    if not len(kernel_centers.shape) == 3:
        raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}")

    if not len(kernel_weights.shape) == 3:
        raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}")

    if not len(affine_weights.shape) == 3:
        raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}")

    batch_size, _, h, w = image.shape
    coords: torch.Tensor = create_meshgrid(h, w, device=image.device, dtype=image.dtype)
    coords = coords.reshape(-1, 2).expand(batch_size, -1, -1)
    warped: torch.Tensor = warp_points_tps(coords, kernel_centers, kernel_weights, affine_weights)
    warped = warped.view(-1, h, w, 2)
    warped_image: torch.Tensor = nn.functional.grid_sample(
        image, warped, padding_mode=padding_mode, align_corners=align_corners
    )

    return warped_image
