# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import nn
from typing_extensions import TypedDict

from kornia.core import Module, Tensor, concatenate
from kornia.filters import SpatialGradient
from kornia.geometry.transform import pyrdown

from .scale_space_detector import Detector_config, MultiResolutionDetector, get_default_detector_config


class KeyNet_conf(TypedDict):
    num_filters: int
    num_levels: int
    kernel_size: int
    Detector_conf: Detector_config


keynet_default_config: KeyNet_conf = {
    # Key.Net Model
    "num_filters": 8,
    "num_levels": 3,
    "kernel_size": 5,
    # Extraction Parameters
    "Detector_conf": get_default_detector_config(),
}

KeyNet_URL = "https://github.com/axelBarroso/Key.Net-Pytorch/raw/main/model/weights/keynet_pytorch.pth"


class _FeatureExtractor(Module):
    """Helper class for KeyNet.

    It loads both, the handcrafted and learnable blocks
    """

    def __init__(self) -> None:
        super().__init__()

        self.hc_block = _HandcraftedBlock()
        self.lb_block = _LearnableBlock()

    def forward(self, x: Tensor) -> Tensor:
        x_hc = self.hc_block(x)
        x_lb = self.lb_block(x_hc)
        return x_lb


class _HandcraftedBlock(Module):
    """Helper class for KeyNet, it defines the handcrafted filters within the Key.Net handcrafted block."""

    def __init__(self) -> None:
        super().__init__()
        self.spatial_gradient = SpatialGradient("sobel", 1)

    def forward(self, x: Tensor) -> Tensor:
        sobel = self.spatial_gradient(x)
        dx, dy = sobel[:, :, 0, :, :], sobel[:, :, 1, :, :]

        sobel_dx = self.spatial_gradient(dx)
        dxx, dxy = sobel_dx[:, :, 0, :, :], sobel_dx[:, :, 1, :, :]

        sobel_dy = self.spatial_gradient(dy)
        dyy = sobel_dy[:, :, 1, :, :]

        hc_feats = concatenate([dx, dy, dx**2.0, dy**2.0, dx * dy, dxy, dxy**2.0, dxx, dyy, dxx * dyy], 1)

        return hc_feats


class _LearnableBlock(nn.Sequential):
    """Helper class for KeyNet.

    It defines the learnable blocks within the Key.Net
    """

    def __init__(self, in_channels: int = 10) -> None:
        super().__init__()

        self.conv0 = _KeyNetConvBlock(in_channels)
        self.conv1 = _KeyNetConvBlock()
        self.conv2 = _KeyNetConvBlock()

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv2(self.conv1(self.conv0(x)))
        return x


def _KeyNetConvBlock(
    in_channels: int = 8,
    out_channels: int = 8,
    kernel_size: int = 5,
    stride: int = 1,
    padding: int = 2,
    dilation: int = 1,
) -> nn.Sequential:
    """Create KeyNet Conv Block.

    Default learnable convolutional block for KeyNet.
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


class KeyNet(Module):
    """Key.Net model definition -- local feature detector (response function).

    This is based on the original code
    from paper "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters". See :cite:`KeyNet2019` for
    more details.

    .. image:: _static/img/KeyNet.png

    Args:
        pretrained: Download and set pretrained weights to the model.
        keynet_conf: Dict with initialization parameters. Do not pass it, unless you know what you are doing`.

    Returns:
        KeyNet response score.

    Shape:
        - Input: :math:`(B, 1, H, W)`
        - Output: :math:`(B, 1, H, W)`

    """

    def __init__(self, pretrained: bool = False, keynet_conf: KeyNet_conf = keynet_default_config) -> None:
        super().__init__()

        num_filters = keynet_conf["num_filters"]
        self.num_levels = keynet_conf["num_levels"]
        kernel_size = keynet_conf["kernel_size"]
        padding = kernel_size // 2

        self.feature_extractor = _FeatureExtractor()
        self.last_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=num_filters * self.num_levels, out_channels=1, kernel_size=kernel_size, padding=padding
            ),
            nn.ReLU(inplace=True),
        )
        # use torch.hub to load pretrained model
        if pretrained:
            pretrained_dict = torch.hub.load_state_dict_from_url(KeyNet_URL, map_location=torch.device("cpu"))
            self.load_state_dict(pretrained_dict["state_dict"], strict=True)
        self.eval()

    def forward(self, x: Tensor) -> Tensor:
        """X - input image."""
        shape_im = x.shape
        feats: List[Tensor] = [self.feature_extractor(x)]
        for _ in range(1, self.num_levels):
            x = pyrdown(x, factor=1.2)
            feats_i = self.feature_extractor(x)
            feats_i = F.interpolate(feats_i, size=(shape_im[2], shape_im[3]), mode="bilinear")
            feats.append(feats_i)
        scores = self.last_conv(concatenate(feats, 1))
        return scores


class KeyNetDetector(MultiResolutionDetector):
    """Multi-scale feature detector based on KeyNet.

    This is based on the original code from paper
    "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters".
    See :cite:`KeyNet2019` for more details.

    .. image:: _static/img/keynet.jpg

    Args:
        pretrained: Download and set pretrained weights to the model.
        num_features: Number of features to detect.
        keynet_conf: Dict with initialization parameters. Do not pass it, unless you know what you are doing`.
        ori_module: for local feature orientation estimation. Default: :class:`~kornia.feature.PassLAF`,
           which does nothing. See :class:`~kornia.feature.LAFOrienter` for details.
        aff_module: for local feature affine shape estimation. Default: :class:`~kornia.feature.PassLAF`,
            which does nothing. See :class:`~kornia.feature.LAFAffineShapeEstimator` for details.

    """

    def __init__(
        self,
        pretrained: bool = False,
        num_features: int = 2048,
        keynet_conf: KeyNet_conf = keynet_default_config,
        ori_module: Optional[Module] = None,
        aff_module: Optional[Module] = None,
    ) -> None:
        model = KeyNet(pretrained, keynet_conf)
        super().__init__(model, num_features, keynet_conf["Detector_conf"], ori_module, aff_module)
