# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import OrderedDict
from itertools import zip_longest
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
from torch import nn

import kornia.augmentation as K
from kornia.augmentation.base import _AugmentationBase
from kornia.core import Module, Tensor
from kornia.geometry.boxes import Boxes
from kornia.geometry.keypoints import Keypoints

from .ops import BoxSequentialOps, InputSequentialOps, KeypointSequentialOps, MaskSequentialOps
from .params import ParamItem

__all__ = ["BasicSequentialBase", "ImageSequentialBase", "SequentialBase"]


class BasicSequentialBase(nn.Sequential):
    r"""BasicSequential for creating kornia modulized processing pipeline.

    Args:
        *args : a list of kornia augmentation and image operation modules.

    """

    def __init__(self, *args: Module) -> None:
        # To name the modules properly
        _args = OrderedDict()
        for idx, mod in enumerate(args):
            if not isinstance(mod, Module):
                raise NotImplementedError(f"Only Module are supported at this moment. Got {mod}.")
            _args.update({f"{mod.__class__.__name__}_{idx}": mod})
        super().__init__(_args)
        self._params: Optional[List[ParamItem]] = None

    def get_submodule(self, target: str) -> Module:
        """Get submodule.

        This code is taken from torch 1.9.0 since it is not introduced
        back to torch 1.7.1. We included this for maintaining more
        backward torch versions.

        Args:
            target: The fully-qualified string name of the submodule
                to look for. (See above example for how to specify a
                fully-qualified string.)

        Returns:
            Module: The submodule referenced by ``target``

        Raises:
            AttributeError: If the target string references an invalid
                path or resolves to something that is not an
                ``Module``

        """
        if len(target) == 0:
            return self

        atoms: List[str] = target.split(".")
        mod = self

        for item in atoms:
            if not hasattr(mod, item):
                raise AttributeError(mod._get_name() + " has no attribute `" + item + "`")

            mod = getattr(mod, item)

            if not isinstance(mod, Module):
                raise AttributeError("`" + item + "` is not an Module")

        return mod

    def clear_state(self) -> None:
        """Reset self._params state to None."""
        self._params = None

    # TODO: Implement this for all submodules.
    def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
        raise NotImplementedError

    def get_children_by_indices(self, indices: Tensor) -> Iterator[Tuple[str, Module]]:
        modules = list(self.named_children())
        for idx in indices:
            yield modules[idx]

    def get_children_by_params(self, params: List[ParamItem]) -> Iterator[Tuple[str, Module]]:
        modules = list(self.named_children())
        # TODO: Wrong params passed here when nested ImageSequential
        for param in params:
            yield modules[list(dict(self.named_children()).keys()).index(param.name)]

    def get_params_by_module(self, named_modules: Iterator[Tuple[str, Module]]) -> Iterator[ParamItem]:
        # This will not take module._params
        for name, _ in named_modules:
            yield ParamItem(name, None)


class SequentialBase(BasicSequentialBase):
    r"""SequentialBase for creating kornia modulized processing pipeline.

    Args:
        *args : a list of kornia augmentation and image operation modules.
        same_on_batch: apply the same transformation across the batch.
            If None, it will not overwrite the function-wise settings.
        return_transform: if ``True`` return the matrix describing the transformation
            applied to each. If None, it will not overwrite the function-wise settings.
        keepdim: whether to keep the output shape the same as input (True) or broadcast it
            to the batch form (False). If None, it will not overwrite the function-wise settings.

    """

    def __init__(self, *args: Module, same_on_batch: Optional[bool] = None, keepdim: Optional[bool] = None) -> None:
        # To name the modules properly
        super().__init__(*args)
        self._same_on_batch = same_on_batch
        self._keepdim = keepdim
        self.update_attribute(same_on_batch, keepdim=keepdim)

    def update_attribute(
        self,
        same_on_batch: Optional[bool] = None,
        return_transform: Optional[bool] = None,
        keepdim: Optional[bool] = None,
    ) -> None:
        for mod in self.children():
            # MixAugmentation does not have return transform
            if isinstance(mod, (_AugmentationBase, K.MixAugmentationBaseV2)):
                if same_on_batch is not None:
                    mod.same_on_batch = same_on_batch
                if keepdim is not None:
                    mod.keepdim = keepdim
            if isinstance(mod, SequentialBase):
                mod.update_attribute(same_on_batch, return_transform, keepdim)

    @property
    def same_on_batch(self) -> Optional[bool]:
        return self._same_on_batch

    @same_on_batch.setter
    def same_on_batch(self, same_on_batch: Optional[bool]) -> None:
        self._same_on_batch = same_on_batch
        self.update_attribute(same_on_batch=same_on_batch)

    @property
    def keepdim(self) -> Optional[bool]:
        return self._keepdim

    @keepdim.setter
    def keepdim(self, keepdim: Optional[bool]) -> None:
        self._keepdim = keepdim
        self.update_attribute(keepdim=keepdim)

    def autofill_dim(self, input: Tensor, dim_range: Tuple[int, int] = (2, 4)) -> Tuple[torch.Size, torch.Size]:
        """Fill tensor dim to the upper bound of dim_range.

        If input tensor dim is smaller than the lower bound of dim_range, an error will be thrown out.
        """
        ori_shape = input.shape
        if len(ori_shape) < dim_range[0] or len(ori_shape) > dim_range[1]:
            raise RuntimeError(f"input shape expected to be in {dim_range} while got {ori_shape}.")
        while len(input.shape) < dim_range[1]:
            input = input[None]
        return ori_shape, input.shape


