# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Implements several backbone networks."""

import functools
import operator
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.functional import pixel_shuffle, softmax

from kornia.core import Module, Tensor


class HourglassConfig(NamedTuple):
    depth: int
    num_stacks: int
    num_blocks: int
    num_classes: int
    input_channels: int
    head: Type[Module]


# [Hourglass backbone classes]
class HourglassBackbone(Module):
    """Hourglass network, taken from https://github.com/zhou13/lcnn.

    Args:
        input_channel: number of input channels.
        depth: number of residual blocks per hourglass module.
        num_stacks: number of hourglass modules stacked together.
        num_blocks: number of layers in each residual block.
        num_classes: number of heads for the output of a hourglass module.

    """

    def __init__(
        self, input_channel: int = 1, depth: int = 4, num_stacks: int = 2, num_blocks: int = 1, num_classes: int = 5
    ) -> None:
        super().__init__()
        self.head = MultitaskHead
        self.net = hg(HourglassConfig(depth, num_stacks, num_blocks, num_classes, input_channel, head=self.head))

    def forward(self, input_images: Tensor) -> Tensor:
        return self.net(input_images)


class MultitaskHead(Module):
    def __init__(self, input_channels: int) -> None:
        super().__init__()

        m = int(input_channels / 4)
        head_size = [[2], [1], [2]]
        heads = []
        _iter: list[int] = functools.reduce(operator.iconcat, head_size, [])
        for output_channels in _iter:
            heads.append(
                nn.Sequential(
                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(m, output_channels, kernel_size=1),
                )
            )
        self.heads = nn.ModuleList(heads)

    def forward(self, x: Tensor) -> Tensor:
        return torch.cat([head(x) for head in self.heads], dim=1)


class Bottleneck2D(Module):
    def __init__(
        self, inplanes: int, planes: int, stride: Union[int, Tuple[int, int]] = 1, downsample: Optional[Module] = None
    ) -> None:
        super().__init__()

        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        return out


class Hourglass(Module):
    def __init__(self, block: Type[Bottleneck2D], num_blocks: int, planes: int, depth: int, expansion: int = 2) -> None:
        super().__init__()
        self.depth = depth
        self.block = block
        self.expansion = expansion
        self.hg = self._make_hour_glass(block, num_blocks, planes, depth)

    def _make_residual(self, block: Type[Bottleneck2D], num_blocks: int, planes: int) -> Module:
        layers = []
        for _ in range(0, num_blocks):
            layers.append(block(planes * self.expansion, planes))
        return nn.Sequential(*layers)

    def _make_hour_glass(self, block: Type[Bottleneck2D], num_blocks: int, planes: int, depth: int) -> nn.ModuleList:
        hgl = []
        for i in range(depth):
            res = []
            for _ in range(3):
                res.append(self._make_residual(block, num_blocks, planes))
            if i == 0:
                res.append(self._make_residual(block, num_blocks, planes))
            hgl.append(nn.ModuleList(res))
        return nn.ModuleList(hgl)

    def _hour_glass_forward(self, n: int, x: Tensor) -> Tensor:
        up1 = self.hg[n - 1][0](x)  # type: ignore[index]
        low1 = F.max_pool2d(x, 2, stride=2)
        low1 = self.hg[n - 1][1](low1)  # type: ignore[index]

        if n > 1:
            low2 = self._hour_glass_forward(n - 1, low1)
        else:
            low2 = self.hg[n - 1][3](low1)  # type: ignore[index]
        low3 = self.hg[n - 1][2](low2)  # type: ignore[index]
        up2 = F.interpolate(low3, size=up1.shape[2:])
        out = up1 + up2
        return out

    def forward(self, x: Tensor) -> Tensor:
        return self._hour_glass_forward(self.depth, x)


