# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast

import torch

import kornia.augmentation as K
from kornia.augmentation.base import _AugmentationBase
from kornia.augmentation.utils import override_parameters
from kornia.core import ImageModule, Module, Tensor, as_tensor
from kornia.core.module import ImageModuleMixIn
from kornia.utils import eye_like

from .base import ImageSequentialBase
from .params import ParamItem

__all__ = ["ImageSequential"]


class ImageModuleForSequentialMixIn(ImageModuleMixIn):
    _disable_features: bool = False

    @property
    def disable_features(self) -> bool:
        return self._disable_features

    @disable_features.setter
    def disable_features(self, value: bool = True) -> None:
        self._disable_features = value

    def disable_item_features(self, *args: Module) -> None:
        for arg in args:
            if isinstance(arg, (ImageModule,)):
                arg.disable_features = True


class ImageSequential(ImageSequentialBase, ImageModuleForSequentialMixIn):
    r"""Sequential for creating kornia image processing pipeline.

    Args:
        *args : a list of kornia augmentation and image operation modules.
        same_on_batch: apply the same transformation across the batch.
            If None, it will not overwrite the function-wise settings.
        keepdim: whether to keep the output shape the same as input (True) or broadcast it
            to the batch form (False). If None, it will not overwrite the function-wise settings.
        random_apply: randomly select a sublist (order agnostic) of args to
            apply transformation. The selection probability aligns to the ``random_apply_weights``.
            If int, a fixed number of transformations will be selected.
            If (a,), x number of transformations (a <= x <= len(args)) will be selected.
            If (a, b), x number of transformations (a <= x <= b) will be selected.
            If True, the whole list of args will be processed as a sequence in a random order.
            If False, the whole list of args will be processed as a sequence in original order.
        random_apply_weights: a list of selection weights for each operation. The length shall be as
            same as the number of operations. By default, operations are sampled uniformly.

    .. note::
        Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module.
        Those transformations in ``kornia.geometry`` will not be taken into account.

    Examples:
        >>> _ = torch.manual_seed(77)
        >>> import kornia
        >>> input = torch.randn(2, 3, 5, 6)
        >>> aug_list = ImageSequential(
        ...     kornia.color.BgrToRgb(),
        ...     kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     kornia.filters.MedianBlur((3, 3)),
        ...     kornia.augmentation.RandomAffine(360, p=1.0),
        ...     kornia.enhance.Invert(),
        ...     kornia.augmentation.RandomMixUpV2(p=1.0),
        ...     same_on_batch=True,
        ...     random_apply=10,
        ... )
        >>> out = aug_list(input)
        >>> out.shape
        torch.Size([2, 3, 5, 6])

        Reproduce with provided params.
        >>> out2 = aug_list(input, params=aug_list._params)
        >>> torch.equal(out, out2)
        True

    Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights`` in ``ImageSequential``.

        >>> import kornia
        >>> input = torch.randn(2, 3, 5, 6)
        >>> aug_list = ImageSequential(
        ...     kornia.color.BgrToRgb(),
        ...     kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     kornia.filters.MedianBlur((3, 3)),
        ...     kornia.augmentation.RandomAffine(360, p=1.0),
        ...     random_apply=1,
        ...     random_apply_weights=[0.5, 0.3, 0.2, 0.5]
        ... )
        >>> out= aug_list(input)
        >>> out.shape
        torch.Size([2, 3, 5, 6])

    """

    def __init__(
        self,
        *args: Module,
        same_on_batch: Optional[bool] = None,
        keepdim: Optional[bool] = None,
        random_apply: Union[int, bool, Tuple[int, int]] = False,
        random_apply_weights: Optional[List[float]] = None,
        if_unsupported_ops: str = "raise",
        disable_item_features: bool = True,
        disable_sequential_features: bool = False,
    ) -> None:
        if disable_item_features:
            self.disable_item_features(*args)
        if disable_sequential_features:
            self.disable_features = True
        super().__init__(*args, same_on_batch=same_on_batch, keepdim=keepdim)

        self.random_apply = self._read_random_apply(random_apply, len(args))
        if random_apply_weights is not None and len(random_apply_weights) != len(self):
            raise ValueError(
                "The length of `random_apply_weights` must be as same as the number of operations."
                f"Got {len(random_apply_weights)} and {len(self)}."
            )
        self.random_apply_weights = as_tensor(random_apply_weights or torch.ones((len(self),)))
        self.if_unsupported_ops = if_unsupported_ops

    def _read_random_apply(
        self, random_apply: Union[int, bool, Tuple[int, int]], max_length: int
    ) -> Union[Tuple[int, int], bool]:
        """Process the scenarios for random apply."""
        if isinstance(random_apply, (bool,)) and random_apply is False:
            random_apply = False
        elif isinstance(random_apply, (bool,)) and random_apply is True:
            random_apply = (max_length, max_length + 1)
        elif isinstance(random_apply, (int,)):
            random_apply = (random_apply, random_apply + 1)
        elif (
            isinstance(random_apply, (tuple,))
            and len(random_apply) == 2
            and isinstance(random_apply[0], (int,))
            and isinstance(random_apply[1], (int,))
        ):
            random_apply = (random_apply[0], random_apply[1] + 1)
        elif isinstance(random_apply, (tuple,)) and len(random_apply) == 1 and isinstance(random_apply[0], (int,)):
            random_apply = (random_apply[0], max_length + 1)
        else:
            raise ValueError(f"Non-readable random_apply. Got {random_apply}.")
        if random_apply is not False and not (
            isinstance(random_apply, (tuple,))
            and len(random_apply) == 2
            and isinstance(random_apply[0], (int,))
            and isinstance(random_apply[0], (int,))
        ):
            raise AssertionError(f"Expect a tuple of (int, int). Got {random_apply}.")
        return random_apply

    def get_random_forward_sequence(self, with_mix: bool = True) -> Tuple[Iterator[Tuple[str, Module]], bool]:
        """Get a forward sequence when random apply is in need.

        Args:
            with_mix: if to require a mix augmentation for the sequence.

        Note:
            Mix augmentations (e.g. RandomMixUp) will be only applied once even in a random forward.

        """
        if isinstance(self.random_apply, tuple):
            num_samples = int(torch.randint(*self.random_apply, (1,)).item())
        else:
            raise TypeError(f"random apply should be a tuple. Gotcha {type(self.random_apply)}")

        multinomial_weights = self.random_apply_weights.clone()
        # Mix augmentation can only be applied once per forward
        mix_indices = self.get_mix_augmentation_indices(self.named_children())
        # kick out the mix augmentations
        multinomial_weights[mix_indices] = 0
        indices = torch.multinomial(
            multinomial_weights,
            num_samples,
            # enable replacement if non-mix augmentation is less than required
            replacement=num_samples > multinomial_weights.sum().item(),
        )

        mix_added = False
        if with_mix and len(mix_indices) != 0:
            # Make the selection fair.
            if (torch.rand(1) < ((len(mix_indices) + len(indices)) / len(self))).item():
                indices[-1] = torch.multinomial((~multinomial_weights.bool()).float(), 1)
                indices = indices[torch.randperm(len(indices))]
                mix_added = True

        return self.get_children_by_indices(indices), mix_added

    def get_mix_augmentation_indices(self, named_modules: Iterator[Tuple[str, Module]]) -> List[int]:
        """Get all the mix augmentations since they are label-involved.

        Special operations needed for label-involved augmentations.
        """
        # NOTE: MixV2 will not be a special op in the future.
        return [idx for idx, (_, child) in enumerate(named_modules) if isinstance(child, K.MixAugmentationBaseV2)]

    def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
        if params is None:
            # Mix augmentation can only be applied once per forward
            mix_indices = self.get_mix_augmentation_indices(self.named_children())

            if self.random_apply:
                return self.get_random_forward_sequence()[0]

            if len(mix_indices) > 1:
                raise ValueError(
                    "Multiple mix augmentation is prohibited without enabling random_apply."
                    f"Detected {len(mix_indices)} mix augmentations."
                )

            return self.named_children()

        return self.get_children_by_params(params)

    def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
        named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence()

        params: List[ParamItem] = []
        mod_param: Union[Dict[str, Tensor], List[ParamItem]]
        for name, module in named_modules:
            if isinstance(module, (_AugmentationBase, K.MixAugmentationBaseV2, ImageSequentialBase)):
                mod_param = module.forward_parameters(batch_shape)
                param = ParamItem(name, mod_param)
            else:
                param = ParamItem(name, None)
            batch_shape = _get_new_batch_shape(param, batch_shape)
            params.append(param)
        return params

    def identity_matrix(self, input: Tensor) -> Tensor:
        """Return identity matrix."""
        return eye_like(3, input)

    def get_transformation_matrix(
        self,
        input: Tensor,
        params: Optional[List[ParamItem]] = None,
        recompute: bool = False,
        extra_args: Optional[Dict[str, Any]] = None,
    ) -> Optional[Tensor]:
        """Compute the transformation matrix according to the provided parameters.

        Args:
            input: the input tensor.
            params: params for the sequence.
            recompute: if to recompute the transformation matrix according to the params.
                default: False.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if params is None:
            raise NotImplementedError("requires params to be provided.")
        if extra_args is None:
            extra_args = {}
        named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence(params)

        # Define as 1 for broadcasting
        res_mat: Optional[Tensor] = None
        for (_, module), param in zip(named_modules, params if params is not None else []):
            if isinstance(module, (K.GeometricAugmentationBase2D,)) and isinstance(param.data, dict):
                ori_shape = input.shape
                try:
                    input = module.transform_tensor(input)
                except ValueError:
                    # Ignore error for 5-dim video
                    pass
                # Standardize shape
                if recompute:
                    flags = override_parameters(module.flags, extra_args, in_place=False)
                    mat = module.generate_transformation_matrix(input, param.data, flags)
                elif module._transform_matrix is not None:
                    mat = as_tensor(module._transform_matrix, device=input.device, dtype=input.dtype)
                else:
                    raise RuntimeError(f"{module}._transform_matrix is None while `recompute=False`.")
                res_mat = mat if res_mat is None else mat @ res_mat
                input = module.transform_output_tensor(input, ori_shape)
                if module.keepdim and ori_shape != input.shape:
                    res_mat = res_mat.squeeze()
            elif isinstance(module, (ImageSequentialBase,)):
                # If not augmentationSequential
                if isinstance(module, (K.AugmentationSequential,)) and not recompute:
                    mat = as_tensor(module._transform_matrix, device=input.device, dtype=input.dtype)
                else:
                    maybe_param_data = cast(Optional[List[ParamItem]], param.data)
                    _mat = module.get_transformation_matrix(
                        input, maybe_param_data, recompute=recompute, extra_args=extra_args
                    )
                    mat = module.identity_matrix(input) if _mat is None else _mat
                res_mat = mat if res_mat is None else mat @ res_mat
        return res_mat

    # TODO: Make this as a class property to avoid running every time.
    def is_intensity_only(self, strict: bool = True) -> bool:
        """Check if all transformations are intensity-based.

        Args:
            strict: if strict is False, it will allow non-augmentation Modules to be passed.
                e.g. `kornia.enhance.AdjustBrightness` will be recognized as non-intensity module
                if strict is set to True.

        Note: patch processing would break the continuity of labels (e.g. bbounding boxes, masks).

        """
        for arg in self.children():
            if isinstance(arg, (ImageSequential,)) and not arg.is_intensity_only(strict):
                return False
            elif isinstance(arg, (ImageSequential,)):
                pass
            elif isinstance(arg, K.IntensityAugmentationBase2D):
                pass
            elif strict:
                # disallow non-registered ops if in strict mode
                # TODO: add an ops register module
                return False
        return True

    def __call__(
        self,
        *inputs: Any,
        input_names_to_handle: Optional[List[Any]] = None,
        output_type: str = "tensor",
        **kwargs: Any,
    ) -> Any:
        """Overwrite the __call__ function to handle various inputs.

        Args:
            inputs: Inputs to operate on.
            input_names_to_handle: List of input names to convert, if None, handle all inputs.
            output_type: Desired output type ('tensor', 'numpy', or 'pil').
            kwargs: Additional arguments.

        Returns:
            Callable: Decorated function with converted input and output types.

        """
        # Wrap the forward method with the decorator
        if not self._disable_features:
            decorated_forward = self.convert_input_output(
                input_names_to_handle=input_names_to_handle, output_type=output_type
            )(super().__call__)
            _output_image = decorated_forward(*inputs, **kwargs)
            if output_type == "tensor":
                self._output_image = self._detach_tensor_to_cpu(_output_image)
            else:
                self._output_image = _output_image
        else:
            _output_image = super().__call__(*inputs, **kwargs)
        return _output_image


def _get_new_batch_shape(param: ParamItem, batch_shape: torch.Size) -> torch.Size:
    """Get the new batch shape if the augmentation changes the image size.

    Note:
       Augmentations that change the image size must provide the parameter `output_size`.

    """
    data = param.data
    if data is None:
        return batch_shape

    # If data is a list, process all subitems (exit early if all subitems are None)
    if isinstance(data, list):
        for p in data:
            batch_shape = _get_new_batch_shape(p, batch_shape)
        return batch_shape

    # Carefully avoid evaluating expression multiple times; batch_prob is often a 1-element tensor
    if "output_size" in data:
        # Inline check for common PyTorch float tensor case
        batch_prob = data.get("batch_prob", None)
        if batch_prob is not None:
            # Avoid repeated indexing, always fetch scalar efficiently
            prob = batch_prob.item() if batch_prob.numel() == 1 else batch_prob[0].item()
            if prob <= 0.5:
                return batch_shape
        else:
            # batch_prob missing, fallback do not update shape
            return batch_shape
        # Mutate only last two dims
        new_batch_shape = list(batch_shape)
        new_batch_shape[-2:] = data["output_size"][0]
        return torch.Size(new_batch_shape)

    return batch_shape
