# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Module containing the functionalities for computing the real roots of polynomial equation."""

import math

import torch

from kornia.core import Tensor, cos, ones_like, stack, zeros, zeros_like
from kornia.core.check import KORNIA_CHECK_SHAPE


# Reference : https://github.com/opencv/opencv/blob/4.x/modules/calib3d/src/polynom_solver.cpp
def solve_quadratic(coeffs: Tensor) -> Tensor:
    r"""Solve given quadratic equation.

    The function takes the coefficients of quadratic equation and returns the real roots.

    .. math:: coeffs[0]x^2 + coeffs[1]x + coeffs[2] = 0

    Args:
        coeffs : The coefficients of quadratic equation :`(B, 3)`

    Returns:
        A tensor of shape `(B, 2)` containing the real roots to the quadratic equation.

    Example:
        >>> coeffs = torch.tensor([[1., 4., 4.]])
        >>> roots = solve_quadratic(coeffs)

    .. note::
       In cases where a quadratic polynomial has only one real root, the output will be in the format
       [real_root, 0]. And for the complex roots should be represented as 0. This is done to maintain
       a consistent output shape for all cases.

    """
    KORNIA_CHECK_SHAPE(coeffs, ["B", "3"])

    # Coefficients of quadratic equation
    a = coeffs[:, 0]  # coefficient of x^2
    b = coeffs[:, 1]  # coefficient of x
    c = coeffs[:, 2]  # constant term

    # Calculate discriminant
    delta = b * b - 4 * a * c

    # Create masks for negative and zero discriminant
    mask_negative = delta < 0
    mask_zero = delta == 0

    # Calculate 1/(2*a) for efficient computation
    inv_2a = 0.5 / a

    # Initialize solutions tensor
    solutions = zeros((coeffs.shape[0], 2), device=coeffs.device, dtype=coeffs.dtype)

    # Handle cases with zero discriminant
    if torch.any(mask_zero):
        solutions[mask_zero, 0] = -b[mask_zero] * inv_2a[mask_zero]
        solutions[mask_zero, 1] = solutions[mask_zero, 0]

    # Negative discriminant cases are automatically handled since solutions is initialized with zeros.

    sqrt_delta = torch.sqrt(delta)

    # Handle cases with non-negative discriminant
    mask = torch.bitwise_and(~mask_negative, ~mask_zero)
    if torch.any(mask):
        solutions[mask, 0] = (-b[mask] + sqrt_delta[mask]) * inv_2a[mask]
        solutions[mask, 1] = (-b[mask] - sqrt_delta[mask]) * inv_2a[mask]

    return solutions


