# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import copy
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union

from typing_extensions import ParamSpec

import kornia.augmentation as K
from kornia.augmentation.base import _AugmentationBase
from kornia.constants import DataKey
from kornia.core import Module, Tensor
from kornia.geometry.boxes import Boxes
from kornia.geometry.keypoints import Keypoints

from .params import ParamItem

DataType = Union[Tensor, List[Tensor], Boxes, Keypoints]

# NOTE: shouldn't this SequenceDataType alias be equals to List[DataType]?
SequenceDataType = Union[List[Tensor], List[List[Tensor]], List[Boxes], List[Keypoints]]

T = TypeVar("T")


class SequentialOpsInterface(Generic[T], metaclass=ABCMeta):
    """Abstract interface for applying and inversing transformations."""

    @classmethod
    def get_instance_module_param(cls, param: ParamItem) -> Dict[str, Tensor]:
        if isinstance(param, ParamItem) and isinstance(param.data, dict):
            _params = param.data
        else:
            raise TypeError(f"Expected param (ParamItem.data) be a dictionary. Gotcha {param}.")
        return _params

    @classmethod
    def get_sequential_module_param(cls, param: ParamItem) -> List[ParamItem]:
        if isinstance(param, ParamItem) and isinstance(param.data, list):
            _params = param.data
        else:
            raise TypeError(f"Expected param (ParamItem.data) be a list. Gotcha {param}.")
        return _params

    @classmethod
    @abstractmethod
    def transform(cls, input: T, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None) -> T:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def inverse(cls, input: T, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None) -> T:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        raise NotImplementedError


class AugmentationSequentialOps:
    def __init__(self, data_keys: Optional[List[DataKey]]) -> None:
        self._data_keys = data_keys

    @property
    def data_keys(self) -> Optional[List[DataKey]]:
        return self._data_keys

    @data_keys.setter
    def data_keys(self, data_keys: Optional[Union[List[DataKey], List[str], List[int]]]) -> None:
        if data_keys:
            self._data_keys = [DataKey.get(inp) for inp in data_keys]
        else:
            self._data_keys = None

    def preproc_datakeys(self, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None) -> List[DataKey]:
        if data_keys is None:
            if isinstance(self.data_keys, list):
                return self.data_keys
            raise ValueError("Sequential ops needs data keys to be able to process.")
        else:
            return [DataKey.get(inp) for inp in data_keys]

    def _get_op(self, data_key: DataKey) -> Type[SequentialOpsInterface[Any]]:
        """Return the corresponding operation given a data key."""
        if data_key == DataKey.INPUT:
            return InputSequentialOps
        if data_key == DataKey.MASK:
            return MaskSequentialOps
        if data_key in {DataKey.BBOX, DataKey.BBOX_XYWH, DataKey.BBOX_XYXY}:
            return BoxSequentialOps
        if data_key == DataKey.KEYPOINTS:
            return KeypointSequentialOps
        if data_key == DataKey.CLASS:
            return ClassSequentialOps
        raise RuntimeError(f"Operation for `{data_key.name}` is not found.")

    def transform(
        self,
        *arg: DataType,
        module: Module,
        param: ParamItem,
        extra_args: Dict[DataKey, Dict[str, Any]],
        data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
    ) -> Union[DataType, SequenceDataType]:
        _data_keys = self.preproc_datakeys(data_keys)

        if isinstance(module, K.RandomTransplantation):
            # For transforms which require the full input to calculate the parameters (e.g. RandomTransplantation)
            param = ParamItem(
                name=param.name,
                data=module.params_from_input(
                    *arg,  # type: ignore[arg-type]
                    data_keys=_data_keys,
                    params=param.data,  # type: ignore[arg-type]
                    extra_args=extra_args,
                ),
            )

        outputs = []
        for inp, dcate in zip(arg, _data_keys):
            op = self._get_op(dcate)
            extra_arg = extra_args.get(dcate, {})
            if dcate.name == "MASK" and isinstance(inp, list):
                outputs.append(MaskSequentialOps.transform_list(inp, module, param=param, extra_args=extra_arg))
            else:
                outputs.append(op.transform(inp, module, param=param, extra_args=extra_arg))
        if len(outputs) == 1 and isinstance(outputs, (list, tuple)):
            return outputs[0]
        return outputs

    def inverse(
        self,
        *arg: DataType,
        module: Module,
        param: ParamItem,
        extra_args: Dict[DataKey, Dict[str, Any]],
        data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None,
    ) -> Union[DataType, SequenceDataType]:
        _data_keys = self.preproc_datakeys(data_keys)
        outputs = []
        for inp, dcate in zip(arg, _data_keys):
            op = self._get_op(dcate)
            extra_arg = extra_args[dcate] if dcate in extra_args else {}
            outputs.append(op.inverse(inp, module, param=param, extra_args=extra_arg))
        if len(outputs) == 1 and isinstance(outputs, (list, tuple)):
            return outputs[0]
        return outputs


