# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import torch
from torch import Tensor

from torchmetrics.functional.image.utils import _uniform_filter
from torchmetrics.utilities.checks import _check_same_shape


def _rmse_sw_update(
    preds: Tensor,
    target: Tensor,
    window_size: int,
    rmse_val_sum: Optional[Tensor],
    rmse_map: Optional[Tensor],
    total_images: Optional[Tensor],
) -> tuple[Tensor, Tensor, Tensor]:
    """Calculate the sum of RMSE values and RMSE map for the batch of examples and update intermediate states.

    Args:
        preds: Deformed image
        target: Ground truth image
        window_size: Sliding window used for rmse calculation
        rmse_val_sum: Sum of RMSE over all examples per individual channels
        rmse_map: Sum of RMSE map values over all examples
        total_images: Total number of images

    Return:
        (Optionally) Intermediate state of RMSE (using sliding window) over the accumulated examples.
        (Optionally) Intermediate state of RMSE map
        Updated total number of already processed images

    Raises:
        ValueError: If ``preds`` and ``target`` do not have the same data type.
        ValueError: If ``preds`` and ``target`` do not have ``BxCxWxH`` shape.
        ValueError: If ``round(window_size / 2)`` is greater or equal to width or height of the image.

    """
    if preds.dtype != target.dtype:
        raise TypeError(
            f"Expected `preds` and `target` to have the same data type. But got {preds.dtype} and {target.dtype}."
        )
    _check_same_shape(preds, target)
    if len(preds.shape) != 4:
        raise ValueError(f"Expected `preds` and `target` to have BxCxHxW shape. But got {preds.shape}.")

    if round(window_size / 2) >= target.shape[2] or round(window_size / 2) >= target.shape[3]:
        raise ValueError(
            f"Parameter `round(window_size / 2)` is expected to be smaller than {min(target.shape[2], target.shape[3])}"
            f" but got {round(window_size / 2)}."
        )

    if total_images is not None:
        total_images += target.shape[0]
    else:
        total_images = torch.tensor(target.shape[0], device=target.device)
    error = (target - preds) ** 2
    error = _uniform_filter(error, window_size)
    _rmse_map = torch.sqrt(error)
    crop_slide = round(window_size / 2)

    if rmse_val_sum is not None:
        rmse_val = _rmse_map[:, :, crop_slide:-crop_slide, crop_slide:-crop_slide]
        rmse_val_sum += rmse_val.sum(0).mean()
    else:
        rmse_val_sum = _rmse_map[:, :, crop_slide:-crop_slide, crop_slide:-crop_slide].sum(0).mean()

    if rmse_map is not None:
        rmse_map += _rmse_map.sum(0)
    else:
        rmse_map = _rmse_map.sum(0)

    return rmse_val_sum, rmse_map, total_images


def _rmse_sw_compute(
    rmse_val_sum: Optional[Tensor], rmse_map: Tensor, total_images: Tensor
) -> tuple[Optional[Tensor], Tensor]:
    """Compute RMSE from the aggregated RMSE value. Optionally also computes the mean value for RMSE map.

    Args:
        rmse_val_sum: Sum of RMSE over all examples
        rmse_map: Sum of RMSE map values over all examples
        total_images: Total number of images

    Return:
        RMSE using sliding window
        (Optionally) RMSE map

    """
    rmse = rmse_val_sum / total_images if rmse_val_sum is not None else None
    if rmse_map is not None:
        # prevent overwrite the inputs
        rmse_map = rmse_map / total_images
    return rmse, rmse_map


def root_mean_squared_error_using_sliding_window(
    preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False
) -> Union[Optional[Tensor], tuple[Optional[Tensor], Tensor]]:
    """Compute Root Mean Squared Error (RMSE) using sliding window.

    Args:
        preds: Deformed image
        target: Ground truth image
        window_size: Sliding window used for rmse calculation
        return_rmse_map: An indication whether the full rmse reduced image should be returned.

    Return:
        RMSE using sliding window
        (Optionally) RMSE map

    Example:
        >>> from torch import rand
        >>> from torchmetrics.functional.image import root_mean_squared_error_using_sliding_window
        >>> preds = rand(4, 3, 16, 16)
        >>> target = rand(4, 3, 16, 16)
        >>> root_mean_squared_error_using_sliding_window(preds, target)
        tensor(0.4158)

    Raises:
        ValueError: If ``window_size`` is not a positive integer.

    """
    if not isinstance(window_size, int) or (isinstance(window_size, int) and window_size < 1):
        raise ValueError("Argument `window_size` is expected to be a positive integer.")

    rmse_val_sum, rmse_map, total_images = _rmse_sw_update(
        preds, target, window_size, rmse_val_sum=None, rmse_map=None, total_images=None
    )
    rmse, rmse_map = _rmse_sw_compute(rmse_val_sum, rmse_map, total_images)

    if return_rmse_map:
        return rmse, rmse_map
    return rmse
