# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Tuple, Union

import torch.nn.functional as F

from kornia.core import Tensor, concatenate, tensor
from kornia.core.check import KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
from kornia.filters import filter2d
from kornia.filters.kernels import get_gaussian_kernel2d
from kornia.utils import create_meshgrid

__all__ = ["elastic_transform2d"]


def elastic_transform2d(
    image: Tensor,
    noise: Tensor,
    kernel_size: Tuple[int, int] = (63, 63),
    sigma: Union[Tuple[float, float], Tensor] = (32.0, 32.0),
    alpha: Union[Tuple[float, float], Tensor] = (1.0, 1.0),
    align_corners: bool = False,
    mode: str = "bilinear",
    padding_mode: str = "zeros",
) -> Tensor:
    r"""Apply elastic transform of images as described in :cite:`Simard2003BestPF`.

    .. image:: _static/img/elastic_transform2d.png

    Args:
        image: Input image to be transformed with shape :math:`(B, C, H, W)`.
        noise: Noise image used to spatially transform the input image. Same
          resolution as the input image with shape :math:`(B, 2, H, W)`. The coordinates order
          it is expected to be in x-y.
        kernel_size: the size of the Gaussian kernel.
        sigma: The standard deviation of the Gaussian in the y and x directions,
          respectively. Larger sigma results in smaller pixel displacements.
        alpha : The scaling factor that controls the intensity of the deformation
          in the y and x directions, respectively.
        align_corners: Interpolation flag used by ```grid_sample```.
        mode: Interpolation mode used by ```grid_sample```. Either ``'bilinear'`` or ``'nearest'``.
        padding_mode: The padding used by ```grid_sample```. Either ``'zeros'``, ``'border'`` or ``'refection'``.

    Returns:
        the elastically transformed input image with shape :math:`(B,C,H,W)`.

    Example:
        >>> image = torch.rand(1, 3, 5, 5)
        >>> noise = torch.rand(1, 2, 5, 5, requires_grad=True)
        >>> image_hat = elastic_transform2d(image, noise, (3, 3))
        >>> image_hat.mean().backward()

        >>> image = torch.rand(1, 3, 5, 5)
        >>> noise = torch.rand(1, 2, 5, 5)
        >>> sigma = torch.tensor([4., 4.], requires_grad=True)
        >>> image_hat = elastic_transform2d(image, noise, (3, 3), sigma)
        >>> image_hat.mean().backward()

        >>> image = torch.rand(1, 3, 5, 5)
        >>> noise = torch.rand(1, 2, 5, 5)
        >>> alpha = torch.tensor([16., 32.], requires_grad=True)
        >>> image_hat = elastic_transform2d(image, noise, (3, 3), alpha=alpha)
        >>> image_hat.mean().backward()

    """
    KORNIA_CHECK_IS_TENSOR(image)
    KORNIA_CHECK_IS_TENSOR(noise)
    KORNIA_CHECK_SHAPE(image, ["B", "C", "H", "W"])
    KORNIA_CHECK_SHAPE(noise, ["B", "C", "H", "W"])

    device, dtype = image.device, image.dtype
    # if isinstance(sigma, tuple):
    #    sigma_t = tensor(sigma, device=device, dtype=dtype)
    if isinstance(sigma, Tensor):
        sigma = sigma.expand(2)[None, ...]
    #        sigma = sigma.to(device=device, dtype=dtype)

    # Get Gaussian kernel for 'y' and 'x' displacement
    kernel_x = get_gaussian_kernel2d(kernel_size, sigma)  # _t[0].expand(2).unsqueeze(0))
    kernel_y = get_gaussian_kernel2d(kernel_size, sigma)  # _t[1].expand(2).unsqueeze(0))

    if isinstance(alpha, Tensor):
        alpha_x = alpha[0]
        alpha_y = alpha[1]
    else:
        alpha_x = tensor(alpha[0], device=device, dtype=dtype)
        alpha_y = tensor(alpha[1], device=device, dtype=dtype)

    # Convolve over a random displacement matrix and scale them with 'alpha'
    disp_x = noise[:, :1]
    disp_y = noise[:, 1:]

    disp_x = filter2d(disp_x, kernel=kernel_y, border_type="constant") * alpha_x
    disp_y = filter2d(disp_y, kernel=kernel_x, border_type="constant") * alpha_y

    # stack and normalize displacement
    disp = concatenate([disp_x, disp_y], 1).permute(0, 2, 3, 1)

    # Warp image based on displacement matrix
    _, _, h, w = image.shape
    grid = create_meshgrid(h, w, device=image.device).to(image.dtype)
    warped = F.grid_sample(
        image, (grid + disp).clamp(-1, 1), align_corners=align_corners, mode=mode, padding_mode=padding_mode
    )

    return warped
