# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import math
from typing import Dict, List, Tuple

import torch

from kornia.core import Device, Tensor, tensor
from kornia.geometry.camera import PinholeCamera
from kornia.nerf.camera_utils import cameras_for_ids
from kornia.utils._compat import torch_meshgrid
from kornia.utils.helpers import _torch_inverse_cast


class RaySampler:
    r"""Class to manage spatial ray sampling.

    Args:
        min_depth: sampled rays minimal depth from cameras: float
        max_depth: sampled rays maximal depth from cameras: float
        ndc: convert ray parameters to normalized device coordinates: bool
        device: device for ray tensors: Union[str, torch.device]

    """

    _origins: Tensor  # Ray origins in world coordinates (*, 3)
    _directions: Tensor  # Ray directions in world coordinates (*, 3)
    _directions_cam: Tensor  # Ray directions in camera coordinates (*, 3)
    _origins_cam: Tensor  # Ray origins in camera coordinates (*, 3)
    _camera_ids: Tensor  # Ray camera ID
    _points_2d: Tensor  # Ray intersection with image plane in camera coordinates

    def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
        self._min_depth = min_depth
        self._max_depth = max_depth
        self._ndc = ndc
        self._device = device
        self._dtype = dtype

    @property
    def origins(self) -> Tensor:
        return self._origins

    @property
    def directions(self) -> Tensor:
        return self._directions

    @property
    def camera_ids(self) -> Tensor:
        return self._camera_ids

    @property
    def points_2d(self) -> Tensor:
        return self._points_2d

    def __len__(self) -> int:
        if self.origins is None:
            return 0
        return self.origins.shape[0]

    def _calc_ray_directions_cam(self, cameras: PinholeCamera, points_2d: Tensor) -> Tensor:
        # FIXME: This function should call perspective.unproject_points or, implement in PinholeCamera unproject to
        # camera coordinates that will call perspective.unproject_points
        fx = cameras.fx
        fy = cameras.fy
        cx = cameras.cx
        cy = cameras.cy
        directions_x = (points_2d[..., 0] - cx[..., None]) / fx[..., None]
        directions_y = (points_2d[..., 1] - cy[..., None]) / fy[..., None]
        directions_z = torch.ones_like(directions_x)
        directions_cam = torch.stack([directions_x, directions_y, directions_z], dim=-1)
        return directions_cam.reshape(-1, 3)

    class Points2D:
        r"""A class to hold ray 2d pixel coordinates and a camera id for each.

        Args:
            points_2d: tensor with ray pixel coordinates (the coordinates in the image plane that correspond to the
              ray):math:`(B, 2)`
            camera_ids: list of camera ids for each pixel coordinates: List[int]

        """

        def __init__(self, points_2d: Tensor, camera_ids: List[int]) -> None:
            self._points_2d = points_2d  # (*, N, 2)
            self._camera_ids = camera_ids

        @property
        def points_2d(self) -> Tensor:
            return self._points_2d

        @property
        def camera_ids(self) -> List[int]:
            return self._camera_ids

    def _calc_ray_params(self, cameras: PinholeCamera, points_2d_camera: Dict[int, Points2D]) -> None:
        r"""Calculate ray parameters: origins, directions.

        Also stored are camera ids for each ray, and its pixel coordinates.

        Args:
            cameras: scene cameras: PinholeCamera
            points_2d_camera: a dictionary that groups Point2D objects by total number of casted rays

        """
        # Unproject 2d points in image plane to 3d world for two depths
        origins = []
        directions = []
        directions_cam = []
        origins_cam = []
        camera_ids = []
        points_2d = []
        for obj in points_2d_camera.values():
            # FIXME: Below both world and camera ray directions are calculated. It could be that world ray directions
            # will not be necessary and can be removed here
            num_cams_group, num_points_per_cam_group = obj._points_2d.shape[:2]
            depths = (
                torch.ones(num_cams_group, 2 * num_points_per_cam_group, 3, device=self._device, dtype=self._dtype)
                * self._min_depth
            )
            depths[:, num_points_per_cam_group:] = self._max_depth
            cams = cameras_for_ids(cameras, obj.camera_ids)
            points_3d = cams.unproject(obj._points_2d.repeat(1, 2, 1), depths)
            origins.append(points_3d[..., :num_points_per_cam_group, :].reshape(-1, 3))
            directions.append(
                (points_3d[..., num_points_per_cam_group:, :] - points_3d[..., :num_points_per_cam_group, :]).reshape(
                    -1, 3
                )
            )
            directions_cam.append(self._calc_ray_directions_cam(cams, obj._points_2d))
            origins_cam.append(directions_cam[-1] * self._min_depth)
            camera_ids.append(
                tensor(obj.camera_ids).repeat(num_points_per_cam_group, 1).permute(1, 0).reshape(1, -1).squeeze(0)
            )
            points_2d.append(obj._points_2d.reshape(-1, 2).int())
        self._origins = torch.cat(origins)
        self._directions = torch.cat(directions)
        self._directions_cam = torch.cat(directions_cam)
        self._origins_cam = torch.cat(origins_cam)
        self._camera_ids = torch.cat(camera_ids)
        if self._ndc:  # Transform ray parameters to NDC, if defined
            self._origins, self._directions = self.transform_ray_params_world_to_ndc(cameras)
        self._points_2d = torch.cat(points_2d)

    def transform_ray_params_world_to_ndc(self, cameras: PinholeCamera) -> Tuple[Tensor, Tensor]:
        r"""Transform ray parameters to normalized coordinate device (camera) system (NDC).

        Args:
        cameras: scene cameras: PinholeCamera

        """
        cams = cameras_for_ids(cameras, self._camera_ids)
        fx = cams.fx
        fy = cams.fy
        widths = cams.width
        heights = cams.height
        fx_widths = 2.0 * fx / (widths - 1.0)
        fy_heights = 2.0 * fy / (heights - 1.0)

        # oxoz = self._origins_cam[..., 0] / self._origins_cam[..., 2]
        # oyoz = self._origins_cam[..., 1] / self._origins_cam[..., 2]

        oxoz = self._origins[..., 0] / self._origins[..., 2]
        oyoz = self._origins[..., 1] / self._origins[..., 2]

        origins_ndc_x = fx_widths * oxoz
        origins_ndc_y = fy_heights * oyoz

        # origins_ndc_z = 1 - 2 * self._min_depth / self._origins_cam[..., 2]

        origins_ndc_z = 1 - 2 * self._min_depth / self._origins[..., 2]

        origins_ndc = torch.stack([origins_ndc_x, origins_ndc_y, origins_ndc_z], dim=-1)

        # dxdz = self._directions_cam[..., 0] / self._directions_cam[..., 2]
        # dydz = self._directions_cam[..., 1] / self._directions_cam[..., 2]

        Rt_inv = _torch_inverse_cast(cams.rotation_matrix)
        directions_rotated_world = (Rt_inv @ self._directions_cam[..., None]).squeeze(dim=-1)

        dxdz = directions_rotated_world[..., 0] / directions_rotated_world[..., 2]
        dydz = directions_rotated_world[..., 1] / directions_rotated_world[..., 2]

        directions_ndc_x = fx_widths * dxdz - origins_ndc_x
        directions_ndc_y = fy_heights * dydz - origins_ndc_y
        directions_ndc_z = 1 - origins_ndc_z
        directions_ndc = torch.stack([directions_ndc_x, directions_ndc_y, directions_ndc_z], dim=-1)

        # Rt_inv = _torch_inverse_cast(cams.rotation_matrix)
        # origins_ndc_world = (Rt_inv @ origins_ndc[..., None]).squeeze(dim=-1)
        # directions_ndc_world = (Rt_inv @ directions_ndc[..., None]).squeeze(dim=-1)

        origins_ndc_world = origins_ndc
        directions_ndc_world = directions_ndc

        return origins_ndc_world, directions_ndc_world

    class Points2D_FlatTensors:
        r"""Class to hold x/y pixel coordinates for each ray, and its scene camera id."""

        def __init__(self) -> None:
            self._x: Tensor
            self._y: Tensor
            self._camera_ids: List[int] = []

    @staticmethod
    def _add_points2d_as_flat_tensors_to_num_ray_dict(
        n: int, x: Tensor, y: Tensor, camera_id: int, points2d_as_flat_tensors: Dict[int, Points2D_FlatTensors]
    ) -> None:
        r"""Add x/y pixel coordinates for all rays casted by a scene camera to dictionary of pixel coordinates
        grouped by total number of rays.
        """  # noqa: D205
        if n not in points2d_as_flat_tensors:
            points2d_as_flat_tensors[n] = RaySampler.Points2D_FlatTensors()
            points2d_as_flat_tensors[n]._x = x.flatten()
            points2d_as_flat_tensors[n]._y = y.flatten()
        else:
            points2d_as_flat_tensors[n]._x = torch.cat((points2d_as_flat_tensors[n]._x, x.flatten()))
            points2d_as_flat_tensors[n]._y = torch.cat((points2d_as_flat_tensors[n]._y, y.flatten()))
        points2d_as_flat_tensors[n]._camera_ids.append(camera_id)

    @staticmethod
    def _build_num_ray_dict_of_points2d(
        points2d_as_flat_tensors: Dict[int, Points2D_FlatTensors],
    ) -> Dict[int, Points2D]:
        r"""Build a dictionary of ray pixel points, by total number of rays as key.

        The dictionary groups rays by the total amount of rays, which allows the case of casting different number
        of rays from each scene camera.

        Args:
            points2d_as_flat_tensors: dictionary of pixel coordinates grouped by total number of rays:
              Dict[int, Points2D_FlatTensors]

        Returns:
            dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
              id it was casted by: Dict[int, Points2D]

        """
        num_ray_dict_of_points2d: Dict[int, RaySampler.Points2D] = {}
        for n, points2d_as_flat_tensor in points2d_as_flat_tensors.items():
            num_cams = len(points2d_as_flat_tensor._camera_ids)
            points_2d = (
                torch.stack((points2d_as_flat_tensor._x, points2d_as_flat_tensor._y))
                .permute(1, 0)
                .reshape(num_cams, -1, 2)
            )
            num_ray_dict_of_points2d[n] = RaySampler.Points2D(points_2d, points2d_as_flat_tensor._camera_ids)
        return num_ray_dict_of_points2d


