from typing import Optional

from torch import Tensor

from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.utilities.prints import _deprecated_root_import_func


def _retrieval_average_precision(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([0.2, 0.3, 0.5])
    >>> target = tensor([True, False, True])
    >>> _retrieval_average_precision(preds, target)
    tensor(0.8333)

    """
    _deprecated_root_import_func("retrieval_average_precision", "retrieval")
    return retrieval_average_precision(preds=preds, target=target, top_k=top_k)


def _retrieval_fall_out(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([0.2, 0.3, 0.5])
    >>> target = tensor([True, False, True])
    >>> _retrieval_fall_out(preds, target, top_k=2)
    tensor(1.)

    """
    _deprecated_root_import_func("retrieval_fall_out", "retrieval")
    return retrieval_fall_out(preds=preds, target=target, top_k=top_k)


def _retrieval_hit_rate(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([0.2, 0.3, 0.5])
    >>> target = tensor([True, False, True])
    >>> _retrieval_hit_rate(preds, target, top_k=2)
    tensor(1.)

    """
    _deprecated_root_import_func("retrieval_hit_rate", "retrieval")
    return retrieval_hit_rate(preds=preds, target=target, top_k=top_k)


def _retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([.1, .2, .3, 4, 70])
    >>> target = tensor([10, 0, 0, 1, 5])
    >>> _retrieval_normalized_dcg(preds, target)
    tensor(0.6957)

    """
    _deprecated_root_import_func("retrieval_normalized_dcg", "retrieval")
    return retrieval_normalized_dcg(preds=preds, target=target, top_k=top_k)


def _retrieval_precision(
    preds: Tensor, target: Tensor, top_k: Optional[int] = None, adaptive_k: bool = False
) -> Tensor:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([0.2, 0.3, 0.5])
    >>> target = tensor([True, False, True])
    >>> _retrieval_precision(preds, target, top_k=2)
    tensor(0.5000)

    """
    _deprecated_root_import_func("retrieval_precision", "retrieval")
    return retrieval_precision(preds=preds, target=target, top_k=top_k, adaptive_k=adaptive_k)


def _retrieval_precision_recall_curve(
    preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False
) -> tuple[Tensor, Tensor, Tensor]:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([0.2, 0.3, 0.5])
    >>> target = tensor([True, False, True])
    >>> precisions, recalls, top_k = _retrieval_precision_recall_curve(preds, target, max_k=2)
    >>> precisions
    tensor([1.0000, 0.5000])
    >>> recalls
    tensor([0.5000, 0.5000])
    >>> top_k
    tensor([1, 2])

    """
    _deprecated_root_import_func("retrieval_precision_recall_curve", "retrieval")
    return retrieval_precision_recall_curve(preds=preds, target=target, max_k=max_k, adaptive_k=adaptive_k)


def _retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([0.2, 0.3, 0.5])
    >>> target = tensor([True, False, True])
    >>> _retrieval_r_precision(preds, target)
    tensor(0.5000)

    """
    _deprecated_root_import_func("retrieval_r_precision", "retrieval")
    return retrieval_r_precision(preds=preds, target=target)


def _retrieval_recall(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([0.2, 0.3, 0.5])
    >>> target = tensor([True, False, True])
    >>> _retrieval_recall(preds, target, top_k=2)
    tensor(0.5000)

    """
    _deprecated_root_import_func("retrieval_recall", "retrieval")
    return retrieval_recall(preds=preds, target=target, top_k=top_k)


def _retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
    """Wrapper for deprecated import.

    >>> from torch import tensor
    >>> preds = tensor([0.2, 0.3, 0.5])
    >>> target = tensor([False, True, False])
    >>> _retrieval_reciprocal_rank(preds, target)
    tensor(0.5000)

    """
    _deprecated_root_import_func("retrieval_reciprocal_rank", "retrieval")
    return retrieval_reciprocal_rank(preds=preds, target=target)
