# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, ClassVar, Dict, List, Optional, Tuple

import torch

from kornia.core import Module, Tensor, concatenate
from kornia.core.check import KORNIA_CHECK_DM_DESC, KORNIA_CHECK_SHAPE
from kornia.feature.laf import get_laf_center
from kornia.feature.steerers import DiscreteSteerer
from kornia.utils.helpers import is_mps_tensor_safe

from .adalam import get_adalam_default_config, match_adalam


def _cdist(d1: Tensor, d2: Tensor) -> Tensor:
    r"""Manual `torch.cdist` for M1."""
    if (not is_mps_tensor_safe(d1)) and (not is_mps_tensor_safe(d2)):
        return torch.cdist(d1, d2)
    d1_sq = (d1**2).sum(dim=1, keepdim=True)
    d2_sq = (d2**2).sum(dim=1, keepdim=True)
    dm = d1_sq.repeat(1, d2.size(0)) + d2_sq.repeat(1, d1.size(0)).t() - 2.0 * d1 @ d2.t()
    dm = dm.clamp(min=0.0).sqrt()
    return dm


def _get_default_fginn_params() -> Dict[str, Any]:
    config = {"th": 0.85, "mutual": False, "spatial_th": 10.0}
    return config


def _get_lazy_distance_matrix(desc1: Tensor, desc2: Tensor, dm_: Optional[Tensor] = None) -> Tensor:
    """Check validity of provided distance matrix, or calculates L2-distance matrix if dm is not provided.

    Args:
        desc1: Batch of descriptors of a shape :math:`(B1, D)`.
        desc2: Batch of descriptors of a shape :math:`(B2, D)`.
        dm_: Tensor containing the distances from each descriptor in desc1
          to each descriptor in desc2, shape of :math:`(B1, B2)`.

    """
    if dm_ is None:
        dm = _cdist(desc1, desc2)
    else:
        KORNIA_CHECK_DM_DESC(desc1, desc2, dm_)
        dm = dm_
    return dm


def _no_match(dm: Tensor) -> Tuple[Tensor, Tensor]:
    """Output empty tensors.

    Returns:
            - Descriptor distance of matching descriptors, shape of :math:`(0, 1)`.
            - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(0, 2)`.

    """
    dists = torch.empty(0, 1, device=dm.device, dtype=dm.dtype)
    idxs = torch.empty(0, 2, device=dm.device, dtype=torch.long)
    return dists, idxs


def match_nn(desc1: Tensor, desc2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
    r"""Find nearest neighbors in desc2 for each vector in desc1.

    If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.

    Args:
        desc1: Batch of descriptors of a shape :math:`(B1, D)`.
        desc2: Batch of descriptors of a shape :math:`(B2, D)`.
        dm: Tensor containing the distances from each descriptor in desc1
          to each descriptor in desc2, shape of :math:`(B1, B2)`.

    Returns:
        - Descriptor distance of matching descriptors, shape of :math:`(B1, 1)`.
        - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(B1, 2)`.

    """
    KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
    KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
    if (len(desc1) == 0) or (len(desc2) == 0):
        return _no_match(desc1)
    distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
    match_dists, idxs_in_2 = torch.min(distance_matrix, dim=1)
    idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=idxs_in_2.device)
    matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
    return match_dists.view(-1, 1), matches_idxs.view(-1, 2)