class RandomRaySampler(RaySampler):
    r"""Class to manage random ray spatial sampling.

    Args:
        min_depth: sampled rays minimal depth from cameras: float
        max_depth: sampled rays maximal depth from cameras: float
        ndc: convert to normalized device coordinates: bool
        device: device for ray tensors: Union[str, torch.device]

    """

    def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
        super().__init__(min_depth, max_depth, ndc, device, dtype)

    def sample_points_2d(self, heights: Tensor, widths: Tensor, num_img_rays: Tensor) -> Dict[int, RaySampler.Points2D]:
        r"""Randomly sample pixel points in 2d.

        Args:
            heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
            widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
            num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera: math: `(B)`.

        Returns:
            dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
              id it was casted by: Dict[int, Points2D]

        """
        num_img_rays = num_img_rays.int()
        points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
        for camera_id, (height, width, n) in enumerate(zip(heights.tolist(), widths.tolist(), num_img_rays.tolist())):
            y_rand = torch.trunc(torch.rand(n, device=self._device, dtype=self._dtype) * height)
            x_rand = torch.trunc(torch.rand(n, device=self._device, dtype=self._dtype) * width)
            RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
                n, x_rand, y_rand, camera_id, points2d_as_flat_tensors
            )
        return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)

    def calc_ray_params(self, cameras: PinholeCamera, num_img_rays: Tensor) -> None:
        r"""Calculate ray parameters: origins, directions.

        Also stored are camera ids for each ray, and its pixel coordinates.

        Args:
            cameras: scene cameras: PinholeCamera
            num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera: int math: `(B)`.

        """
        num_cams = cameras.batch_size
        if num_cams != num_img_rays.shape[0]:
            raise ValueError(
                f"Number of cameras {num_cams} does not match size of tensor to define number of rays to march from "
                f"each camera {num_img_rays.shape[0]}"
            )
        points_2d_camera = self.sample_points_2d(cameras.height, cameras.width, num_img_rays)
        self._calc_ray_params(cameras, points_2d_camera)


