# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict, Tuple

import torch
from torch import nn

from kornia.core import Module, Tensor


def conv_1x1_bn(inp: int, oup: int) -> Module:
    """Apply 1x1 Convolution with Batch Norm."""
    return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.SiLU())


def conv_nxn_bn(inp: int, oup: int, kernal_size: int = 3, stride: int = 1) -> Module:
    """Apply NxN Convolution with Batch Norm."""
    return nn.Sequential(nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.SiLU())


class PreNorm(Module):
    def __init__(self, dim: int, fn: Module) -> None:
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x: Tensor, **kwargs: Dict[str, Any]) -> Tensor:
        return self.fn(self.norm(x), **kwargs)


class FeedForward(Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout)
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)


class Attention(Module):
    def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0) -> None:
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head**-0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        qkv = self.to_qkv(x).chunk(3, dim=-1)

        b, p, n, hd = qkv[0].shape
        q, k, v = (t.reshape(b, p, n, self.heads, hd // self.heads).transpose(2, 3) for t in qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = out.transpose(2, 3).reshape(b, p, n, hd)
        return self.to_out(out)


class Transformer(Module):
    """Transformer block described in ViT.

    Paper: https://arxiv.org/abs/2010.11929
    Based on: https://github.com/lucidrains/vit-pytorch

    Args:
        dim: input dimension.
        depth: depth for transformer block.
        heads: number of heads in multi-head attention layer.
        dim_head: head size.
        mlp_dim: dimension of the FeedForward layer.
        dropout: dropout ratio, defaults to 0.

    """

    def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0) -> None:
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                        PreNorm(dim, FeedForward(dim, mlp_dim, dropout)),
                    ]
                )
            )

    def forward(self, x: Tensor) -> Tensor:
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class MV2Block(Module):
    """MV2 block described in MobileNetV2.

    Paper: https://arxiv.org/pdf/1801.04381
    Based on: https://github.com/tonylins/pytorch-mobilenet-v2

    Args:
        inp: input channel.
        oup: output channel.
        stride: stride for convolution, defaults to 1, set to 2 if down-sample.
        expansion: expansion ratio for hidden dimension, defaults to 4.

    """

    def __init__(self, inp: int, oup: int, stride: int = 1, expansion: int = 4) -> None:
        super().__init__()
        self.stride = stride

        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                # depthwise
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pointwise
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pointwise
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # depthwise
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pointwise
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x: Tensor) -> Tensor:
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileViTBlock(Module):
    """MobileViT block mentioned in MobileViT.

    Args:
        dim: input dimension of Transformer.
        depth: depth of Transformer.
        channel: input channel.
        kernel_size: kernel size.
        patch_size: patch size for folding and unfloding.
        mlp_dim: dimension of the FeedForward layer in Transformer.
        dropout: dropout ratio, defaults to 0.

    """

    def __init__(
        self,
        dim: int,
        depth: int,
        channel: int,
        kernel_size: int,
        patch_size: Tuple[int, int],
        mlp_dim: int,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv2 = conv_1x1_bn(channel, dim)

        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)

    def forward(self, x: Tensor) -> Tensor:
        y = x.clone()

        # Local representations
        x = self.conv1(x)
        x = self.conv2(x)

        b, d, h, w = x.shape
        nh, nw = h // self.ph, w // self.pw

        # Global representations
        # [b, d, h, w] -> [b * d * nh, nw, ph, pw]
        x = x.reshape(b * d * nh, self.ph, nw, self.pw).transpose(1, 2)
        # [b * d * nh, nw, ph, pw] -> [b, (ph pw), (nh nw), d]
        x = x.reshape(b, d, nh * nw, self.ph * self.pw).transpose(1, 3)

        x = self.transformer(x)

        # [b, (ph pw), (nh nw), d] -> [b * d * nh, nw, ph, pw]
        x = x.transpose(1, 3).reshape(b * d * nh, nw, self.ph, self.pw)
        # [b * d * nh, nw, ph, pw] -> [b, d, h, w]
        x = x.transpose(1, 2).reshape(b, d, h, w)

        # Fusion
        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return x


class MobileViT(Module):
    """Module MobileViT. Default arguments is for MobileViT XXS.

    Paper: https://arxiv.org/abs/2110.02178
    Based on: https://github.com/chinhsuanwu/mobilevit-pytorch

    Args:
        mode: 'xxs', 'xs' or 's', defaults to 'xxs'.
        in_channels: the number of channels for the input image.
        patch_size: image_size must be divisible by patch_size.
        dropout: dropout ratio in Transformer.

    Example:
        >>> img = torch.rand(1, 3, 256, 256)
        >>> mvit = MobileViT(mode='xxs')
        >>> mvit(img).shape
        torch.Size([1, 320, 8, 8])

    """

    def __init__(
        self, mode: str = "xxs", in_channels: int = 3, patch_size: Tuple[int, int] = (2, 2), dropout: float = 0.0
    ) -> None:
        super().__init__()
        if mode == "xxs":
            expansion = 2
            dims = [64, 80, 96]
            channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
        elif mode == "xs":
            expansion = 4
            dims = [96, 120, 144]
            channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
        elif mode == "s":
            expansion = 4
            dims = [144, 192, 240]
            channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]

        kernel_size = 3
        depth = [2, 4, 3]

        self.conv1 = conv_nxn_bn(in_channels, channels[0], stride=2)

        self.mv2 = nn.ModuleList([])
        self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
        self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))  # Repeat
        self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
        self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
        self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))

        self.mvit = nn.ModuleList([])
        self.mvit.append(
            MobileViTBlock(dims[0], depth[0], channels[5], kernel_size, patch_size, int(dims[0] * 2), dropout=dropout)
        )
        self.mvit.append(
            MobileViTBlock(dims[1], depth[1], channels[7], kernel_size, patch_size, int(dims[1] * 4), dropout=dropout)
        )
        self.mvit.append(
            MobileViTBlock(dims[2], depth[2], channels[9], kernel_size, patch_size, int(dims[2] * 4), dropout=dropout)
        )

        self.conv2 = conv_1x1_bn(channels[-2], channels[-1])

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.mv2[0](x)

        x = self.mv2[1](x)
        x = self.mv2[2](x)
        x = self.mv2[3](x)  # Repeat

        x = self.mv2[4](x)
        x = self.mvit[0](x)

        x = self.mv2[5](x)
        x = self.mvit[1](x)

        x = self.mv2[6](x)
        x = self.mvit[2](x)
        x = self.conv2(x)
        return x