class HourglassNet(Module):
    """Hourglass model from Newell et al ECCV 2016."""

    def __init__(
        self,
        block: Type[Bottleneck2D],
        head: Type[Module],
        depth: int,
        num_stacks: int,
        num_blocks: int,
        num_classes: int,
        input_channels: int,
        expansion: int = 2,
    ) -> None:
        super().__init__()

        self.inplanes = 64
        self.num_feats = 128
        self.num_stacks = num_stacks
        self.expansion = expansion
        self.conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_residual(block, self.inplanes, 1)
        self.layer2 = self._make_residual(block, self.inplanes, 1)
        self.layer3 = self._make_residual(block, self.num_feats, 1)
        self.maxpool = nn.MaxPool2d(2, stride=2)

        # Build hourglass modules
        ch = self.num_feats * self.expansion
        hgl, res, fc, score, fc_, score_ = [], [], [], [], [], []
        for i in range(num_stacks):
            hgl.append(Hourglass(block, num_blocks, self.num_feats, depth))
            res.append(self._make_residual(block, self.num_feats, num_blocks))
            fc.append(self._make_fc(ch, ch))
            score.append(head(ch))
            if i < num_stacks - 1:
                fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
                score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
        self.hg = nn.ModuleList(hgl)
        self.res = nn.ModuleList(res)
        self.fc = nn.ModuleList(fc)
        self.score = nn.ModuleList(score)
        self.fc_ = nn.ModuleList(fc_)
        self.score_ = nn.ModuleList(score_)

    def _make_residual(
        self, block: Type[Bottleneck2D], planes: int, blocks: int, stride: Union[int, Tuple[int, int]] = 1
    ) -> Module:
        downsample = None
        if stride != 1 or self.inplanes != planes * self.expansion:
            downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * self.expansion, kernel_size=1, stride=stride))

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * self.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_fc(self, inplanes: int, outplanes: int) -> Module:
        bn = nn.BatchNorm2d(inplanes)
        conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
        return nn.Sequential(conv, bn, self.relu)

    def forward(self, x: Tensor) -> Tensor:
        out = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.maxpool(x)
        x = self.layer2(x)
        x = self.layer3(x)

        for i in range(self.num_stacks):
            y = self.hg[i](x)
            y = self.res[i](y)
            y = self.fc[i](y)
            score = self.score[i](y)
            out.append(score)
            if i < self.num_stacks - 1:
                fc_ = self.fc_[i](y)
                score_ = self.score_[i](score)
                x = x + fc_ + score_

        return y


def hg(cfg: HourglassConfig) -> HourglassNet:
    """Create HourglassNet."""
    return HourglassNet(
        Bottleneck2D,
        head=cfg.head,
        depth=cfg.depth,
        num_stacks=cfg.num_stacks,
        num_blocks=cfg.num_blocks,
        num_classes=cfg.num_classes,
        input_channels=cfg.input_channels,
    )


# [Backbone decoders]
class SuperpointDecoder(Module):
    """Junction decoder based on the SuperPoint architecture.

    Args:
        input_feat_dim: channel size of the input features.

    Returns:
        the junction heatmap, with shape (B, H, W).

    """

    def __init__(self, input_feat_dim: int = 128, grid_size: int = 8) -> None:
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        # Perform strided convolution when using lcnn backbone.
        self.convPa = nn.Conv2d(input_feat_dim, 256, kernel_size=3, stride=2, padding=1)
        self.convPb = nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
        self.grid_size = grid_size

    def forward(self, input_features: Tensor) -> Tensor:
        feat = self.relu(self.convPa(input_features))
        semi = self.convPb(feat)

        # Convert from semi-dense to dense heatmap
        junc_prob = softmax(semi, dim=1)
        junc_pred = pixel_shuffle(junc_prob[:, :-1, :, :], self.grid_size)[:, 0]
        return junc_pred