def solve_cubic(coeffs: Tensor) -> Tensor:
    r"""Solve given cubic equation.

    The function takes the coefficients of cubic equation and returns
    the real roots.

    .. math:: coeffs[0]x^3 + coeffs[1]x^2 + coeffs[2]x + coeffs[3] = 0

    Args:
        coeffs : The coefficients cubic equation : `(B, 4)`

    Returns:
        A tensor of shape `(B, 3)` containing the real roots to the cubic equation.

    Example:
        >>> coeffs = torch.tensor([[32., 3., -11., -6.]])
        >>> roots = solve_cubic(coeffs)

    .. note::
       In cases where a cubic polynomial has only one or two real roots, the output for the non-real
       roots should be represented as 0. Thus, the output for a single real root should be in the
       format [real_root, 0, 0], and for two real roots, it should be [real_root_1, real_root_2, 0].

    """
    KORNIA_CHECK_SHAPE(coeffs, ["B", "4"])

    _PI = torch.tensor(math.pi, device=coeffs.device, dtype=coeffs.dtype)

    # Coefficients of cubic equation
    a = coeffs[:, 0]  # coefficient of x^3
    b = coeffs[:, 1]  # coefficient of x^2
    c = coeffs[:, 2]  # coefficient of x
    d = coeffs[:, 3]  # constant term

    solutions = zeros((len(coeffs), 3), device=a.device, dtype=a.dtype)

    mask_a_zero = a == 0
    mask_b_zero = b == 0
    mask_c_zero = c == 0

    # Zero order cases are automatically handled since solutions is initialized with zeros.
    # No need for explicit handling of mask_zero_order as solutions already contains zeros by default.

    mask_first_order = mask_a_zero & mask_b_zero & ~mask_c_zero
    mask_second_order = mask_a_zero & ~mask_b_zero & ~mask_c_zero

    if torch.any(mask_second_order):
        solutions[mask_second_order, 0:2] = solve_quadratic(coeffs[mask_second_order, 1:])

    if torch.any(mask_first_order):
        solutions[mask_first_order, 0] = torch.tensor(1.0, device=a.device, dtype=a.dtype)

    # Normalized form x^3 + a2 * x^2 + a1 * x + a0 = 0
    inv_a = 1.0 / a[~mask_a_zero]
    b_a = inv_a * b[~mask_a_zero]
    b_a2 = b_a * b_a

    c_a = inv_a * c[~mask_a_zero]
    d_a = inv_a * d[~mask_a_zero]

    # Solve the cubic equation
    Q = (3 * c_a - b_a2) / 9
    R = (9 * b_a * c_a - 27 * d_a - 2 * b_a * b_a2) / 54
    Q3 = Q * Q * Q
    D = Q3 + R * R
    b_a_3 = (1.0 / 3.0) * b_a

    a_Q_zero = ones_like(a)
    a_R_zero = ones_like(a)
    a_D_zero = ones_like(a)

    a_Q_zero[~mask_a_zero] = Q
    a_R_zero[~mask_a_zero] = R
    a_D_zero[~mask_a_zero] = D

    # Q == 0
    mask_Q_zero = (Q == 0) & (R != 0)
    mask_Q_zero_solutions = (a_Q_zero == 0) & (a_R_zero != 0)

    if torch.any(mask_Q_zero):
        x0_Q_zero = torch.pow(2 * R[mask_Q_zero], 1 / 3) - b_a_3[mask_Q_zero]
        solutions[mask_Q_zero_solutions, 0] = x0_Q_zero

    mask_QR_zero = (Q == 0) & (R == 0)
    mask_QR_zero_solutions = (a_Q_zero == 0) & (a_R_zero == 0)

    if torch.any(mask_QR_zero):
        solutions[mask_QR_zero_solutions] = stack(
            [-b_a_3[mask_QR_zero], -b_a_3[mask_QR_zero], -b_a_3[mask_QR_zero]], dim=1
        )

    # D <= 0
    mask_D_zero = (D <= 0) & (Q != 0)
    mask_D_zero_solutions = (a_D_zero <= 0) & (a_Q_zero != 0)

    if torch.any(mask_D_zero):
        theta_D_zero = torch.acos(R[mask_D_zero] / torch.sqrt(-Q3[mask_D_zero]))
        sqrt_Q_D_zero = torch.sqrt(-Q[mask_D_zero])
        x0_D_zero = 2 * sqrt_Q_D_zero * cos(theta_D_zero / 3.0) - b_a_3[mask_D_zero]
        x1_D_zero = 2 * sqrt_Q_D_zero * cos((theta_D_zero + 2 * _PI) / 3.0) - b_a_3[mask_D_zero]
        x2_D_zero = 2 * sqrt_Q_D_zero * cos((theta_D_zero + 4 * _PI) / 3.0) - b_a_3[mask_D_zero]
        solutions[mask_D_zero_solutions] = stack([x0_D_zero, x1_D_zero, x2_D_zero], dim=1)

    a_D_positive = zeros_like(a)
    a_D_positive[~mask_a_zero] = D
    # D > 0
    mask_D_positive_solution = (a_D_positive > 0) & (a_Q_zero != 0)
    mask_D_positive = (D > 0) & (Q != 0)
    if torch.any(mask_D_positive):
        AD = zeros_like(R)
        BD = zeros_like(R)
        R_abs = torch.abs(R)
        mask_R_positive = R_abs > 1e-16
        if torch.any(mask_R_positive):
            AD[mask_R_positive] = torch.pow(R_abs[mask_R_positive] + torch.sqrt(D[mask_R_positive]), 1 / 3)
            mask_R_positive_ = R < 0

            if torch.any(mask_R_positive_):
                AD[mask_R_positive_] = -AD[mask_R_positive_]

            BD[mask_R_positive] = -Q[mask_R_positive] / AD[mask_R_positive]
        x0_D_positive = AD[mask_D_positive] + BD[mask_D_positive] - b_a_3[mask_D_positive]
        solutions[mask_D_positive_solution, 0] = x0_D_positive

    return solutions