P = ParamSpec("P")


def make_input_only_sequential(module: "K.container.ImageSequentialBase") -> Callable[P, Tensor]:
    """Disable all other additional inputs (e.g. ) for ImageSequential."""

    def f(*args: P.args, **kwargs: P.kwargs) -> Tensor:
        return module(*args, **kwargs)

    return f


def get_geometric_only_param(module: "K.container.ImageSequentialBase", param: List[ParamItem]) -> List[ParamItem]:
    """Return geometry param."""
    named_modules = module.get_forward_sequence(param)

    res: List[ParamItem] = []
    for (_, mod), p in zip(named_modules, param):
        if isinstance(mod, (K.GeometricAugmentationBase2D, K.GeometricAugmentationBase3D)):
            res.append(p)
    return res


class InputSequentialOps(SequentialOpsInterface[Tensor]):
    @classmethod
    def transform(
        cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        if extra_args is None:
            extra_args = {}
        if isinstance(module, (_AugmentationBase, K.MixAugmentationBaseV2)):
            input = module(input, params=cls.get_instance_module_param(param), data_keys=[DataKey.INPUT], **extra_args)
        elif isinstance(module, (K.container.ImageSequentialBase,)):
            input = module.transform_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
        elif isinstance(module, (K.auto.operations.OperationBase,)):
            input = module(input, params=cls.get_instance_module_param(param))
        else:
            if param.data is not None:
                raise AssertionError(f"Non-augmentaion operation {param.name} require empty parameters. Got {param}.")
            input = module(input)
        return input

    @classmethod
    def inverse(
        cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        if extra_args is None:
            extra_args = {}
        if isinstance(module, K.GeometricAugmentationBase2D):
            input = module.inverse(input, params=cls.get_instance_module_param(param), **extra_args)
        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d inverse operations are not yet supported. You are welcome to file a PR in our repo."
            )
        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return InputSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            input = module.inverse_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
        elif isinstance(module, K.container.ImageSequentialBase):
            input = module.inverse_inputs(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)
        return input


class ClassSequentialOps(SequentialOpsInterface[Tensor]):
    """Apply and inverse transformations for class labels if needed."""

    @classmethod
    def transform(
        cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        if isinstance(module, K.MixAugmentationBaseV2):
            raise NotImplementedError(
                "The support for class labels for mix augmentations that change the class label is not yet supported."
            )
        return input

    @classmethod
    def inverse(
        cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        return input


class MaskSequentialOps(SequentialOpsInterface[Tensor]):
    """Apply and inverse transformations for mask tensors."""

    @classmethod
    def transform(
        cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            input = module.transform_masks(
                input,
                params=cls.get_instance_module_param(param),
                flags=module.flags,
                transform=module.transform_matrix,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.RandomTransplantation):
            input = module(input, params=cls.get_instance_module_param(param), data_keys=[DataKey.MASK], **extra_args)

        elif isinstance(module, (_AugmentationBase)):
            input = module.transform_masks(
                input, params=cls.get_instance_module_param(param), flags=module.flags, **extra_args
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, K.container.ImageSequentialBase):
            input = module.transform_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            input = MaskSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)

        return input

    @classmethod
    def transform_list(
        cls, input: List[Tensor], module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> List[Tensor]:
        """Apply a transformation with respect to the parameters.

        Args:
            input: list of input tensors.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            tfm_input = []
            params = cls.get_instance_module_param(param)
            params_i = copy.deepcopy(params)
            for i, inp in enumerate(input):
                params_i["batch_prob"] = params["batch_prob"][i]
                tfm_inp = module.transform_masks(
                    inp, params=params_i, flags=module.flags, transform=module.transform_matrix, **extra_args
                )
                tfm_input.append(tfm_inp)
            input = tfm_input

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, (_AugmentationBase)):
            tfm_input = []
            params = cls.get_instance_module_param(param)
            params_i = copy.deepcopy(params)
            for i, inp in enumerate(input):
                params_i["batch_prob"] = params["batch_prob"][i]
                tfm_inp = module.transform_masks(inp, params=params_i, flags=module.flags, **extra_args)
                tfm_input.append(tfm_inp)
            input = tfm_input

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            tfm_input = []
            seq_params = cls.get_sequential_module_param(param)
            for inp in input:
                tfm_inp = module.transform_masks(inp, params=seq_params, extra_args=extra_args)
                tfm_input.append(tfm_inp)
            input = tfm_input

        elif isinstance(module, K.container.ImageSequentialBase):
            tfm_input = []
            seq_params = cls.get_sequential_module_param(param)
            for inp in input:
                tfm_inp = module.transform_masks(inp, params=seq_params, extra_args=extra_args)
                tfm_input.append(tfm_inp)
            input = tfm_input

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            raise NotImplementedError(
                "The support for list of masks under auto operations are not yet supported. You are welcome to file a"
                " PR in our repo."
            )
        return input

    @classmethod
    def inverse(
        cls, input: Tensor, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            if module.transform_matrix is None:
                raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
            transform = module.compute_inverse_transformation(module.transform_matrix)
            input = module.inverse_masks(
                input,
                params=cls.get_instance_module_param(param),
                flags=module.flags,
                transform=transform,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d mask operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.container.ImageSequentialBase):
            input = module.inverse_masks(input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            input = MaskSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)

        return input


class BoxSequentialOps(SequentialOpsInterface[Boxes]):
    """Apply and inverse transformations for bounding box tensors.

    This is for transform boxes in the format (B, N, 4, 2).
    """

    @classmethod
    def transform(
        cls, input: Boxes, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Boxes:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input tensor, (B, N, 4, 2) or (B, 4, 2).
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        _input = input.clone()

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            _input = module.transform_boxes(
                _input,
                cls.get_instance_module_param(param),
                module.flags,
                transform=module.transform_matrix,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d box operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            _input = module.transform_boxes(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, K.container.ImageSequentialBase):
            _input = module.transform_boxes(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return BoxSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)

        return _input

    @classmethod
    def inverse(
        cls, input: Boxes, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Boxes:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        _input = input.clone()

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            if module.transform_matrix is None:
                raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
            transform = module.compute_inverse_transformation(module.transform_matrix)
            _input = module.inverse_boxes(
                _input,
                param.data,  # type: ignore[arg-type]
                module.flags,
                transform=transform,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d box operations are not yet supported. You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            _input = module.inverse_boxes(_input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, K.container.ImageSequentialBase):
            _input = module.inverse_boxes(_input, params=cls.get_sequential_module_param(param), extra_args=extra_args)

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return BoxSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)
        return _input


class KeypointSequentialOps(SequentialOpsInterface[Keypoints]):
    """Apply and inverse transformations for keypoints tensors.

    This is for transform keypoints in the format (B, N, 2).
    """

    @classmethod
    def transform(
        cls, input: Keypoints, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Keypoints:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input tensor, (B, N, 4, 2) or (B, 4, 2).
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        _input = input.clone()

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            _input = module.transform_keypoints(
                _input,
                cls.get_instance_module_param(param),
                module.flags,
                transform=module.transform_matrix,
                **extra_args,
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d keypoint operations are not yet supported. "
                "You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            _input = module.transform_keypoints(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, K.container.ImageSequentialBase):
            _input = module.transform_keypoints(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return KeypointSequentialOps.transform(input, module=module.op, param=param, extra_args=extra_args)

        return _input

    @classmethod
    def inverse(
        cls, input: Keypoints, module: Module, param: ParamItem, extra_args: Optional[Dict[str, Any]] = None
    ) -> Keypoints:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if extra_args is None:
            extra_args = {}
        _input = input.clone()

        if isinstance(module, (K.GeometricAugmentationBase2D,)):
            if module.transform_matrix is None:
                raise ValueError(f"No valid transformation matrix found in {module.__class__}.")
            transform = module.compute_inverse_transformation(module.transform_matrix)
            _input = module.inverse_keypoints(
                _input, cls.get_instance_module_param(param), module.flags, transform=transform, **extra_args
            )

        elif isinstance(module, (K.GeometricAugmentationBase3D,)):
            raise NotImplementedError(
                "The support for 3d keypoint operations are not yet supported. "
                "You are welcome to file a PR in our repo."
            )

        elif isinstance(module, K.ImageSequential) and not module.is_intensity_only():
            _input = module.inverse_keypoints(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, K.container.ImageSequentialBase):
            _input = module.inverse_keypoints(
                _input, params=cls.get_sequential_module_param(param), extra_args=extra_args
            )

        elif isinstance(module, (K.auto.operations.OperationBase,)):
            return KeypointSequentialOps.inverse(input, module=module.op, param=param, extra_args=extra_args)

        return _input