class PixelShuffleDecoder(Module):
    """Pixel shuffle decoder used to predict the line heatmap.

    Args:
        input_feat_dim: channel size of the input features.
        num_upsample: how many upsamples are performed.
        output_channel: number of output channels.

    Returns:
        the (B, 1, H, W) line heatmap.

    """

    def __init__(self, input_feat_dim: int = 128, num_upsample: int = 2, output_channel: int = 2) -> None:
        super().__init__()
        # Get channel parameters
        self.channel_conf = self.get_channel_conf(num_upsample)

        # Define the pixel shuffle
        self.pixshuffle = nn.PixelShuffle(2)

        # Process the feature
        conv_block_lst = []
        # The input block
        conv_block_lst.append(
            nn.Sequential(
                nn.Conv2d(input_feat_dim, self.channel_conf[0], kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(self.channel_conf[0]),
                nn.ReLU(inplace=True),
            )
        )

        # Intermediate block
        for channel in self.channel_conf[1:-1]:
            conv_block_lst.append(
                nn.Sequential(
                    nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),
                    nn.BatchNorm2d(channel),
                    nn.ReLU(inplace=True),
                )
            )

        # Output block
        conv_block_lst.append(
            nn.Sequential(nn.Conv2d(self.channel_conf[-1], output_channel, kernel_size=1, stride=1, padding=0))
        )
        self.conv_block_lst = nn.ModuleList(conv_block_lst)

    def get_channel_conf(self, num_upsample: int) -> List[int]:
        """Get num of channels based on number of upsampling."""
        if num_upsample == 2:
            return [256, 64, 16]
        return [256, 64, 16, 4]

    def forward(self, input_features: Tensor) -> Tensor:
        # Iterate til output block
        out = input_features
        for block in self.conv_block_lst[:-1]:
            out = block(out)
            out = self.pixshuffle(out)

        # Output layer
        out = self.conv_block_lst[-1](out)
        heatmap = softmax(out, dim=1)[:, 1, :, :]

        return heatmap


class SuperpointDescriptor(Module):
    """Descriptor decoder based on the SuperPoint arcihtecture.

    Args:
        input_feat_dim: channel size of the input features.

    Returns:
        the semi-dense descriptors with shape (B, 128, H/4, W/4).

    """

    def __init__(self, input_feat_dim: int = 128) -> None:
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.convPa = nn.Conv2d(input_feat_dim, 256, kernel_size=3, stride=1, padding=1)
        self.convPb = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)

    def forward(self, input_features: Tensor) -> Tensor:
        feat = self.relu(self.convPa(input_features))
        semi = self.convPb(feat)

        return semi


# [Combination of all previous models in one]


class SOLD2Net(Module):
    """Full network for SOLD².

    Args:
        model_cfg: the configuration as a Dict.

    Returns:
        a Dict with the following values:
            junctions: heatmap of junctions.
            heatmap: line heatmap.
            descriptors: semi-dense descriptors.

    """

    def __init__(self, model_cfg: Dict[str, Any]) -> None:
        super().__init__()
        self.cfg = model_cfg

        # Backbone
        self.backbone_net = HourglassBackbone(**self.cfg["backbone_cfg"])
        feat_channel = 256

        # Junction decoder
        self.junction_decoder = SuperpointDecoder(feat_channel, self.cfg["grid_size"])

        # Line heatmap decoder
        self.heatmap_decoder = PixelShuffleDecoder(feat_channel, num_upsample=2)

        # Descriptor decoder
        if "use_descriptor" in self.cfg:
            self.descriptor_decoder = SuperpointDescriptor(feat_channel)

    def forward(self, input_images: Tensor) -> Dict[str, Tensor]:
        # The backbone
        features = self.backbone_net(input_images)

        # junction decoder
        junctions = self.junction_decoder(features)

        # heatmap decoder
        heatmaps = self.heatmap_decoder(features)

        outputs = {"junctions": junctions, "heatmap": heatmaps}

        # Descriptor decoder
        if "use_descriptor" in self.cfg:
            outputs["descriptors"] = self.descriptor_decoder(features)

        return outputs