# def solve_quartic(coeffs: Tensor) -> Tensor:
#    TODO: Quartic equation solver
#     return solutions


# Reference
# https://github.com/danini/graph-cut-ransac/blob/master/src/pygcransac/include/
# estimators/solver_essential_matrix_five_point_nister.h#L108


T_deg1 = torch.zeros(16, 10)
T_deg1[0, 0] = 1  # x * x → x^2
T_deg1[1, 1] = 1  # x * y
T_deg1[4, 1] = 1  # y * x
T_deg1[2, 2] = 1  # x * z
T_deg1[8, 2] = 1  # z * x
T_deg1[3, 3] = 1  # x * 1
T_deg1[12, 3] = 1  # 1 * x
T_deg1[5, 4] = 1  # y * y
T_deg1[6, 5] = 1  # y * z
T_deg1[9, 5] = 1  # z * y
T_deg1[7, 6] = 1  # y * 1
T_deg1[13, 6] = 1  # 1 * y
T_deg1[10, 7] = 1  # z * z
T_deg1[11, 8] = 1  # z * 1
T_deg1[14, 8] = 1  # 1 * z
T_deg1[15, 9] = 1  # 1 * 1


def multiply_deg_one_poly(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    r"""Multiply two polynomials of the first order [@nister2004efficient].

    Args:
        a: a first order polynomial for variables :math:`(x,y,z,1)`.
        b: a first order polynomial for variables :math:`(x,y,z,1)`.

    Returns:
        degree 2 poly with the order :math:`(x^2, x*y, x*z, x, y^2, y*z, y, z^2, z, 1)`.

    """
    global T_deg1  # noqa: PLW0603
    if T_deg1.device != a.device or T_deg1.dtype != a.dtype:
        T_deg1 = T_deg1.to(device=a.device, dtype=a.dtype)
    return (a.unsqueeze(2) * b.unsqueeze(1)).flatten(start_dim=-2) @ T_deg1


# Reference
# https://github.com/danini/graph-cut-ransac/blob/aae1f40c2e10e31fd2191bac601c53a189673f60/src/pygcransac/
# include/estimators/solver_essential_matrix_five_point_nister.h#L156

T_deg2 = torch.zeros(40, 20)
T_deg2[0, 0] = 1  # (0*4+0)
T_deg2[17, 1] = 1  # (4*4+1)
T_deg2[1, 2] = 1  # (0*4+1)
T_deg2[4, 2] = 1  # (1*4+0)
T_deg2[5, 3] = 1  # (1*4+1)
T_deg2[16, 3] = 1  # (4*4+0)
T_deg2[2, 4] = 1  # (0*4+2)
T_deg2[8, 4] = 1  # (2*4+0)
T_deg2[3, 5] = 1  # (0*4+3)
T_deg2[12, 5] = 1  # (3*4+0)
T_deg2[18, 6] = 1  # (4*4+2)
T_deg2[21, 6] = 1  # (5*4+1)
T_deg2[19, 7] = 1  # (4*4+3)
T_deg2[25, 7] = 1  # (6*4+1)
T_deg2[6, 8] = 1  # (1*4+2)
T_deg2[9, 8] = 1  # (2*4+1)
T_deg2[20, 8] = 1  # (5*4+0)
T_deg2[7, 9] = 1  # (1*4+3)
T_deg2[13, 9] = 1  # (3*4+1)
T_deg2[24, 9] = 1  # (6*4+0)
T_deg2[10, 10] = 1  # (2*4+2)
T_deg2[28, 10] = 1  # (7*4+0)
T_deg2[11, 11] = 1  # (2*4+3)
T_deg2[14, 11] = 1  # (3*4+2)
T_deg2[32, 11] = 1  # (8*4+0)
T_deg2[15, 12] = 1  # (3*4+3)
T_deg2[36, 12] = 1  # (9*4+0)
T_deg2[22, 13] = 1  # (5*4+2)
T_deg2[29, 13] = 1  # (7*4+1)
T_deg2[23, 14] = 1  # (5*4+3)
T_deg2[26, 14] = 1  # (6*4+2)
T_deg2[33, 14] = 1  # (8*4+1)
T_deg2[27, 15] = 1  # (6*4+3)
T_deg2[37, 15] = 1  # (9*4+1)
T_deg2[30, 16] = 1  # (7*4+2)
T_deg2[31, 17] = 1  # (7*4+3)
T_deg2[34, 17] = 1  # (8*4+2)
T_deg2[35, 18] = 1  # (8*4+3)
T_deg2[38, 18] = 1  # (9*4+2)
T_deg2[39, 19] = 1  # (9*4+3)


def multiply_deg_two_one_poly(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    r"""Multiply two polynomials a and b of degrees two and one [@nister2004efficient].

    Args:
        a: a second degree poly for variables :math:`(x^2, x*y, x*z, x, y^2, y*z, y, z^2, z, 1)`.
        b: a first degree poly for variables :math:`(x y z 1)`.

    Returns:
        a third degree poly for variables,
        :math:`(x^3, y^3, x^2*y, x*y^2, x^2*z, x^2, y^2*z, y^2,
        x*y*z, x*y, x*z^2, x*z, x, y*z^2, y*z, y, z^3, z^2, z, 1)`.

    """
    global T_deg2  # noqa: PLW0603
    if T_deg2.device != a.device or T_deg2.dtype != a.dtype:
        T_deg2 = T_deg2.to(device=a.device, dtype=a.dtype)
    product_basis = a.unsqueeze(2) * b.unsqueeze(1)
    product_vector = product_basis.flatten(start_dim=-2)
    return product_vector @ T_deg2


# Compute degree 10 poly representing determinant (equation 14 in the paper)
# https://github.com/danini/graph-cut-ransac/blob/aae1f40c2e10e31fd2191bac601c53a189673f60/src/pygcransac/
# include/estimators/solver_essential_matrix_five_point_nister.h#L368C5-L368C82

multiplication_indices = torch.tensor(
    [
        [12, 16, 33],
        [12, 20, 29],
        [3, 33, 25],
        [7, 29, 25],
        [3, 20, 38],
        [7, 16, 38],
        [11, 16, 33],
        [11, 20, 29],
        [12, 15, 33],
        [12, 16, 32],
        [12, 19, 29],
        [12, 20, 28],
        [2, 33, 25],
        [3, 32, 25],
        [3, 33, 24],
        [6, 29, 25],
        [7, 28, 25],
        [7, 29, 24],
        [2, 20, 38],
        [3, 19, 38],
        [3, 20, 37],
        [6, 16, 38],
        [7, 15, 38],
        [7, 16, 37],
        [10, 16, 33],
        [10, 20, 29],
        [11, 15, 33],
        [11, 16, 32],
        [11, 19, 29],
        [11, 20, 28],
        [14, 12, 33],
        [12, 15, 32],
        [12, 16, 31],
        [12, 18, 29],
        [12, 19, 28],
        [12, 20, 27],
        [1, 33, 25],
        [2, 32, 25],
        [2, 33, 24],
        [3, 31, 25],
        [3, 32, 24],
        [3, 33, 23],
        [5, 29, 25],
        [6, 28, 25],
        [6, 29, 24],
        [7, 27, 25],
        [7, 28, 24],
        [7, 29, 23],
        [1, 20, 38],
        [2, 19, 38],
        [2, 20, 37],
        [3, 18, 38],
        [3, 19, 37],
        [3, 20, 36],
        [5, 16, 38],
        [6, 15, 38],
        [6, 16, 37],
        [7, 14, 38],
        [7, 15, 37],
        [7, 16, 36],
        [3, 20, 35],
        [3, 22, 33],
        [7, 16, 35],
        [7, 22, 29],
        [9, 16, 33],
        [9, 20, 29],
        [10, 15, 33],
        [10, 16, 32],
        [10, 19, 29],
        [10, 20, 28],
        [13, 12, 33],
        [11, 14, 33],
        [11, 15, 32],
        [11, 16, 31],
        [11, 18, 29],
        [11, 19, 28],
        [11, 20, 27],
        [14, 12, 32],
        [12, 15, 31],
        [12, 16, 30],
        [12, 17, 29],
        [12, 18, 28],
        [12, 19, 27],
        [12, 20, 26],
        [0, 33, 25],
        [1, 32, 25],
        [1, 33, 24],
        [2, 31, 25],
        [2, 32, 24],
        [2, 33, 23],
        [3, 30, 25],
        [3, 31, 24],
        [3, 32, 23],
        [4, 29, 25],
        [5, 28, 25],
        [5, 29, 24],
        [6, 27, 25],
        [6, 28, 24],
        [6, 29, 23],
        [7, 26, 25],
        [7, 27, 24],
        [7, 28, 23],
        [0, 20, 38],
        [1, 19, 38],
        [1, 20, 37],
        [2, 18, 38],
        [2, 19, 37],
        [2, 20, 36],
        [3, 17, 38],
        [3, 18, 37],
        [3, 19, 36],
        [4, 16, 38],
        [5, 15, 38],
        [5, 16, 37],
        [6, 14, 38],
        [6, 15, 37],
        [6, 16, 36],
        [7, 13, 38],
        [7, 14, 37],
        [7, 15, 36],
        [2, 20, 35],
        [2, 22, 33],
        [3, 19, 35],
        [3, 20, 34],
        [3, 21, 33],
        [3, 22, 32],
        [6, 16, 35],
        [6, 22, 29],
        [7, 15, 35],
        [7, 16, 34],
        [7, 21, 29],
        [7, 22, 28],
        [8, 16, 33],
        [8, 20, 29],
        [9, 15, 33],
        [9, 16, 32],
        [9, 19, 29],
        [9, 20, 28],
        [10, 14, 33],
        [10, 15, 32],
        [10, 16, 31],
        [10, 18, 29],
        [10, 19, 28],
        [10, 20, 27],
        [13, 11, 33],
        [13, 12, 32],
        [11, 14, 32],
        [11, 15, 31],
        [11, 16, 30],
        [11, 17, 29],
        [11, 18, 28],
        [11, 19, 27],
        [11, 20, 26],
        [14, 12, 31],
        [12, 15, 30],
        [12, 17, 28],
        [12, 18, 27],
        [12, 19, 26],
        [0, 32, 25],
        [0, 33, 24],
        [1, 31, 25],
        [1, 32, 24],
        [1, 33, 23],
        [2, 30, 25],
        [2, 31, 24],
        [2, 32, 23],
        [3, 30, 24],
        [3, 31, 23],
        [4, 28, 25],
        [4, 29, 24],
        [5, 27, 25],
        [5, 28, 24],
        [5, 29, 23],
        [6, 26, 25],
        [6, 27, 24],
        [6, 28, 23],
        [7, 26, 24],
        [7, 27, 23],
        [0, 19, 38],
        [0, 20, 37],
        [1, 18, 38],
        [1, 19, 37],
        [1, 20, 36],
        [2, 17, 38],
        [2, 18, 37],
        [2, 19, 36],
        [3, 17, 37],
        [3, 18, 36],
        [4, 15, 38],
        [4, 16, 37],
        [5, 14, 38],
        [5, 15, 37],
        [5, 16, 36],
        [6, 13, 38],
        [6, 14, 37],
        [6, 15, 36],
        [7, 13, 37],
        [7, 14, 36],
        [1, 20, 35],
        [1, 22, 33],
        [2, 19, 35],
        [2, 20, 34],
        [2, 21, 33],
        [2, 22, 32],
        [3, 18, 35],
        [3, 19, 34],
        [3, 21, 32],
        [3, 22, 31],
        [5, 16, 35],
        [5, 22, 29],
        [6, 15, 35],
        [6, 16, 34],
        [6, 21, 29],
        [6, 22, 28],
        [7, 14, 35],
        [7, 15, 34],
        [7, 21, 28],
        [7, 22, 27],
        [8, 15, 33],
        [8, 16, 32],
        [8, 19, 29],
        [8, 20, 28],
        [9, 14, 33],
        [9, 15, 32],
        [9, 16, 31],
        [9, 18, 29],
        [9, 19, 28],
        [9, 20, 27],
        [10, 13, 33],
        [10, 14, 32],
        [10, 15, 31],
        [10, 16, 30],
        [10, 17, 29],
        [10, 18, 28],
        [10, 19, 27],
        [10, 20, 26],
        [13, 11, 32],
        [13, 12, 31],
        [11, 14, 31],
        [11, 15, 30],
        [11, 17, 28],
        [11, 18, 27],
        [11, 19, 26],
        [14, 12, 30],
        [12, 17, 27],
        [12, 18, 26],
        [0, 31, 25],
        [0, 32, 24],
        [0, 33, 23],
        [1, 30, 25],
        [1, 31, 24],
        [1, 32, 23],
        [2, 30, 24],
        [2, 31, 23],
        [3, 30, 23],
        [4, 27, 25],
        [4, 28, 24],
        [4, 29, 23],
        [5, 26, 25],
        [5, 27, 24],
        [5, 28, 23],
        [6, 26, 24],
        [6, 27, 23],
        [7, 26, 23],
        [0, 18, 38],
        [0, 19, 37],
        [0, 20, 36],
        [1, 17, 38],
        [1, 18, 37],
        [1, 19, 36],
        [2, 17, 37],
        [2, 18, 36],
        [3, 17, 36],
        [4, 14, 38],
        [4, 15, 37],
        [4, 16, 36],
        [5, 13, 38],
        [5, 14, 37],
        [5, 15, 36],
        [6, 13, 37],
        [6, 14, 36],
        [7, 13, 36],
        [0, 20, 35],
        [0, 22, 33],
        [1, 19, 35],
        [1, 20, 34],
        [1, 21, 33],
        [1, 22, 32],
        [2, 18, 35],
        [2, 19, 34],
        [2, 21, 32],
        [2, 22, 31],
        [3, 17, 35],
        [3, 18, 34],
        [3, 21, 31],
        [3, 22, 30],
        [4, 16, 35],
        [4, 22, 29],
        [5, 15, 35],
        [5, 16, 34],
        [5, 21, 29],
        [5, 22, 28],
        [6, 14, 35],
        [6, 15, 34],
        [6, 21, 28],
        [6, 22, 27],
        [7, 13, 35],
        [7, 14, 34],
        [7, 21, 27],
        [7, 22, 26],
        [8, 14, 33],
        [8, 15, 32],
        [8, 16, 31],
        [8, 18, 29],
        [8, 19, 28],
        [8, 20, 27],
        [9, 13, 33],
        [9, 14, 32],
        [9, 15, 31],
        [9, 16, 30],
        [9, 17, 29],
        [9, 18, 28],
        [9, 19, 27],
        [9, 20, 26],
        [10, 13, 32],
        [10, 14, 31],
        [10, 15, 30],
        [10, 17, 28],
        [10, 18, 27],
        [10, 19, 26],
        [13, 11, 31],
        [13, 12, 30],
        [11, 14, 30],
        [11, 17, 27],
        [11, 18, 26],
        [12, 17, 26],
        [0, 30, 25],
        [0, 31, 24],
        [0, 32, 23],
        [1, 30, 24],
        [1, 31, 23],
        [2, 30, 23],
        [4, 26, 25],
        [4, 27, 24],
        [4, 28, 23],
        [5, 26, 24],
        [5, 27, 23],
        [6, 26, 23],
        [0, 17, 38],
        [0, 18, 37],
        [0, 19, 36],
        [1, 17, 37],
        [1, 18, 36],
        [2, 17, 36],
        [4, 13, 38],
        [4, 14, 37],
        [4, 15, 36],
        [5, 13, 37],
        [5, 14, 36],
        [6, 13, 36],
        [0, 19, 35],
        [0, 20, 34],
        [0, 21, 33],
        [0, 22, 32],
        [1, 18, 35],
        [1, 19, 34],
        [1, 21, 32],
        [1, 22, 31],
        [2, 17, 35],
        [2, 18, 34],
        [2, 21, 31],
        [2, 22, 30],
        [3, 17, 34],
        [3, 21, 30],
        [4, 15, 35],
        [4, 16, 34],
        [4, 21, 29],
        [4, 22, 28],
        [5, 14, 35],
        [5, 15, 34],
        [5, 21, 28],
        [5, 22, 27],
        [6, 13, 35],
        [6, 14, 34],
        [6, 21, 27],
        [6, 22, 26],
        [7, 13, 34],
        [7, 21, 26],
        [8, 13, 33],
        [8, 14, 32],
        [8, 15, 31],
        [8, 16, 30],
        [8, 17, 29],
        [8, 18, 28],
        [8, 19, 27],
        [8, 20, 26],
        [9, 13, 32],
        [9, 14, 31],
        [9, 15, 30],
        [9, 17, 28],
        [9, 18, 27],
        [9, 19, 26],
        [10, 13, 31],
        [10, 14, 30],
        [10, 17, 27],
        [10, 18, 26],
        [13, 11, 30],
        [11, 17, 26],
        [0, 30, 24],
        [0, 31, 23],
        [1, 30, 23],
        [4, 26, 24],
        [4, 27, 23],
        [5, 26, 23],
        [0, 17, 37],
        [0, 18, 36],
        [1, 17, 36],
        [4, 13, 37],
        [4, 14, 36],
        [5, 13, 36],
        [0, 18, 35],
        [0, 19, 34],
        [0, 21, 32],
        [0, 22, 31],
        [1, 17, 35],
        [1, 18, 34],
        [1, 21, 31],
        [1, 22, 30],
        [2, 17, 34],
        [2, 21, 30],
        [4, 14, 35],
        [4, 15, 34],
        [4, 21, 28],
        [4, 22, 27],
        [5, 13, 35],
        [5, 14, 34],
        [5, 21, 27],
        [5, 22, 26],
        [6, 13, 34],
        [6, 21, 26],
        [8, 13, 32],
        [8, 14, 31],
        [8, 15, 30],
        [8, 17, 28],
        [8, 18, 27],
        [8, 19, 26],
        [9, 13, 31],
        [9, 14, 30],
        [9, 17, 27],
        [9, 18, 26],
        [10, 13, 30],
        [10, 17, 26],
        [0, 30, 23],
        [4, 26, 23],
        [0, 17, 36],
        [4, 13, 36],
        [0, 17, 35],
        [0, 18, 34],
        [0, 21, 31],
        [0, 22, 30],
        [1, 17, 34],
        [1, 21, 30],
        [4, 13, 35],
        [4, 14, 34],
        [4, 21, 27],
        [4, 22, 26],
        [5, 13, 34],
        [5, 21, 26],
        [8, 13, 31],
        [8, 14, 30],
        [8, 17, 27],
        [8, 18, 26],
        [9, 13, 30],
        [9, 17, 26],
        [0, 17, 34],
        [0, 21, 30],
        [4, 13, 34],
        [4, 21, 26],
        [8, 13, 30],
        [8, 17, 26],
    ],
    dtype=torch.int64,
)


signs = torch.tensor(
    [
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
        1.0,
        1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        1.0,
        1.0,
        -1.0,
    ],
    dtype=torch.float32,
)


coefficient_map = torch.tensor(
    [
        0,
        0,
        0,
        0,
        0,
        0,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        2,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        3,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        4,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        5,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        6,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        7,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        8,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        9,
        10,
        10,
        10,
        10,
        10,
        10,
    ],
    dtype=torch.int64,
)


def determinant_to_polynomial(
    A: Tensor,
) -> Tensor:
    r"""Represent the determinant by the 10th polynomial, used for 5PC solver [@nister2004efficient].

    Args:
        A: Tensor :math:`(*, 3, 13)`.

    Returns:
        a degree 10 poly, representing determinant (Eqn. 14 in the paper).

    """
    B, device, dtype = A.shape[0], A.device, A.dtype
    global multiplication_indices, signs, coefficient_map  # noqa: PLW0603

    multiplication_indices = multiplication_indices.to(device)
    signs = signs.to(device, dtype)
    coefficient_map = coefficient_map.to(device)

    A_flat = A.view(B, -1)
    gathered_values = A_flat[:, multiplication_indices]
    products = torch.prod(gathered_values, dim=-1)
    signed_products = products * signs

    cs = torch.zeros(B, 11, device=device, dtype=dtype)
    batch_coefficient_map = coefficient_map.repeat(B, 1)
    cs.scatter_add_(dim=1, index=batch_coefficient_map, src=signed_products)
    return cs
