# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict, List, Optional, Tuple, Union, cast

import torch

import kornia.augmentation as K
from kornia.augmentation.base import _AugmentationBase
from kornia.augmentation.container.base import SequentialBase
from kornia.augmentation.container.image import ImageSequential, _get_new_batch_shape
from kornia.core import Module, Tensor
from kornia.geometry.boxes import Boxes
from kornia.geometry.keypoints import Keypoints

from .params import ParamItem

__all__ = ["VideoSequential"]


class VideoSequential(ImageSequential):
    r"""VideoSequential for processing 5-dim video data like (B, T, C, H, W) and (B, C, T, H, W).

    `VideoSequential` is used to replace `nn.Sequential` for processing video data augmentations.
    By default, `VideoSequential` enabled `same_on_frame` to make sure the same augmentations happen
    across temporal dimension. Meanwhile, it will not affect other augmentation behaviours like the
    settings on `same_on_batch`, etc.

    Args:
        *args: a list of augmentation module.
        data_format: only BCTHW and BTCHW are supported.
        same_on_frame: apply the same transformation across the channel per frame.
        random_apply: randomly select a sublist (order agnostic) of args to
            apply transformation.
            If int, a fixed number of transformations will be selected.
            If (a,), x number of transformations (a <= x <= len(args)) will be selected.
            If (a, b), x number of transformations (a <= x <= b) will be selected.
            If None, the whole list of args will be processed as a sequence.

    Note:
        Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module.
        Those transformations in ``kornia.geometry`` will not be taken into account.

    Example:
        If set `same_on_frame` to True, we would expect the same augmentation has been applied to each
        timeframe.

        >>> import kornia
        >>> input = torch.randn(2, 3, 1, 5, 6).repeat(1, 1, 4, 1, 1)
        >>> aug_list = VideoSequential(
        ...     kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     kornia.color.BgrToRgb(),
        ...     kornia.augmentation.RandomAffine(360, p=1.0),
        ...     random_apply=10,
        ...     data_format="BCTHW",
        ...     same_on_frame=True)
        >>> output = aug_list(input)
        >>> (output[0, :, 0] == output[0, :, 1]).all()
        tensor(True)
        >>> (output[0, :, 1] == output[0, :, 2]).all()
        tensor(True)
        >>> (output[0, :, 2] == output[0, :, 3]).all()
        tensor(True)

        If set `same_on_frame` to False:

        >>> aug_list = VideoSequential(
        ...     kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     kornia.augmentation.RandomAffine(360, p=1.0),
        ...     kornia.augmentation.RandomMixUpV2(p=1.0),
        ... data_format="BCTHW",
        ... same_on_frame=False)
        >>> output = aug_list(input)
        >>> output.shape
        torch.Size([2, 3, 4, 5, 6])
        >>> (output[0, :, 0] == output[0, :, 1]).all()
        tensor(False)

        Reproduce with provided params.
        >>> out2 = aug_list(input, params=aug_list._params)
        >>> torch.equal(output, out2)
        True

    Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights`` in ``VideoSequential``.

        >>> import kornia
        >>> input, label = torch.randn(2, 3, 1, 5, 6).repeat(1, 1, 4, 1, 1), torch.tensor([0, 1])
        >>> aug_list = VideoSequential(
        ...     kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     kornia.augmentation.RandomAffine(360, p=1.0),
        ...     kornia.augmentation.RandomMixUpV2(p=1.0),
        ... data_format="BCTHW",
        ... same_on_frame=False,
        ... random_apply=1,
        ... random_apply_weights=[0.5, 0.3, 0.8]
        ... )
        >>> out = aug_list(input)
        >>> out.shape
        torch.Size([2, 3, 4, 5, 6])

    """

    # TODO: implement transform_matrix

    def __init__(
        self,
        *args: Module,
        data_format: str = "BTCHW",
        same_on_frame: bool = True,
        random_apply: Union[int, bool, Tuple[int, int]] = False,
        random_apply_weights: Optional[List[float]] = None,
    ) -> None:
        super().__init__(
            *args,
            same_on_batch=None,
            keepdim=None,
            random_apply=random_apply,
            random_apply_weights=random_apply_weights,
        )
        self.same_on_frame = same_on_frame
        self.data_format = data_format.upper()
        if self.data_format not in ["BCTHW", "BTCHW"]:
            raise AssertionError(f"Only `BCTHW` and `BTCHW` are supported. Got `{data_format}`.")
        self._temporal_channel: int
        if self.data_format == "BCTHW":
            self._temporal_channel = 2
        elif self.data_format == "BTCHW":
            self._temporal_channel = 1

    def __infer_channel_exclusive_batch_shape__(self, batch_shape: torch.Size, chennel_index: int) -> torch.Size:
        # Fix mypy complains: error: Incompatible return value type (got "Tuple[int, ...]", expected "Size")
        return cast(torch.Size, batch_shape[:chennel_index] + batch_shape[chennel_index + 1 :])

    def __repeat_param_across_channels__(self, param: Tensor, frame_num: int) -> Tensor:
        """Repeat parameters across channels.

        The input is shaped as (B, ...), while to output (B * same_on_frame, ...), which
        to guarantee that the same transformation would happen for each frame.

        (B1, B2, ..., Bn) => (B1, ... B1, B2, ..., B2, ..., Bn, ..., Bn)
                              | ch_size | | ch_size |  ..., | ch_size |
        """
        repeated = param[:, None, ...].repeat(1, frame_num, *([1] * len(param.shape[1:])))
        return repeated.reshape(-1, *list(param.shape[1:]))

    def __broadcast_param__(
        self, v: Tensor, batch_shape: torch.Size, frame_num: int, same_on_frame: bool, same_on_batch: bool
    ) -> Tensor:
        if not v.numel():
            return v

        if same_on_frame and same_on_batch:
            return v.repeat(batch_shape[0] * frame_num, *([1] * (v.ndim - 1)))
        elif same_on_frame:
            return self.__repeat_param_across_channels__(v, frame_num)
        elif same_on_batch:
            return v.unsqueeze(1).repeat(1, batch_shape[0], *([1] * (v.ndim - 1))).reshape(-1, *v.shape[1:])
        return v

    def _input_shape_convert_in(self, input: Tensor, frame_num: int) -> Tensor:
        # Convert any shape to (B, T, C, H, W)
        if self.data_format == "BCTHW":
            # Convert (B, C, T, H, W) to (B, T, C, H, W)
            input = input.transpose(1, 2)
        if self.data_format == "BTCHW":
            pass

        input = input.reshape(-1, *input.shape[2:])
        return input

    def _input_shape_convert_back(self, input: Tensor, frame_num: int) -> Tensor:
        input = input.view(-1, frame_num, *input.shape[1:])
        if self.data_format == "BCTHW":
            input = input.transpose(1, 2)
        if self.data_format == "BTCHW":
            pass

        return input

    def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
        frame_num = batch_shape[self._temporal_channel]
        named_modules = self.get_forward_sequence()
        # Got param generation shape to (B, C, H, W). Ignoring T.
        batch_shape = self.__infer_channel_exclusive_batch_shape__(batch_shape, self._temporal_channel)

        params = []
        for name, module in named_modules:
            if isinstance(module, (K.RandomCrop, _AugmentationBase, K.MixAugmentationBaseV2)):
                is_same_on_batch = getattr(module, "same_on_batch", False)

                if self.same_on_frame and is_same_on_batch:
                    mod_shape = torch.Size([1, *batch_shape[1:]])
                elif self.same_on_frame:
                    mod_shape = batch_shape
                elif is_same_on_batch:
                    mod_shape = torch.Size([frame_num, *batch_shape[1:]])
                else:
                    mod_shape = torch.Size([batch_shape[0] * frame_num, *batch_shape[1:]])

                mod_param = module.forward_parameters(mod_shape)

                if isinstance(mod_param, dict):
                    for k, v in mod_param.items():
                        # TODO: revise ColorJiggle and ColorJitter order param in the future to align the standard.
                        if k == "order" and isinstance(module, (K.ColorJiggle, K.ColorJitter)):
                            continue
                        if k == "forward_input_shape":
                            mod_param.update({k: v})
                            continue
                        mod_param[k] = self.__broadcast_param__(
                            v, batch_shape, frame_num, self.same_on_frame, is_same_on_batch
                        )

                param = ParamItem(name, mod_param)

            elif isinstance(module, (SequentialBase,)):
                seq_param = module.forward_parameters(batch_shape)
                if self.same_on_frame:
                    raise ValueError("Sequential is currently unsupported for ``same_on_frame``.")
                param = ParamItem(name, seq_param)

            else:
                param = ParamItem(name, None)

            batch_shape = _get_new_batch_shape(param, batch_shape)
            params.append(param)

        return params

    def transform_inputs(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        frame_num: int = input.size(self._temporal_channel)
        input = self._input_shape_convert_in(input, frame_num)

        input = super().transform_inputs(input, params, extra_args=extra_args)

        input = self._input_shape_convert_back(input, frame_num)
        return input

    def inverse_inputs(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        frame_num: int = input.size(self._temporal_channel)
        input = self._input_shape_convert_in(input, frame_num)

        input = super().inverse_inputs(input, params, extra_args=extra_args)

        input = self._input_shape_convert_back(input, frame_num)
        return input

    def transform_masks(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        frame_num: int = input.size(self._temporal_channel)
        input = self._input_shape_convert_in(input, frame_num)

        input = super().transform_masks(input, params, extra_args=extra_args)

        input = self._input_shape_convert_back(input, frame_num)
        return input

    def inverse_masks(
        self, input: Tensor, params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        frame_num: int = input.size(self._temporal_channel)
        input = self._input_shape_convert_in(input, frame_num)

        input = super().inverse_masks(input, params, extra_args=extra_args)

        input = self._input_shape_convert_back(input, frame_num)
        return input

    def transform_boxes(  # type: ignore[override]
        self, input: Union[Tensor, Boxes], params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Union[Tensor, Boxes]:
        """Transform bounding boxes.

        Args:
            input: tensor with shape :math:`(B, T, N, 4, 2)`.
                If input is a `Keypoints` type, the internal shape is :math:`(B * T, N, 4, 2)`.
            params: params for the sequence.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if isinstance(input, Tensor):
            batchsize, frame_num = input.size(0), input.size(1)
            input = Boxes.from_tensor(input.view(-1, input.size(2), input.size(3), input.size(4)), mode="vertices_plus")
            input = super().transform_boxes(input, params, extra_args=extra_args)
            input = input.data.view(batchsize, frame_num, -1, 4, 2)
        else:
            input = super().transform_boxes(input, params, extra_args=extra_args)
        return input

    def inverse_boxes(  # type: ignore[override]
        self, input: Union[Tensor, Boxes], params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Union[Tensor, Boxes]:
        """Transform bounding boxes.

        Args:
            input: tensor with shape :math:`(B, T, N, 4, 2)`.
                If input is a `Keypoints` type, the internal shape is :math:`(B * T, N, 4, 2)`.
            params: params for the sequence.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if isinstance(input, Tensor):
            batchsize, frame_num = input.size(0), input.size(1)
            input = Boxes.from_tensor(input.view(-1, input.size(2), input.size(3), input.size(4)), mode="vertices_plus")
            input = super().inverse_boxes(input, params, extra_args=extra_args)
            input = input.data.view(batchsize, frame_num, -1, 4, 2)
        else:
            input = super().inverse_boxes(input, params, extra_args=extra_args)
        return input

    def transform_keypoints(  # type: ignore[override]
        self, input: Union[Tensor, Keypoints], params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Union[Tensor, Keypoints]:
        """Transform bounding boxes.

        Args:
            input: tensor with shape :math:`(B, T, N, 2)`.
                If input is a `Keypoints` type, the internal shape is :math:`(B * T, N, 2)`.
            params: params for the sequence.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if isinstance(input, Tensor):
            batchsize, frame_num = input.size(0), input.size(1)
            input = Keypoints(input.view(-1, input.size(2), input.size(3)))
            input = super().transform_keypoints(input, params, extra_args=extra_args)
            input = input.data.view(batchsize, frame_num, -1, 2)
        else:
            input = super().transform_keypoints(input, params, extra_args=extra_args)
        return input

    def inverse_keypoints(  # type: ignore[override]
        self, input: Union[Tensor, Keypoints], params: List[ParamItem], extra_args: Optional[Dict[str, Any]] = None
    ) -> Union[Tensor, Keypoints]:
        """Transform bounding boxes.

        Args:
            input: tensor with shape :math:`(B, T, N, 2)`.
                If input is a `Keypoints` type, the internal shape is :math:`(B * T, N, 2)`.
            params: params for the sequence.
            extra_args: Optional dictionary of extra arguments with specific options for different input types.
        """
        if isinstance(input, Tensor):
            frame_num, batchsize = input.size(0), input.size(1)
            input = Keypoints(input.view(-1, input.size(2), input.size(3)))
            input = super().inverse_keypoints(input, params, extra_args=extra_args)
            input = input.data.view(batchsize, frame_num, -1, 2)
        else:
            input = super().inverse_keypoints(input, params, extra_args=extra_args)
        return input

    def inverse(
        self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        """Inverse transformation.

        Used to inverse a tensor according to the performed transformation by a forward pass, or with respect to
        provided parameters.
        """
        if params is None:
            if self._params is not None:
                params = self._params
            else:
                raise RuntimeError("No valid params to inverse the transformation.")

        return self.inverse_inputs(input, params, extra_args=extra_args)

    def forward(
        self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Optional[Dict[str, Any]] = None
    ) -> Tensor:
        """Define the video computation performed."""
        if len(input.shape) != 5:
            raise AssertionError(f"Input must be a 5-dim tensor. Got {input.shape}.")

        if params is None:
            self._params = self.forward_parameters(input.shape)
            params = self._params

        output = self.transform_inputs(input, params, extra_args=extra_args)

        return output
