# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Callable, Optional, Tuple

from torch import Tensor, nn
from torch.autograd import Function

__all__ = ["STEFunction", "StraightThroughEstimator"]


class STEFunction(Function):
    """Straight-Through Estimation (STE) function.

    STE bridges the gradients between the input tensor and output tensor as if the function
    was an identity function. Meanwhile, advanced gradient functions are also supported. e.g.
    the output gradients can be mapped into [-1, 1] with ``F.hardtanh`` function.

    Args:
        grad_fn: function to restrain the gradient received. If None, no mapping will performed.

    Example:
        Let the gradients of ``torch.sign`` estimated from STE.
        >>> input = torch.randn(4, requires_grad = True)
        >>> output = torch.sign(input)
        >>> loss = output.mean()
        >>> loss.backward()
        >>> input.grad
        tensor([0., 0., 0., 0.])

        >>> with torch.no_grad():
        ...     output = torch.sign(input)
        >>> out_est = STEFunction.apply(input, output)
        >>> loss = out_est.mean()
        >>> loss.backward()
        >>> input.grad
        tensor([0.2500, 0.2500, 0.2500, 0.2500])

    """

    @staticmethod
    def forward(ctx: Any, input: Tensor, output: Tensor, grad_fn: Optional[Callable[..., Any]] = None) -> Tensor:
        ctx.in_shape = input.shape
        ctx.out_shape = output.shape
        ctx.grad_fn = grad_fn
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: Tensor) -> Tuple[Tensor, Tensor, None]:
        if ctx.grad_fn is None:
            return grad_output.sum_to_size(ctx.in_shape), grad_output.sum_to_size(ctx.out_shape), None
        return (
            ctx.grad_fn(grad_output.sum_to_size(ctx.in_shape)),
            ctx.grad_fn(grad_output.sum_to_size(ctx.out_shape)),
            None,
        )

    # https://pytorch.org/docs/1.10.0/onnx.html#torch-autograd-functions
    # @staticmethod
    # def symbolic(g: torch._C.graph, input: torch._C.Value) -> torch._C.Value:
    #     raise NotImplementedError(
    #         "ONNX support is not implemented at the moment."
    #         "Feel free to contribute to https://github.com/kornia/kornia.")


class StraightThroughEstimator(nn.Module):
    """Straight-Through Estimation (STE) module.

    STE wraps the ``STEFunction`` to aid the back propagation of non-differentiable modules.
    It may also use to avoid gradient computation for differentiable operations. By default,
    STE bridges the gradients between the input tensor and output tensor as if the function
    was an identity function. Meanwhile, advanced gradient functions are also supported. e.g.
    the output gradients can be mapped into [-1, 1] with ``F.hardtanh`` function.

    Args:
        target_fn: the target function to wrap with.
        grad_fn: function to restrain the gradient received. If None, no mapping will performed.

    Example:
        ``RandomPosterize`` is a non-differentiable operation. Let STE estimate the gradients.
        >>> import kornia.augmentation as K
        >>> import torch.nn.functional as F
        >>> input = torch.randn(1, 1, 4, 4, requires_grad = True)
        >>> estimator = StraightThroughEstimator(K.RandomPosterize(3, p=1.), grad_fn=F.hardtanh)
        >>> out = estimator(input)
        >>> out.mean().backward()
        >>> input.grad
        tensor([[[[0.0625, 0.0625, 0.0625, 0.0625],
                  [0.0625, 0.0625, 0.0625, 0.0625],
                  [0.0625, 0.0625, 0.0625, 0.0625],
                  [0.0625, 0.0625, 0.0625, 0.0625]]]])

        This can be used to chain up the gradients within a ``Sequential`` block.
        >>> import kornia.augmentation as K
        >>> input = torch.randn(1, 1, 4, 4, requires_grad = True)
        >>> aug = K.ImageSequential(
        ...     K.RandomAffine((77, 77)),
        ...     StraightThroughEstimator(K.RandomPosterize(3, p=1.), grad_fn=None),
        ...     K.RandomRotation((15, 15)),
        ... )
        >>> aug(input).mean().backward()
        >>> input.grad
        tensor([[[[0.0422, 0.0626, 0.0566, 0.0422],
                  [0.0566, 0.0626, 0.0626, 0.0626],
                  [0.0626, 0.0626, 0.0626, 0.0566],
                  [0.0422, 0.0566, 0.0626, 0.0422]]]])

    """

    def __init__(self, target_fn: nn.Module, grad_fn: Optional[Callable[..., Any]] = None) -> None:
        super().__init__()
        self.target_fn = target_fn
        self.grad_fn = grad_fn

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(target_fn={self.target_fn}, grad_fn={self.grad_fn})"

    def forward(self, input: Tensor) -> Tensor:
        out = self.target_fn(input)
        if not isinstance(out, Tensor):
            raise NotImplementedError(
                "Only Tensor is supported at the moment. Feel free to contribute to https://github.com/kornia/kornia."
            )
        output = STEFunction.apply(input, out, self.grad_fn)
        return output
