# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Dict, Optional

import torch
import torch.nn.functional as F
from torch import nn

from kornia.constants import pi
from kornia.core.check import KORNIA_CHECK_LAF, KORNIA_CHECK_SHAPE
from kornia.filters import SpatialGradient, get_gaussian_discrete_kernel1d, get_gaussian_kernel2d
from kornia.geometry import rad2deg

from .laf import extract_patches_from_pyramid, get_laf_orientation, set_laf_orientation

urls: Dict[str, str] = {}
urls["orinet"] = "https://github.com/ducha-aiki/affnet/raw/master/pretrained/OriNet.pth"


class PassLAF(nn.Module):
    """Dummy module to use instead of local feature orientation or affine shape estimator."""

    def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
        """Run forward.

        Args:
            laf: :math:`(B, N, 2, 3)`
            img: :math:`(B, 1, H, W)`

        Returns:
            LAF, unchanged :math:`(B, N, 2, 3)`

        """
        return laf


class PatchDominantGradientOrientation(nn.Module):
    """Module, which estimates the dominant gradient orientation of the given patches, in radians.

    Zero angle points towards right.

    Args:
        patch_size: size of the (square) input patch.
        num_angular_bins: number of histogram bins.
        eps: for safe division, and arctan.

    """

    def __init__(self, patch_size: int = 32, num_angular_bins: int = 36, eps: float = 1e-8) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.num_ang_bins = num_angular_bins
        self.gradient = SpatialGradient("sobel", 1)
        self.eps = eps
        self.angular_smooth = nn.Conv1d(1, 1, kernel_size=5, padding=2, bias=False, padding_mode="circular")
        with torch.no_grad():
            self.angular_smooth.weight[:] = get_gaussian_discrete_kernel1d(5, 1.6)
        sigma: float = float(self.patch_size) / 6.0
        self.weighting = get_gaussian_kernel2d((self.patch_size, self.patch_size), (sigma, sigma), True)

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}(patch_size={self.patch_size}, num_ang_bins={self.num_ang_bins}, eps={self.eps})"
        )

    def forward(self, patch: torch.Tensor) -> torch.Tensor:
        """Run forward.

        Args:
            patch: :math:`(B, 1, H, W)`

        Returns:
            angle in radians: :math:`(B)`

        """
        KORNIA_CHECK_SHAPE(patch, ["B", "1", "H", "W"])
        _, CH, W, H = patch.size()
        if (W != self.patch_size) or (H != self.patch_size) or (CH != 1):
            raise TypeError(
                f"input shape should be must be [Bx1x{self.patch_size}x{self.patch_size}]. Got {patch.size()}"
            )
        self.weighting = self.weighting.to(patch.dtype).to(patch.device)
        self.angular_smooth = self.angular_smooth.to(patch.dtype).to(patch.device)
        grads: torch.Tensor = self.gradient(patch)
        # unpack the edges
        gx: torch.Tensor = grads[:, :, 0]
        gy: torch.Tensor = grads[:, :, 1]

        mag: torch.Tensor = torch.sqrt(gx * gx + gy * gy + self.eps) * self.weighting
        ori: torch.Tensor = torch.atan2(gy, gx + self.eps) + 2.0 * pi

        o_big = float(self.num_ang_bins) * (ori + 1.0 * pi) / (2.0 * pi)
        bo0_big = torch.floor(o_big)
        wo1_big = o_big - bo0_big
        bo0_big = bo0_big % self.num_ang_bins
        bo1_big = (bo0_big + 1) % self.num_ang_bins
        wo0_big = (1.0 - wo1_big) * mag
        wo1_big = wo1_big * mag
        ang_bins_list = []
        for i in range(0, self.num_ang_bins):
            ang_bins_i = F.adaptive_avg_pool2d(
                (bo0_big == i).to(patch.dtype) * wo0_big + (bo1_big == i).to(patch.dtype) * wo1_big, (1, 1)
            )
            ang_bins_list.append(ang_bins_i)
        ang_bins = torch.cat(ang_bins_list, 1).view(-1, 1, self.num_ang_bins)
        ang_bins = self.angular_smooth(ang_bins).view(-1, self.num_ang_bins)
        values, indices = ang_bins.max(1)
        indices_left = (self.num_ang_bins + indices - 1) % self.num_ang_bins
        indices_right = (indices + 1) % self.num_ang_bins
        left = torch.gather(ang_bins, 1, indices_left.reshape(-1, 1)).reshape(-1)
        center = values
        right = torch.gather(ang_bins, 1, indices_right.reshape(-1, 1)).reshape(-1)
        c_subpix = 0.5 * (left - right) / (left + right - 2.0 * center)
        angle = -((2.0 * pi * (indices.to(patch.dtype) + c_subpix) / float(self.num_ang_bins)) - pi)
        return angle