def match_mnn(desc1: Tensor, desc2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
    """Find mutual nearest neighbors in desc2 for each vector in desc1.

    If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.

    Args:
        desc1: Batch of descriptors of a shape :math:`(B1, D)`.
        desc2: Batch of descriptors of a shape :math:`(B2, D)`.
        dm: Tensor containing the distances from each descriptor in desc1
          to each descriptor in desc2, shape of :math:`(B1, B2)`.

    Return:
        - Descriptor distance of matching descriptors, shape of. :math:`(B3, 1)`.
        - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(B3, 2)`,
          where 0 <= B3 <= min(B1, B2)

    """
    KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
    KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
    if (len(desc1) == 0) or (len(desc2) == 0):
        return _no_match(desc1)
    distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
    ms = min(distance_matrix.size(0), distance_matrix.size(1))
    match_dists, idxs_in_2 = torch.min(distance_matrix, dim=1)
    match_dists2, idxs_in_1 = torch.min(distance_matrix, dim=0)
    minsize_idxs = torch.arange(ms, device=distance_matrix.device)

    if distance_matrix.size(0) <= distance_matrix.size(1):
        mutual_nns = minsize_idxs == idxs_in_1[idxs_in_2][:ms]
        matches_idxs = concatenate([minsize_idxs.view(-1, 1), idxs_in_2.view(-1, 1)], 1)[mutual_nns]
        match_dists = match_dists[mutual_nns]
    else:
        mutual_nns = minsize_idxs == idxs_in_2[idxs_in_1][:ms]
        matches_idxs = concatenate([idxs_in_1.view(-1, 1), minsize_idxs.view(-1, 1)], 1)[mutual_nns]
        match_dists = match_dists2[mutual_nns]
    return match_dists.view(-1, 1), matches_idxs.view(-1, 2)


def match_snn(desc1: Tensor, desc2: Tensor, th: float = 0.8, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
    """Find nearest neighbors in desc2 for each vector in desc1.

    The method satisfies first to second nearest neighbor distance <= th.

    If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.

    Args:
        desc1: Batch of descriptors of a shape :math:`(B1, D)`.
        desc2: Batch of descriptors of a shape :math:`(B2, D)`.
        th: distance ratio threshold.
        dm: Tensor containing the distances from each descriptor in desc1
          to each descriptor in desc2, shape of :math:`(B1, B2)`.

    Return:
        - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
        - Long tensor indexes of matching descriptors in desc1 and desc2. Shape: :math:`(B3, 2)`,
          where 0 <= B3 <= B1.

    """
    KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
    KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])

    if desc2.shape[0] < 2:  # We cannot perform snn check, so output empty matches
        return _no_match(desc1)
    distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
    vals, idxs_in_2 = torch.topk(distance_matrix, 2, dim=1, largest=False)
    ratio = vals[:, 0] / vals[:, 1]
    mask = ratio <= th
    match_dists = ratio[mask]
    if len(match_dists) == 0:
        return _no_match(distance_matrix)
    idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=distance_matrix.device)[mask]
    idxs_in_2 = idxs_in_2[:, 0][mask]
    matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
    return match_dists.view(-1, 1), matches_idxs.view(-1, 2)


def match_smnn(desc1: Tensor, desc2: Tensor, th: float = 0.95, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
    """Find mutual nearest neighbors in desc2 for each vector in desc1.

    the method satisfies first to second nearest neighbor distance <= th.

    If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.

    Args:
        desc1: Batch of descriptors of a shape :math:`(B1, D)`.
        desc2: Batch of descriptors of a shape :math:`(B2, D)`.
        th: distance ratio threshold.
        dm: Tensor containing the distances from each descriptor in desc1
          to each descriptor in desc2, shape of :math:`(B1, B2)`.

    Return:
        - Descriptor distance of matching descriptors, shape of. :math:`(B3, 1)`.
        - Long tensor indexes of matching descriptors in desc1 and desc2,
          shape of :math:`(B3, 2)` where 0 <= B3 <= B1.

    """
    KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
    KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])

    if (desc1.shape[0] < 2) or (desc2.shape[0] < 2):
        return _no_match(desc1)
    distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)

    dists1, idx1 = match_snn(desc1, desc2, th, distance_matrix)
    dists2, idx2 = match_snn(desc2, desc1, th, distance_matrix.t())

    if len(dists2) > 0 and len(dists1) > 0:
        idx2 = idx2.flip(1)
        if not is_mps_tensor_safe(idx1):
            idxs_dm = torch.cdist(idx1.float(), idx2.float(), p=1.0)
        else:
            idxs1_rep = idx1.to(desc1).repeat_interleave(idx2.size(0), dim=0)
            idxs_dm = (idx2.to(desc2).repeat(idx1.size(0), 1) - idxs1_rep).abs().sum(dim=1)
            idxs_dm = idxs_dm.reshape(idx1.size(0), idx2.size(0))
        mutual_idxs1 = idxs_dm.min(dim=1)[0] < 1e-8
        mutual_idxs2 = idxs_dm.min(dim=0)[0] < 1e-8
        good_idxs1 = idx1[mutual_idxs1.view(-1)]
        good_idxs2 = idx2[mutual_idxs2.view(-1)]
        dists1_good = dists1[mutual_idxs1.view(-1)]
        dists2_good = dists2[mutual_idxs2.view(-1)]
        _, idx_upl1 = torch.sort(good_idxs1[:, 0])
        _, idx_upl2 = torch.sort(good_idxs2[:, 0])
        good_idxs1 = good_idxs1[idx_upl1]
        match_dists = torch.max(dists1_good[idx_upl1], dists2_good[idx_upl2])
        matches_idxs = good_idxs1
        match_dists, matches_idxs = match_dists.view(-1, 1), matches_idxs.view(-1, 2)
    else:
        match_dists, matches_idxs = _no_match(distance_matrix)
    return match_dists, matches_idxs