class RandomGridRaySampler(RandomRaySampler):
    r"""Class to manage random ray spatial sampling.

    Sampling is done on a regular grid of pixels by randomizing
    column and row values, and casting rays for all pixels along the selected ones.

    Args:
        min_depth: sampled rays minimal depth from cameras: float
        max_depth: sampled rays maximal depth from cameras: float
        ndc: convert to normalized device coordinates: bool
        device: device for ray tensors: Union[str, torch.device]

    """

    def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
        super().__init__(min_depth, max_depth, ndc, device, dtype)

    def sample_points_2d(self, heights: Tensor, widths: Tensor, num_img_rays: Tensor) -> Dict[int, RaySampler.Points2D]:
        r"""Randomly sample pixel points in 2d over a regular row-column grid.

        Args:
            heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
            widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
            num_img_rays: tensor that holds the number of rays to randomly cast from each scene camera. Number of rows
              and columns is the square root of this value: int math: `(B)`.

        Returns:
            dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
              id it was casted by: Dict[int, Points2D]

        """
        num_img_rays = num_img_rays.int()
        points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
        for camera_id, (height, width, n) in enumerate(zip(heights.tolist(), widths.tolist(), num_img_rays.tolist())):
            n_sqrt = int(math.sqrt(n))
            y_rand = torch.randperm(int(height), device=self._device, dtype=self._dtype)[: min(int(height), n_sqrt)]
            x_rand = torch.randperm(int(width), device=self._device, dtype=self._dtype)[: min(int(width), n_sqrt)]
            y_grid, x_grid = torch_meshgrid([y_rand, x_rand], indexing="ij")
            RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
                n_sqrt * n_sqrt, x_grid, y_grid, camera_id, points2d_as_flat_tensors
            )
        return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)


