# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast

from torch import Size

import kornia.augmentation as K
from kornia.augmentation.auto.operations import OperationBase
from kornia.augmentation.container.base import ImageSequentialBase, TransformMatrixMinIn
from kornia.augmentation.container.ops import InputSequentialOps
from kornia.augmentation.container.params import ParamItem
from kornia.augmentation.utils import _transform_input, override_parameters
from kornia.core import Module, Tensor, as_tensor
from kornia.utils import eye_like


class PolicySequential(TransformMatrixMinIn, ImageSequentialBase):
    """Policy tuple for applying multiple operations.

    Args:
        operations: a list of operations to perform.

    """

    def __init__(self, *operations: OperationBase) -> None:
        self.validate_operations(*operations)
        super().__init__(*operations)
        self._valid_ops_for_transform_computation: Tuple[Any, ...] = (OperationBase,)

    def _update_transform_matrix_for_valid_op(self, module: Module) -> None:
        self._transform_matrices.append(module.transform_matrix)

    def clear_state(self) -> None:
        self._reset_transform_matrix_state()
        return super().clear_state()

    def validate_operations(self, *operations: OperationBase) -> None:
        invalid_ops: List[OperationBase] = []
        for op in operations:
            if not isinstance(op, OperationBase):
                invalid_ops.append(op)
        if len(invalid_ops) != 0:
            raise ValueError(f"All operations must be Kornia Operations. Got {invalid_ops}.")

    def identity_matrix(self, input: Tensor) -> Tensor:
        """Return identity matrix."""
        return eye_like(3, input)

    def get_transformation_matrix(
        self,
        input: Tensor,
        params: Optional[List[ParamItem]] = None,
        recompute: bool = False,
        extra_args: Optional[Dict[str, Any]] = None,
    ) -> Tensor:
        """Compute the transformation matrix according to the provided parameters.

        Args:
            input: the input tensor.
            params: params for the sequence.
            recompute: if to recompute the transformation matrix according to the params.
                default: False.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.

        """
        if params is None:
            raise NotImplementedError("requires params to be provided.")
        named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence(params)

        # Define as 1 for broadcasting
        res_mat: Tensor = self.identity_matrix(_transform_input(input))
        for (_, module), param in zip(named_modules, params if params is not None else []):
            module = cast(OperationBase, module)
            if isinstance(module.op, (K.GeometricAugmentationBase2D,)) and isinstance(param.data, dict):
                ori_shape = input.shape
                input = module.op.transform_tensor(input)
                # Standardize shape
                if recompute:
                    flags = override_parameters(module.op.flags, extra_args, in_place=False)
                    mat = module.op.generate_transformation_matrix(input, param.data, flags)
                elif module.op._transform_matrix is not None:
                    mat = as_tensor(module.transform_matrix, device=input.device, dtype=input.dtype)
                else:
                    raise RuntimeError(f"{module}.transform_matrix is None while `recompute=False`.")
                res_mat = mat @ res_mat
                input = module.op.transform_output_tensor(input, ori_shape)
                if module.op.keepdim and ori_shape != input.shape:
                    res_mat = res_mat.squeeze()
        return res_mat

    def is_intensity_only(self) -> bool:
        for module in self.children():
            module = cast(OperationBase, module)
            if isinstance(module.op, (K.GeometricAugmentationBase2D,)):
                return False
        return True

    def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
        if params is not None:
            return super().get_children_by_params(params)
        return self.named_children()

    def forward_parameters(self, batch_shape: Size) -> List[ParamItem]:
        named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence()

        params: List[ParamItem] = []
        mod_param: Union[Dict[str, Tensor], List[ParamItem]]
        for name, module in named_modules:
            module = cast(OperationBase, module)
            mod_param = module.op.forward_parameters(batch_shape)
            param = ParamItem(name, mod_param)
            params.append(param)
        return params

    def transform_inputs(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        for param in params:
            module = self.get_submodule(param.name)
            input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
            self._update_transform_matrix_by_module(module)
        return input
