# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Module containing RANSAC modules."""

import math
from functools import partial
from typing import Callable, Optional, Tuple

import torch

from kornia.core import Device, Module, Tensor, zeros
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.geometry import (
    find_fundamental,
    find_homography_dlt,
    find_homography_dlt_iterated,
    find_homography_lines_dlt,
    find_homography_lines_dlt_iterated,
    symmetrical_epipolar_distance,
)
from kornia.geometry.homography import (
    line_segment_transfer_error_one_way,
    oneway_transfer_error,
    sample_is_valid_for_homography,
)

__all__ = ["RANSAC"]


class RANSAC(Module):
    """Module for robust geometry estimation with RANSAC. https://en.wikipedia.org/wiki/Random_sample_consensus.

    Args:
        model_type: type of model to estimate: "homography", "fundamental", "fundamental_7pt",
            "homography_from_linesegments".
        inliers_threshold: threshold for the correspondence to be an inlier.
        batch_size: number of generated samples at once.
        max_iterations: maximum batches to generate. Actual number of models to try is ``batch_size * max_iterations``.
        confidence: desired confidence of the result, used for the early stopping.
        max_local_iterations: number of local optimization (polishing) iterations.

    """

    def __init__(
        self,
        model_type: str = "homography",
        inl_th: float = 2.0,
        batch_size: int = 2048,
        max_iter: int = 10,
        confidence: float = 0.99,
        max_lo_iters: int = 5,
    ) -> None:
        """Initialize the RANSAC estimator.

        Args:
            model_type: type of model to estimate: "homography", "fundamental", "fundamental_7pt",
                "homography_from_linesegments".
            inl_th: threshold for the correspondence to be an inlier.
            batch_size: number of generated samples at once.
            max_iter: maximum batches to generate. Actual number of models to try is ``batch_size * max_iter``.
            confidence: desired confidence of the result, used for the early stopping.
            max_lo_iters: number of local optimization (polishing) iterations.

        """
        super().__init__()
        self.supported_models = ["homography", "fundamental", "fundamental_7pt", "homography_from_linesegments"]
        self.inl_th = inl_th
        self.max_iter = max_iter
        self.batch_size = batch_size
        self.model_type = model_type
        self.confidence = confidence
        self.max_lo_iters = max_lo_iters
        self.model_type = model_type

        self.error_fn: Callable[..., Tensor]
        self.minimal_solver: Callable[..., Tensor]
        self.polisher_solver: Callable[..., Tensor]

        if model_type == "homography":
            self.error_fn = oneway_transfer_error
            self.minimal_solver = find_homography_dlt
            self.polisher_solver = find_homography_dlt_iterated
            self.minimal_sample_size = 4
        elif model_type == "homography_from_linesegments":
            self.error_fn = line_segment_transfer_error_one_way
            self.minimal_solver = find_homography_lines_dlt
            self.polisher_solver = find_homography_lines_dlt_iterated
            self.minimal_sample_size = 4
        elif model_type == "fundamental":
            self.error_fn = symmetrical_epipolar_distance
            self.minimal_solver = find_fundamental
            self.minimal_sample_size = 8
            self.polisher_solver = find_fundamental
        elif model_type == "fundamental_7pt":
            self.error_fn = symmetrical_epipolar_distance
            self.minimal_solver = partial(find_fundamental, method="7POINT")
            self.minimal_sample_size = 7
            self.polisher_solver = find_fundamental
        else:
            raise NotImplementedError(f"{model_type} is unknown. Try one of {self.supported_models}")

    def sample(self, sample_size: int, pop_size: int, batch_size: int, device: Optional[Device] = None) -> Tensor:
        """Minimal sampler, but unlike traditional RANSAC we sample in batches.

        Yields the benefit of the parallel processing, esp. on GPU.

        Args:
            sample_size: number of samples to draw from the population.
            pop_size: size of the population to sample from.
            batch_size: number of sample sets to generate.
            device: device to place the samples on.

        Returns:
            Tensor of sampled indices with shape :math:`(batch_size, sample_size)`.

        """
        if device is None:
            device = torch.device("cpu")
        rand = torch.rand(batch_size, pop_size, device=device)
        _, out = rand.topk(k=sample_size, dim=1)
        return out

    @staticmethod
    def max_samples_by_conf(n_inl: int, num_tc: int, sample_size: int, conf: float) -> float:
        """Update max_iter to stop iterations earlier https://en.wikipedia.org/wiki/Random_sample_consensus.

        Args:
            n_inl: number of inliers.
            num_tc: total number of correspondences.
            sample_size: size of minimal sample.
            conf: desired confidence level.

        Returns:
            Maximum number of samples needed to achieve the desired confidence.

        """
        eps = 1e-9
        if num_tc <= sample_size:
            return 1.0
        if n_inl == num_tc:
            return 1.0
        return math.log(1.0 - conf) / min(-eps, math.log(max(eps, 1.0 - math.pow(n_inl / num_tc, sample_size))))

    def estimate_model_from_minsample(self, kp1: Tensor, kp2: Tensor) -> Tensor:
        """Estimate models from minimal samples.

        Args:
            kp1: source keypoints with shape :math:`(batch_size, sample_size, 2)`.
            kp2: target keypoints with shape :math:`(batch_size, sample_size, 2)`.

        Returns:
            Estimated models tensor.

        """
        batch_size, sample_size = kp1.shape[:2]
        H = self.minimal_solver(kp1, kp2, torch.ones(batch_size, sample_size, dtype=kp1.dtype, device=kp1.device))
        return H

    def verify(self, kp1: Tensor, kp2: Tensor, models: Tensor, inl_th: float) -> Tuple[Tensor, Tensor, float]:
        """Verify models by computing inliers and selecting the best model.

        Args:
            kp1: source keypoints.
            kp2: target keypoints.
            models: candidate models to verify.
            inl_th: inlier threshold.

        Returns:
            Tuple containing:
                - Best model
                - Inlier mask for the best model
                - Score of the best model

        """
        if len(kp1.shape) == 2:
            kp1 = kp1[None]
        if len(kp2.shape) == 2:
            kp2 = kp2[None]
        batch_size = models.shape[0]
        if self.model_type == "homography_from_linesegments":
            errors = self.error_fn(kp1.expand(batch_size, -1, 2, 2), kp2.expand(batch_size, -1, 2, 2), models)
        else:
            errors = self.error_fn(kp1.expand(batch_size, -1, 2), kp2.expand(batch_size, -1, 2), models)
        inl = errors <= inl_th
        models_score = inl.to(kp1).sum(dim=1)
        best_model_idx = models_score.argmax()
        best_model_score = models_score[best_model_idx].item()
        model_best = models[best_model_idx].clone()
        inliers_best = inl[best_model_idx]
        return model_best, inliers_best, best_model_score

    def remove_bad_samples(self, kp1: Tensor, kp2: Tensor) -> Tuple[Tensor, Tensor]:
        """Remove degenerate samples based on model-specific constraints.

        Args:
            kp1: source keypoints.
            kp2: target keypoints.

        Returns:
            Tuple of filtered keypoints (kp1, kp2).

        """
        # ToDo: add (model-specific) verification of the samples,
        # E.g. constraints on not to be a degenerate sample
        if self.model_type == "homography":
            mask = sample_is_valid_for_homography(kp1, kp2)
            return kp1[mask], kp2[mask]
        return kp1, kp2

    def remove_bad_models(self, models: Tensor) -> Tensor:
        """Remove degenerate models based on simple heuristics.

        Args:
            models: candidate models to filter.

        Returns:
            Filtered models tensor.

        """
        # ToDo: add more and better degenerate model rejection
        # For now it is simple and hardcoded
        main_diagonal = torch.diagonal(models, dim1=1, dim2=2)
        mask = main_diagonal.abs().min(dim=1)[0] > 1e-4
        return models[mask]

    def polish_model(self, kp1: Tensor, kp2: Tensor, inliers: Tensor) -> Tensor:
        """Polish the model using inliers through local optimization.

        Args:
            kp1: source keypoints.
            kp2: target keypoints.
            inliers: boolean mask indicating inlier correspondences.

        Returns:
            Polished model tensor.

        """
        # TODO: Replace this with MAGSAC++ polisher
        kp1_inl = kp1[inliers][None]
        kp2_inl = kp2[inliers][None]
        num_inl = kp1_inl.size(1)
        model = self.polisher_solver(
            kp1_inl, kp2_inl, torch.ones(1, num_inl, dtype=kp1_inl.dtype, device=kp1_inl.device)
        )
        return model

    def validate_inputs(self, kp1: Tensor, kp2: Tensor, weights: Optional[Tensor] = None) -> None:
        """Validate input tensors for shape and size requirements.

        Args:
            kp1: source keypoints.
            kp2: target keypoints.
            weights: optional correspondence weights (not used currently).

        Raises:
            ValueError: if input shapes are invalid or insufficient correspondences.

        """
        if self.model_type in ["homography", "fundamental"]:
            KORNIA_CHECK_SHAPE(kp1, ["N", "2"])
            KORNIA_CHECK_SHAPE(kp2, ["N", "2"])
            if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size):
                raise ValueError(
                    "kp1 and kp2 should be                                  equal shape at least"
                    f" [{self.minimal_sample_size}, 2],                                  got {kp1.shape}, {kp2.shape}"
                )
        if self.model_type == "homography_from_linesegments":
            KORNIA_CHECK_SHAPE(kp1, ["N", "2", "2"])
            KORNIA_CHECK_SHAPE(kp2, ["N", "2", "2"])
            if not (kp1.shape[0] == kp2.shape[0]) or (kp1.shape[0] < self.minimal_sample_size):
                raise ValueError(
                    "kp1 and kp2 should be                                  equal shape at least"
                    f" [{self.minimal_sample_size}, 2, 2],                                  got {kp1.shape},"
                    f" {kp2.shape}"
                )

    def forward(self, kp1: Tensor, kp2: Tensor, weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        r"""Call main forward method to execute the RANSAC algorithm.

        Args:
            kp1: source image keypoints :math:`(N, 2)`.
            kp2: distance image keypoints :math:`(N, 2)`.
            weights: optional correspondences weights. Not used now.

        Returns:
            - Estimated model, shape of :math:`(1, 3, 3)`.
            - The inlier/outlier mask, shape of :math:`(1, N)`, where N is number of input correspondences.

        """
        self.validate_inputs(kp1, kp2, weights)
        best_score_total: float = float(self.minimal_sample_size)
        num_tc: int = len(kp1)
        best_model_total = zeros(3, 3, dtype=kp1.dtype, device=kp1.device)
        inliers_best_total: Tensor = zeros(num_tc, 1, device=kp1.device, dtype=torch.bool)
        for i in range(self.max_iter):
            # Sample minimal samples in batch to estimate models
            idxs = self.sample(self.minimal_sample_size, num_tc, self.batch_size, kp1.device)
            kp1_sampled = kp1[idxs]
            kp2_sampled = kp2[idxs]

            kp1_sampled, kp2_sampled = self.remove_bad_samples(kp1_sampled, kp2_sampled)
            if len(kp1_sampled) == 0:
                continue
            # Estimate models
            models = self.estimate_model_from_minsample(kp1_sampled, kp2_sampled)
            models = self.remove_bad_models(models)
            if (models is None) or (len(models) == 0):
                continue
            # Score the models and select the best one
            model, inliers, model_score = self.verify(kp1, kp2, models, self.inl_th)
            # Store far-the-best model and (optionally) do a local optimization
            if model_score > best_score_total:
                # Local optimization
                for _ in range(self.max_lo_iters):
                    model_lo = self.polish_model(kp1, kp2, inliers)
                    if (model_lo is None) or (len(model_lo) == 0):
                        continue
                    _, inliers_lo, score_lo = self.verify(kp1, kp2, model_lo, self.inl_th)
                    # print (f"Orig score = {best_model_score}, LO score = {score_lo} TC={num_tc}")
                    if score_lo > model_score:
                        model = model_lo.clone()[0]
                        inliers = inliers_lo.clone()
                        model_score = score_lo
                    else:
                        break
                # Now storing the best model
                best_model_total = model.clone()
                inliers_best_total = inliers.clone()
                best_score_total = model_score

                # Should we already stop?
                new_max_iter = int(
                    self.max_samples_by_conf(int(best_score_total), num_tc, self.minimal_sample_size, self.confidence)
                )
                # print (f"New max_iter = {new_max_iter}")
                # Stop estimation, if the model is very good
                if (i + 1) * self.batch_size >= new_max_iter:
                    break
        # local optimization with all inliers for better precision
        return best_model_total, inliers_best_total