class UniformRaySampler(RaySampler):
    r"""Class to manage uniform ray spatial sampling for all camera scene pixels.

    Args:
        min_depth: sampled rays minimal depth from cameras: float
        max_depth: sampled rays maximal depth from cameras: float
        ndc: convert to normalized device coordinates: bool
        device: device for ray tensors: Union[str, torch.device]

    """

    def __init__(self, min_depth: float, max_depth: float, ndc: bool, device: Device, dtype: torch.dtype) -> None:
        super().__init__(min_depth, max_depth, ndc, device, dtype)

    def sample_points_2d(
        self, heights: Tensor, widths: Tensor, sampling_step: int = 1
    ) -> Dict[int, RaySampler.Points2D]:
        r"""Uniformly sample pixel points in 2d for all scene camera pixels.

        Args:
            heights: tensor that holds scene camera image heights (can vary between cameras): math: `(B)`.
            widths: tensor that holds scene camera image widths (can vary between cameras): math: `(B)`.
            sampling_step: defines uniform strides between rows and columns: int.

        Returns:
            dictionary of Points2D objects that holds information on pixel 2d coordinates of each ray and the camera
              id it was casted by: Dict[int, Points2D]

        """
        heights = heights.int()
        widths = widths.int()
        points2d_as_flat_tensors: Dict[int, RaySampler.Points2D_FlatTensors] = {}
        for camera_id, (height, width) in enumerate(zip(heights.tolist(), widths.tolist())):
            n = height * width
            y_grid, x_grid = torch_meshgrid(
                [
                    torch.arange(0, height, sampling_step, device=self._device, dtype=self._dtype),
                    torch.arange(0, width, sampling_step, device=self._device, dtype=self._dtype),
                ],
                indexing="ij",
            )
            RaySampler._add_points2d_as_flat_tensors_to_num_ray_dict(
                n, x_grid, y_grid, camera_id, points2d_as_flat_tensors
            )
        return RaySampler._build_num_ray_dict_of_points2d(points2d_as_flat_tensors)

    def calc_ray_params(self, cameras: PinholeCamera) -> None:
        points_2d_camera = self.sample_points_2d(cameras.height, cameras.width)
        self._calc_ray_params(cameras, points_2d_camera)


def sample_lengths(
    num_rays: int, num_ray_points: int, device: Device, dtype: torch.dtype, irregular: bool = False
) -> Tensor:
    """Sample points along the length of rays."""
    if num_ray_points <= 1:
        raise ValueError("Number of ray points must be greater than 1")
    if not irregular:
        zero_to_one = torch.linspace(0.0, 1.0, num_ray_points, device=device, dtype=dtype)
        lengths = zero_to_one.repeat(num_rays, 1)  # FIXME: Expand instead of repeat maybe?
    else:
        zero_to_one = torch.linspace(0.0, 1.0, num_ray_points + 1, device=device, dtype=dtype)
        lengths = torch.rand(num_rays, num_ray_points, device=device) / num_ray_points + zero_to_one[:-1]
    return lengths


# TODO: Implement hierarchical ray sampling as described in Mildenhall (2020) Sec. 5.2


def sample_ray_points(
    origins: Tensor, directions: Tensor, lengths: Tensor
) -> Tensor:  # FIXME: Test by projecting to points_2d and compare with sampler 2d points
    r"""Sample points along ray.

    Args:
        origins: tensor containing ray origins in 3d world coordinates. Tensor shape :math:`(*, 3)`.
        directions: tensor containing ray directions in 3d world coordinates. Tensor shape :math:`(*, 3)`.
        lengths: tensor containing sampled distances along each ray. Tensor shape :math:`(*, num_ray_points)`.

    Returns:
        points_3d: Points along rays :math:`(*, num_ray_points, 3)`

    """
    points_3d = origins[..., None, :] + lengths[..., None] * directions[..., None, :]
    return points_3d


def calc_ray_t_vals(points_3d: Tensor) -> Tensor:
    r"""Calculate t values along rays.

    Args:
        points_3d: Points along rays :math:`(*, num_ray_points, 3)`

    Returns:
        t values along rays :math:`(*, num_ray_points)`

    """
    t_vals = torch.linalg.norm(points_3d - points_3d[..., 0, :].unsqueeze(-2), dim=-1)
    return t_vals