class OriNet(nn.Module):
    """Network, which estimates the canonical orientation of the given 32x32 patches, in radians.

    Zero angle points towards right. This is based on the original code from paper
    "Repeatability Is Not Enough: Learning Discriminative Affine Regions via Discriminability"".
    See :cite:`AffNet2018` for more details.

    Args:
        pretrained: Download and set pretrained weights to the model.
        eps: to avoid division by zero in atan2.

    Returns:
        Angle in radians.

    Shape:
        - Input: (B, 1, 32, 32)
        - Output: (B)

    Examples:
        >>> input = torch.rand(16, 1, 32, 32)
        >>> orinet = OriNet()
        >>> angle = orinet(input) # 16

    """

    def __init__(self, pretrained: bool = False, eps: float = 1e-8) -> None:
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(16, affine=False),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16, affine=False),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32, affine=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Conv2d(64, 2, kernel_size=8, stride=1, padding=1, bias=True),
            nn.Tanh(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.eps = eps
        # use torch.hub to load pretrained model
        if pretrained:
            pretrained_dict = torch.hub.load_state_dict_from_url(urls["orinet"], map_location=torch.device("cpu"))
            self.load_state_dict(pretrained_dict["state_dict"], strict=False)
        self.eval()

    @staticmethod
    def _normalize_input(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        """Utility function that normalizes the input by batch."""
        sp, mp = torch.std_mean(x, dim=(-3, -2, -1), keepdim=True)
        # WARNING: we need to .detach() input, otherwise the gradients produced by
        # the patches extractor with F.grid_sample are very noisy, making the detector
        # training totally unstable.
        return (x - mp.detach()) / (sp.detach() + eps)

    def forward(self, patch: torch.Tensor) -> torch.Tensor:
        """Run forward.

        Args:
            patch: :math:`(B, 1, H, W)`

        Returns:
            angle in radians: :math:`(B)`

        """
        xy = self.features(self._normalize_input(patch)).view(-1, 2)
        angle = torch.atan2(xy[:, 0] + 1e-8, xy[:, 1] + self.eps)
        return angle


class LAFOrienter(nn.Module):
    """Module, which extracts patches using input images and local affine frames (LAFs).

    Then runs :class:`~kornia.feature.PatchDominantGradientOrientation` or
    :class:`~kornia.feature.OriNet` on patches and then rotates the LAFs by the estimated angles

    Args:
        patch_size:
        num_angular_bins:
        angle_detector: Patch orientation estimator, e.g. :class:`~kornia.feature.PatchDominantGradientOrientation`
          or OriNet.

    """  # pylint: disable

    def __init__(
        self, patch_size: int = 32, num_angular_bins: int = 36, angle_detector: Optional[nn.Module] = None
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.num_ang_bins = num_angular_bins
        self.angle_detector: nn.Module
        if angle_detector is None:
            self.angle_detector = PatchDominantGradientOrientation(self.patch_size, self.num_ang_bins)
        else:
            self.angle_detector = angle_detector

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(patch_size={self.patch_size}, angle_detector={self.angle_detector})"

    def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
        """Run forward.

        Args:
            laf: :math:`(B, N, 2, 3)`
            img: :math:`(B, 1, H, W)`

        Returns:
            LAF_out: :math:`(B, N, 2, 3)`

        """
        KORNIA_CHECK_LAF(laf)
        KORNIA_CHECK_SHAPE(img, ["B", "C", "H", "W"])
        if laf.size(0) != img.size(0):
            raise ValueError(f"Batch size of laf and img should be the same. Got {img.size(0)}, {laf.size(0)}")
        B, N = laf.shape[:2]
        patches: torch.Tensor = extract_patches_from_pyramid(img, laf, self.patch_size).view(
            -1, 1, self.patch_size, self.patch_size
        )
        angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N)
        prev_angle = get_laf_orientation(laf).view_as(angles_radians)
        laf_out: torch.Tensor = set_laf_orientation(laf, rad2deg(angles_radians) + prev_angle)
        return laf_out