def match_fginn(
    desc1: Tensor,
    desc2: Tensor,
    lafs1: Tensor,
    lafs2: Tensor,
    th: float = 0.8,
    spatial_th: float = 10.0,
    mutual: bool = False,
    dm: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
    """Find nearest neighbors in desc2 for each vector in desc1.

    The method satisfies first to second nearest neighbor distance <= th,
    and assures 2nd nearest neighbor is geometrically inconsistent with the 1st one
    (see :cite:`MODS2015` for more details)

    If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.

    Args:
        desc1: Batch of descriptors of a shape :math:`(B1, D)`.
        desc2: Batch of descriptors of a shape :math:`(B2, D)`.
        lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
        lafs2: LAFs of a shape :math:`(1, B2, 2, 3)`.

        th: distance ratio threshold.
        spatial_th: minimal distance in pixels to 2nd nearest neighbor.
        mutual: also perform mutual nearest neighbor check
        dm: Tensor containing the distances from each descriptor in desc1
          to each descriptor in desc2, shape of :math:`(B1, B2)`.

    Return:
        - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
        - Long tensor indexes of matching descriptors in desc1 and desc2. Shape: :math:`(B3, 2)`,
          where 0 <= B3 <= B1.

    """
    KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
    KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
    BIG_NUMBER = 1000000.0

    distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
    dtype = distance_matrix.dtype

    if desc2.shape[0] < 2:  # We cannot perform snn check, so output empty matches
        return _no_match(distance_matrix)

    num_candidates = max(2, min(10, desc2.shape[0]))
    vals_cand, idxs_in_2 = torch.topk(distance_matrix, num_candidates, dim=1, largest=False)
    vals = vals_cand[:, 0]
    xy2 = get_laf_center(lafs2).view(-1, 2)
    candidates_xy = xy2[idxs_in_2]
    kdist = torch.norm(candidates_xy - candidates_xy[0:1], p=2, dim=2)
    fginn_vals = vals_cand[:, 1:] + (kdist[:, 1:] < spatial_th).to(dtype) * BIG_NUMBER
    fginn_vals_best, _fginn_idxs_best = fginn_vals.min(dim=1)

    # orig_idxs = idxs_in_2.gather(1, fginn_idxs_best.unsqueeze(1))[0]
    # if you need to know fginn indexes - uncomment

    vals_2nd = fginn_vals_best
    idxs_in_2 = idxs_in_2[:, 0]

    ratio = vals / vals_2nd
    mask = ratio <= th
    match_dists = ratio[mask]
    if len(match_dists) == 0:
        return _no_match(distance_matrix)
    idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=distance_matrix.device)[mask]
    idxs_in_2 = idxs_in_2[mask]
    matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
    match_dists, matches_idxs = match_dists.view(-1, 1), matches_idxs.view(-1, 2)

    if not mutual:  # returning 1-way matches
        return match_dists, matches_idxs
    _, idxs_in_1_mut = torch.min(distance_matrix, dim=0)
    good_mask = matches_idxs[:, 0] == idxs_in_1_mut[matches_idxs[:, 1]]
    return match_dists[good_mask], matches_idxs[good_mask]


class DescriptorMatcher(Module):
    """Module version of matching functions.

    See :func:`~kornia.feature.match_nn`, :func:`~kornia.feature.match_snn`,
        :func:`~kornia.feature.match_mnn` or :func:`~kornia.feature.match_smnn` for more details.

    Args:
        match_mode: type of matching, can be `nn`, `snn`, `mnn`, `smnn`.
        th: threshold on distance ratio, or other quality measure.

    """

    def __init__(self, match_mode: str = "snn", th: float = 0.8) -> None:
        super().__init__()
        _match_mode: str = match_mode.lower()
        self.known_modes = ["nn", "mnn", "snn", "smnn"]
        if _match_mode not in self.known_modes:
            raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
        self.match_mode = _match_mode
        self.th = th

    def forward(self, desc1: Tensor, desc2: Tensor) -> Tuple[Tensor, Tensor]:
        """Run forward.

        Args:
            desc1: Batch of descriptors of a shape :math:`(B1, D)`.
            desc2: Batch of descriptors of a shape :math:`(B2, D)`.

        Returns:
            - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
            - Long tensor indexes of matching descriptors in desc1 and desc2,
                shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.

        """
        if self.match_mode == "nn":
            out = match_nn(desc1, desc2)
        elif self.match_mode == "mnn":
            out = match_mnn(desc1, desc2)
        elif self.match_mode == "snn":
            out = match_snn(desc1, desc2, self.th)
        elif self.match_mode == "smnn":
            out = match_smnn(desc1, desc2, self.th)
        else:
            raise NotImplementedError
        return out