class ImageSequentialBase(SequentialBase):
    def identity_matrix(self, input: Tensor) -> Tensor:
        """Return identity matrix."""
        raise NotImplementedError

    def get_transformation_matrix(
        self,
        input: Tensor,
        params: Optional[List[ParamItem]] = None,
        recompute: bool = False,
        extra_args: Optional[Dict[str, Any]] = None,
    ) -> Optional[Tensor]:
        """Compute the transformation matrix according to the provided parameters.

        Args:
            input: the input tensor.
            params: params for the sequence.
            recompute: if to recompute the transformation matrix according to the params.
                default: False.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        raise NotImplementedError

    def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
        raise NotImplementedError

    def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
        """Get module sequence by input params."""
        raise NotImplementedError

    def transform_inputs(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        for param in params:
            module = self.get_submodule(param.name)
            input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
        return input

    def inverse_inputs(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        for (_, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
            input = InputSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args)
        return input

    def transform_masks(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        for param in params:
            module = self.get_submodule(param.name)
            input = MaskSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
        return input

    def inverse_masks(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        for (_, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
            input = MaskSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args)
        return input

    def transform_boxes(
        self, input: Boxes, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Boxes:
        for param in params:
            module = self.get_submodule(param.name)
            input = BoxSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
        return input

    def inverse_boxes(
        self, input: Boxes, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Boxes:
        for (_, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
            input = BoxSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args)
        return input

    def transform_keypoints(
        self, input: Keypoints, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Keypoints:
        for param in params:
            module = self.get_submodule(param.name)
            input = KeypointSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
        return input

    def inverse_keypoints(
        self, input: Keypoints, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Keypoints:
        for (_, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
            input = KeypointSequentialOps.inverse(input, module=module, param=param, extra_args=extra_args)
        return input

    def inverse(
        self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        """Inverse transformation.

        Used to inverse a tensor according to the performed transformation by a forward pass, or with respect to
        provided parameters.
        """
        if params is None:
            if self._params is None:
                raise ValueError(
                    "No parameters available for inversing, please run a forward pass first "
                    "or passing valid params into this function."
                )
            params = self._params

        input = self.inverse_inputs(input, params, extra_args=extra_args)

        return input

    def forward(
        self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        self.clear_state()

        if params is None:
            inp = input
            _, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
            params = self.forward_parameters(out_shape)

        input = self.transform_inputs(input, params=params, extra_args=extra_args)

        self._params = params
        return input


class TransformMatrixMinIn:
    """Enables computation matrix computation."""

    _valid_ops_for_transform_computation: Tuple[Any, ...] = ()
    _transformation_matrix_arg: str = "silent"

    def __init__(self, *args, **kwargs) -> None:  # type:ignore
        super().__init__(*args, **kwargs)
        self._transform_matrix: Optional[Tensor] = None
        self._transform_matrices: List[Optional[Tensor]] = []

    def _parse_transformation_matrix_mode(self, transformation_matrix_mode: str) -> None:
        _valid_transformation_matrix_args = {"silence", "silent", "rigid", "skip"}
        if transformation_matrix_mode not in _valid_transformation_matrix_args:
            raise ValueError(
                f"`transformation_matrix` has to be one of {_valid_transformation_matrix_args}. "
                f"Got {transformation_matrix_mode}."
            )
        self._transformation_matrix_arg = transformation_matrix_mode

    @property
    def transform_matrix(self) -> Optional[Tensor]:
        # In AugmentationSequential, the parent class is accessed first.
        # So that it was None in the beginning. We hereby use lazy computation here.
        if self._transform_matrix is None and len(self._transform_matrices) != 0:
            self._transform_matrix = self._transform_matrices[0]
            for mat in self._transform_matrices[1:]:
                self._update_transform_matrix(mat)
        return self._transform_matrix

    def _update_transform_matrix_for_valid_op(self, module: Module) -> None:
        raise NotImplementedError(module)

    def _update_transform_matrix_by_module(self, module: Module) -> None:
        if self._transformation_matrix_arg == "skip":
            return
        if isinstance(module, self._valid_ops_for_transform_computation):
            self._update_transform_matrix_for_valid_op(module)
        elif self._transformation_matrix_arg == "rigid":
            raise RuntimeError(
                f"Non-rigid module `{module}` is not supported under `rigid` computation mode. "
                "Please either update the module or change the `transformation_matrix` argument."
            )

    def _update_transform_matrix(self, transform_matrix: Optional[Tensor]) -> None:
        if self._transform_matrix is None:
            self._transform_matrix = transform_matrix
        else:
            self._transform_matrix = transform_matrix @ self._transform_matrix

    def _reset_transform_matrix_state(self) -> None:
        self._transform_matrix = None
        self._transform_matrices = []
