# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.data import _bincount, _cumsum, dim_zero_cat
from torchmetrics.utilities.enums import EnumStr


class _MetricVariant(EnumStr):
    """Enumerate for metric variants."""

    A = "a"
    B = "b"
    C = "c"

    @staticmethod
    def _name() -> str:
        return "variant"


class _TestAlternative(EnumStr):
    """Enumerate for test alternative options."""

    TWO_SIDED = "two-sided"
    LESS = "less"
    GREATER = "greater"

    @staticmethod
    def _name() -> str:
        return "alternative"


def _sort_on_first_sequence(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
    """Sort sequences in an ascent order according to the sequence ``x``."""
    # We need to clone `y` tensor not to change an object in memory
    y = torch.clone(y)
    x, y = x.T, y.T
    x, perm = x.sort()
    for i in range(x.shape[0]):
        y[i] = y[i][perm[i]]
    return x.T, y.T


def _concordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor:
    """Count a total number of concordant pairs in a single sequence."""
    return torch.logical_and(x[i] < x[(i + 1) :], y[i] < y[(i + 1) :]).sum(0).unsqueeze(0)


def _count_concordant_pairs(preds: Tensor, target: Tensor) -> Tensor:
    """Count a total number of concordant pairs in given sequences."""
    return torch.cat([_concordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0)


def _discordant_element_sum(x: Tensor, y: Tensor, i: int) -> Tensor:
    """Count a total number of discordant pairs in a single sequences."""
    return (
        torch.logical_or(
            torch.logical_and(x[i] > x[(i + 1) :], y[i] < y[(i + 1) :]),
            torch.logical_and(x[i] < x[(i + 1) :], y[i] > y[(i + 1) :]),
        )
        .sum(0)
        .unsqueeze(0)
    )


def _count_discordant_pairs(preds: Tensor, target: Tensor) -> Tensor:
    """Count a total number of discordant pairs in given sequences."""
    return torch.cat([_discordant_element_sum(preds, target, i) for i in range(preds.shape[0])]).sum(0)


def _convert_sequence_to_dense_rank(x: Tensor, sort: bool = False) -> Tensor:
    """Convert a sequence to the rank tensor."""
    # Sort if a sequence has not been sorted before
    if sort:
        x = x.sort(dim=0).values
    _ones = torch.zeros(1, x.shape[1], dtype=torch.int32, device=x.device)
    return _cumsum(torch.cat([_ones, (x[1:] != x[:-1]).int()], dim=0), dim=0)


def _get_ties(x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
    """Get a total number of ties and staistics for p-value calculation for  a given sequence."""
    ties = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
    ties_p1 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
    ties_p2 = torch.zeros(x.shape[1], dtype=x.dtype, device=x.device)
    for dim in range(x.shape[1]):
        n_ties = _bincount(x[:, dim])
        n_ties = n_ties[n_ties > 1]
        ties[dim] = (n_ties * (n_ties - 1) // 2).sum()
        ties_p1[dim] = (n_ties * (n_ties - 1.0) * (n_ties - 2)).sum()
        ties_p2[dim] = (n_ties * (n_ties - 1.0) * (2 * n_ties + 5)).sum()

    return ties, ties_p1, ties_p2


def _get_metric_metadata(
    preds: Tensor, target: Tensor, variant: _MetricVariant
) -> tuple[
    Tensor,
    Tensor,
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
    Tensor,
]:
    """Obtain statistics to calculate metric value."""
    preds, target = _sort_on_first_sequence(preds, target)

    concordant_pairs = _count_concordant_pairs(preds, target)
    discordant_pairs = _count_discordant_pairs(preds, target)

    n_total = torch.tensor(preds.shape[0], device=preds.device)
    preds_ties = target_ties = None
    preds_ties_p1 = preds_ties_p2 = target_ties_p1 = target_ties_p2 = None
    if variant != _MetricVariant.A:
        preds = _convert_sequence_to_dense_rank(preds)
        target = _convert_sequence_to_dense_rank(target, sort=True)
        preds_ties, preds_ties_p1, preds_ties_p2 = _get_ties(preds)
        target_ties, target_ties_p1, target_ties_p2 = _get_ties(target)
    return (
        concordant_pairs,
        discordant_pairs,
        preds_ties,
        preds_ties_p1,
        preds_ties_p2,
        target_ties,
        target_ties_p1,
        target_ties_p2,
        n_total,
    )


def _calculate_tau(
    preds: Tensor,
    target: Tensor,
    concordant_pairs: Tensor,
    discordant_pairs: Tensor,
    con_min_dis_pairs: Tensor,
    n_total: Tensor,
    preds_ties: Optional[Tensor],
    target_ties: Optional[Tensor],
    variant: _MetricVariant,
) -> Tensor:
    """Calculate Kendall's tau from metric metadata."""
    if variant == _MetricVariant.A:
        return con_min_dis_pairs / (concordant_pairs + discordant_pairs)
    if variant == _MetricVariant.B:
        total_combinations: Tensor = n_total * (n_total - 1) // 2
        if preds_ties is None:
            preds_ties = torch.tensor(0.0, dtype=total_combinations.dtype, device=total_combinations.device)
        if target_ties is None:
            target_ties = torch.tensor(0.0, dtype=total_combinations.dtype, device=total_combinations.device)
        denominator = (total_combinations - preds_ties) * (total_combinations - target_ties)
        return con_min_dis_pairs / torch.sqrt(denominator)

    preds_unique = torch.tensor([len(p.unique()) for p in preds.T], dtype=preds.dtype, device=preds.device)
    target_unique = torch.tensor([len(t.unique()) for t in target.T], dtype=target.dtype, device=target.device)
    min_classes = torch.minimum(preds_unique, target_unique)
    return 2 * con_min_dis_pairs / ((min_classes - 1) / min_classes * n_total**2)


def _get_p_value_for_t_value_from_dist(t_value: Tensor) -> Tensor:
    """Obtain p-value for a given Tensor of t-values. Handle ``nan`` which cannot be passed into torch distributions.

    When t-value is ``nan``, a resulted p-value should be alson ``nan``.

    """
    device = t_value
    normal_dist = torch.distributions.normal.Normal(torch.tensor([0.0]).to(device), torch.tensor([1.0]).to(device))

    is_nan = t_value.isnan()
    t_value = t_value.nan_to_num()
    p_value = normal_dist.cdf(t_value)
    return p_value.where(~is_nan, torch.tensor(float("nan"), dtype=p_value.dtype, device=p_value.device))


def _calculate_p_value(
    con_min_dis_pairs: Tensor,
    n_total: Tensor,
    preds_ties: Optional[Tensor],
    preds_ties_p1: Optional[Tensor],
    preds_ties_p2: Optional[Tensor],
    target_ties: Optional[Tensor],
    target_ties_p1: Optional[Tensor],
    target_ties_p2: Optional[Tensor],
    variant: _MetricVariant,
    alternative: Optional[_TestAlternative],
) -> Tensor:
    """Calculate p-value for Kendall's tau from metric metadata."""
    t_value_denominator_base = n_total * (n_total - 1) * (2 * n_total + 5)
    if variant == _MetricVariant.A:
        t_value = 3 * con_min_dis_pairs / torch.sqrt(t_value_denominator_base / 2)
    else:
        m = n_total * (n_total - 1)
        t_value_denominator: Tensor = (
            t_value_denominator_base
            - (preds_ties_p2 if preds_ties_p2 is not None else 0)
            - (target_ties_p2 if target_ties_p2 is not None else 0)
        ) / 18
        t_value_denominator += (
            2 * (preds_ties if preds_ties is not None else 0) * (target_ties if target_ties is not None else 0)
        ) / m
        t_value_denominator += (
            (preds_ties_p1 if preds_ties_p1 is not None else 0)
            * (target_ties_p1 if target_ties_p1 is not None else 0)
            / (9 * m * (n_total - 2))
        )
        t_value = con_min_dis_pairs / torch.sqrt(t_value_denominator)

    if alternative == _TestAlternative.TWO_SIDED:
        t_value = torch.abs(t_value)
    if alternative in [_TestAlternative.TWO_SIDED, _TestAlternative.GREATER]:
        t_value *= -1
    p_value = _get_p_value_for_t_value_from_dist(t_value)
    if alternative == _TestAlternative.TWO_SIDED:
        p_value *= 2
    return p_value


def _kendall_corrcoef_update(
    preds: Tensor,
    target: Tensor,
    concat_preds: Optional[List[Tensor]] = None,
    concat_target: Optional[List[Tensor]] = None,
    num_outputs: int = 1,
) -> tuple[List[Tensor], List[Tensor]]:
    """Update variables required to compute Kendall rank correlation coefficient.

    Args:
        preds: Sequence of data
        target: Sequence of data
        concat_preds: List of batches of preds sequence to be concatenated
        concat_target: List of batches of target sequence to be concatenated
        num_outputs: Number of outputs in multioutput setting

    Raises:
        RuntimeError: If ``preds`` and ``target`` do not have the same shape

    """
    concat_preds = concat_preds or []
    concat_target = concat_target or []
    # Data checking
    _check_same_shape(preds, target)
    _check_data_shape_to_num_outputs(preds, target, num_outputs)

    if num_outputs == 1:
        preds = preds.unsqueeze(1)
        target = target.unsqueeze(1)

    concat_preds.append(preds)
    concat_target.append(target)

    return concat_preds, concat_target


def _kendall_corrcoef_compute(
    preds: Tensor,
    target: Tensor,
    variant: _MetricVariant,
    alternative: Optional[_TestAlternative] = None,
) -> tuple[Tensor, Optional[Tensor]]:
    """Compute Kendall rank correlation coefficient, and optionally p-value of corresponding statistical test.

    Args:
        Args:
        preds: Sequence of data
        target: Sequence of data
        variant: Indication of which variant of Kendall's tau to be used
        alternative: Alternative hypothesis for for t-test. Possible values:
            - 'two-sided': the rank correlation is nonzero
            - 'less': the rank correlation is negative (less than zero)
            - 'greater':  the rank correlation is positive (greater than zero)

    """
    (
        concordant_pairs,
        discordant_pairs,
        preds_ties,
        preds_ties_p1,
        preds_ties_p2,
        target_ties,
        target_ties_p1,
        target_ties_p2,
        n_total,
    ) = _get_metric_metadata(preds, target, variant)
    con_min_dis_pairs = concordant_pairs - discordant_pairs

    tau = _calculate_tau(
        preds, target, concordant_pairs, discordant_pairs, con_min_dis_pairs, n_total, preds_ties, target_ties, variant
    )
    p_value = (
        _calculate_p_value(
            con_min_dis_pairs,
            n_total,
            preds_ties,
            preds_ties_p1,
            preds_ties_p2,
            target_ties,
            target_ties_p1,
            target_ties_p2,
            variant,
            alternative,
        )
        if alternative
        else None
    )

    # Squeeze tensor if num_outputs=1
    if tau.shape[0] == 1:
        tau = tau.squeeze()
        p_value = p_value.squeeze() if p_value is not None else None

    return tau.clamp(-1, 1), p_value


def kendall_rank_corrcoef(
    preds: Tensor,
    target: Tensor,
    variant: Literal["a", "b", "c"] = "b",
    t_test: bool = False,
    alternative: Optional[Literal["two-sided", "less", "greater"]] = "two-sided",
) -> Union[Tensor, tuple[Tensor, Tensor]]:
    r"""Compute `Kendall Rank Correlation Coefficient`_.

    .. math::
        tau_a = \frac{C - D}{C + D}

    where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs.

    .. math::
        tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}}

    where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs and :math:`T` represents
    a total number of ties.

    .. math::
        tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}}

    where :math:`C` represents concordant pairs, :math:`D` stands for discordant pairs, :math:`n` is a total number
    of observations and :math:`m` is a ``min`` of unique values in ``preds`` and ``target`` sequence.

    Definitions according to Definition according to `The Treatment of Ties in Ranking Problems`_.

    Args:
        preds: Sequence of data of either shape ``(N,)`` or ``(N,d)``
        target: Sequence of data of either shape ``(N,)`` or ``(N,d)``
        variant: Indication of which variant of Kendall's tau to be used
        t_test: Indication whether to run t-test
        alternative: Alternative hypothesis for t-test. Possible values:
            - 'two-sided': the rank correlation is nonzero
            - 'less': the rank correlation is negative (less than zero)
            - 'greater':  the rank correlation is positive (greater than zero)

    Return:
        Correlation tau statistic
        (Optional) p-value of corresponding statistical test (asymptotic)

    Raises:
        ValueError: If ``t_test`` is not of a type bool
        ValueError: If ``t_test=True`` and ``alternative=None``

    Example (single output regression):
        >>> from torchmetrics.functional.regression import kendall_rank_corrcoef
        >>> preds = torch.tensor([2.5, 0.0, 2, 8])
        >>> target = torch.tensor([3, -0.5, 2, 1])
        >>> kendall_rank_corrcoef(preds, target)
        tensor(0.3333)

    Example (multi output regression):
        >>> from torchmetrics.functional.regression import kendall_rank_corrcoef
        >>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
        >>> target = torch.tensor([[3, -0.5], [2, 1]])
        >>> kendall_rank_corrcoef(preds, target)
        tensor([1., 1.])

    Example (single output regression with t-test)
        >>> from torchmetrics.functional.regression import kendall_rank_corrcoef
        >>> preds = torch.tensor([2.5, 0.0, 2, 8])
        >>> target = torch.tensor([3, -0.5, 2, 1])
        >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided')
        (tensor(0.3333), tensor(0.4969))

    Example (multi output regression with t-test):
        >>> from torchmetrics.functional.regression import kendall_rank_corrcoef
        >>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
        >>> target = torch.tensor([[3, -0.5], [2, 1]])
        >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided')
            (tensor([1., 1.]), tensor([nan, nan]))

    """
    if not isinstance(t_test, bool):
        raise ValueError(f"Argument `t_test` is expected to be of a type `bool`, but got {type(t_test)}.")
    if t_test and alternative is None:
        raise ValueError("Argument `alternative` is required if `t_test=True` but got `None`.")

    _variant = _MetricVariant.from_str(str(variant))
    _alternative = _TestAlternative.from_str(str(alternative)) if t_test else None

    _preds, _target = _kendall_corrcoef_update(
        preds, target, [], [], num_outputs=1 if preds.ndim == 1 else preds.shape[-1]
    )
    tau, p_value = _kendall_corrcoef_compute(
        dim_zero_cat(_preds),
        dim_zero_cat(_target),
        _variant,  # type: ignore[arg-type]  # todo
        _alternative,  # type: ignore[arg-type]  # todo
    )

    if p_value is not None:
        return tau, p_value
    return tau