class DescriptorMatcherWithSteerer(Module):
    """Matching that is invariant under rotations, using Steerers.

    Args:
        steerer: An instance of :func:`kornia.feature.steerers.DiscreteSteerer`.
        steerer_order: order of discretisation of rotation angles, e.g. 4 leads to quarter rotations.
        steer_mode: can be `global`, `local`.
            `global` means that the we output matches from the global rotation with most matches.
            `local` means that we output matches from a distance matrix
            where the distance between each descriptor pair is the minimal over rotations.
        match_mode: type of matching, can be `nn`, `snn`, `mnn`, `smnn`.
            WARNING: using steer_mode `global` with match_mode `nn` will lead to bad results
            since `nn` doesn't generate different amount of matches depending on goodness of fit.
        th: threshold on distance ratio, or other quality measure.

    Example:
        >>> import kornia as K
        >>> import kornia.feature as KF
        >>> device = K.utils.get_cuda_or_mps_device_if_available()
        >>> img1 = torch.randn([1, 3, 768, 768], device=device)
        >>> img2 = torch.randn([1, 3, 768, 768], device=device)
        >>> dedode = KF.DeDoDe.from_pretrained(detector_weights="L-C4-v2", descriptor_weights="B-SO2").to(device)
        >>> steerer_order = 8  # discretisation order of rotation angles
        >>> steerer = KF.steerers.DiscreteSteerer.create_dedode_default(
        ... generator_type="SO2", steerer_order=steerer_order
        ... )
        >>> steerer = steerer.to(device)
        >>> matcher = KF.matching.DescriptorMatcherWithSteerer(
        ... steerer=steerer, steerer_order=steerer_order, steer_mode="global", match_mode="smnn", th=0.98
        ... )
        >>> with torch.inference_mode():
        ...     kps1, scores1, descs1 = dedode(img1, n=20_000)
        ...     kps2, scores2, descs2 = dedode(img2, n=20_000)
        ...     kps1, kps2, descs1, descs2 = kps1[0], kps2[0], descs1[0], descs2[0]
        ...     dists, idxs, num_rot = matcher(
        ...         descs1, descs2, normalize=True, subset_size=1000,
        ...     )
        >>> # print(f"{idxs.shape[0]} tentative matches with steered DeDoDe")
        >>> # print(f"at rotation of {num_rot * 360 / steerer_order} degrees")

    """

    def __init__(
        self,
        steerer: DiscreteSteerer,
        steerer_order: int,
        steer_mode: str = "global",
        match_mode: str = "snn",
        th: float = 0.8,
    ) -> None:
        super().__init__()
        self.steerer = steerer
        self.steerer_order = steerer_order

        _steer_mode: str = steer_mode.lower()
        self.known_steer_modes = ["global", "local"]
        if _steer_mode not in self.known_steer_modes:
            raise NotImplementedError(f"{steer_mode} is not supported. Try one of {self.known_steer_modes}")
        self.steer_mode = _steer_mode
        _match_mode: str = match_mode.lower()
        self.known_modes = ["nn", "mnn", "snn", "smnn"]
        if _match_mode not in self.known_modes:
            raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
        self.match_mode = _match_mode
        self.th = th

    def matching_function(self, d1: Tensor, d2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        if self.match_mode == "nn":
            return match_nn(d1, d2, dm=dm)
        elif self.match_mode == "mnn":
            return match_mnn(d1, d2, dm=dm)
        elif self.match_mode == "snn":
            return match_snn(d1, d2, self.th, dm=dm)
        elif self.match_mode == "smnn":
            return match_smnn(d1, d2, self.th, dm=dm)
        else:
            raise NotImplementedError

    def forward(
        self,
        desc1: Tensor,
        desc2: Tensor,
        normalize: bool = False,
        subset_size: Optional[int] = None,
    ) -> Tuple[Tensor, Tensor, Optional[int]]:
        """Run forward.

        Args:
            desc1: Batch of descriptors of a shape :math:`(B1, D)`.
            desc2: Batch of descriptors of a shape :math:`(B2, D)`.
            normalize: bool to decide whether to normalize descriptors to unit norm.
            subset_size: If set, the subset size to use for determining optimal
                number of rotations. Smaller subset size leads to faster but less
                accurate matching. Only used when `self.steer_mode` is `"global"`.

        Returns:
            - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
            - Long tensor indexes of matching descriptors in desc1 and desc2,
                shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
            - Number of global rotations from desc1 to desc2, in terms of `self.steerer_order`
                (will be `None` if `self.steer_mode` is `local`).

        """
        rot1to2 = None

        if normalize:
            desc1 = torch.nn.functional.normalize(desc1, dim=-1)
            desc2 = torch.nn.functional.normalize(desc2, dim=-1)

        if self.steer_mode == "global":
            if subset_size is not None:
                subsample1 = torch.randperm(desc1.shape[0])[:subset_size]
                subsample2 = torch.randperm(desc2.shape[0])[:subset_size]
                _, _, rot1to2 = self(
                    desc1[subsample1],
                    desc2[subsample2],
                    normalize=normalize,
                )
                desc1 = self.steerer.steer_descriptions(
                    desc1,
                    steerer_power=rot1to2,
                    normalize=normalize,
                )
                dist, idx = self.matching_function(desc1, desc2, None)
                return dist, idx, rot1to2
            dist, idx = self.matching_function(desc1, desc2, None)
            rot1to2 = 0
            for r in range(1, self.steerer_order):
                desc1 = self.steerer.steer_descriptions(desc1, normalize=normalize)
                dist_new, idx_new = self.matching_function(desc1, desc2, None)
                if idx_new.shape[0] > idx.shape[0]:
                    dist, idx, rot1to2 = dist_new, idx_new, r
        elif self.steer_mode == "local":
            dm = _cdist(desc1, desc2)
            for _ in range(1, self.steerer_order):
                desc1 = self.steerer.steer_descriptions(desc1, normalize=normalize)
                dm_new = _cdist(desc1, desc2)
                dm = torch.minimum(dm, dm_new)
            dist, idx = self.matching_function(desc1, desc2, dm)
        else:
            raise NotImplementedError

        return dist, idx, rot1to2


class GeometryAwareDescriptorMatcher(Module):
    """Module version of matching functions.

    See :func:`~kornia.feature.match_nn`, :func:`~kornia.feature.match_snn`,
        :func:`~kornia.feature.match_mnn` or :func:`~kornia.feature.match_smnn` for more details.

    Args:
        match_mode: type of matching, can be `fginn`.
        th: threshold on distance ratio, or other quality measure.

    """

    known_modes: ClassVar[List[str]] = ["fginn", "adalam"]

    def __init__(self, match_mode: str = "fginn", params: Optional[Dict[str, Tensor]] = None) -> None:
        super().__init__()
        _match_mode: str = match_mode.lower()
        if _match_mode not in self.known_modes:
            raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
        self.match_mode = _match_mode
        self.params = params or {}

    def forward(self, desc1: Tensor, desc2: Tensor, lafs1: Tensor, lafs2: Tensor) -> Tuple[Tensor, Tensor]:
        """Run forward.

        Args:
            desc1: Batch of descriptors of a shape :math:`(B1, D)`.
            desc2: Batch of descriptors of a shape :math:`(B2, D)`.
            lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
            lafs2: LAFs of a shape :math:`(1, B2, 2, 3)`.

        Returns:
            - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
            - Long tensor indexes of matching descriptors in desc1 and desc2,
                shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.

        """
        if self.match_mode == "fginn":
            params = _get_default_fginn_params()
            params.update(self.params)
            out = match_fginn(desc1, desc2, lafs1, lafs2, params["th"], params["spatial_th"], params["mutual"])
        elif self.match_mode == "adalam":
            _params = get_adalam_default_config()
            _params.update(self.params)  # type: ignore[typeddict-item]
            out = match_adalam(desc1, desc2, lafs1, lafs2, config=_params)
        else:
            raise NotImplementedError
        return out
