import torch
import torch.nn.functional as F

from ..distances import CosineSimilarity
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


class SmoothAPLoss(BaseMetricLossFunction):
    """
    Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163
    """

    def __init__(self, temperature=0.01, **kwargs):
        super().__init__(**kwargs)
        c_f.assert_distance_type(self, CosineSimilarity)
        self.temperature = temperature

    def get_default_distance(self):
        return CosineSimilarity()

    # Implementation is based on the original repository:
    # https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87
    def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
        # The loss expects labels such that there is the same number of elements for each class
        # The number of classes is not important, nor their order, but the number of elements must be the same, eg.
        #
        # The following label is valid:
        # [ A,A,A, B,B,B, C,C,C ]
        # The following label is NOT valid:
        # [ B,B,B  A,A,A,A,  C,C,C ]
        #
        c_f.labels_required(labels)
        c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)

        counts = torch.bincount(labels)
        nonzero_indices = torch.nonzero(counts, as_tuple=True)[0]
        nonzero_counts = counts[nonzero_indices]
        if nonzero_counts.unique().size(0) != 1:
            raise ValueError(
                "All classes must have the same number of elements in the labels.\n"
                "The given labels have the following number of elements: {}.\n"
                "You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format(
                    nonzero_counts.cpu().tolist()
                )
            )

        batch_size = embeddings.size(0)
        num_classes_batch = batch_size // torch.unique(labels).size(0)

        mask = 1.0 - torch.eye(batch_size)
        mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1)

        sims = self.distance(embeddings)

        sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1)
        sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1)
        sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device)
        sims_ranks = torch.sum(sims_sigm, dim=-1) + 1

        xs = embeddings.view(
            num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1)
        )
        pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch)
        pos_mask = (
            pos_mask.unsqueeze(dim=0)
            .unsqueeze(dim=0)
            .repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1)
        )

        # Circumvent the shape check in forward method
        xs_norm = self.distance.maybe_normalize(xs, dim=-1)
        sims_pos = self.distance.compute_mat(xs_norm, xs_norm)

        sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat(
            1, 1, batch_size // num_classes_batch, 1
        )
        sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2)

        sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to(
            sims_diff.device
        )
        sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1

        g = batch_size // num_classes_batch
        ap = torch.zeros(batch_size).to(embeddings.device)
        for i in range(num_classes_batch):
            for j in range(g):
                pos_rank = sims_pos_ranks[i, j]
                all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g]
                ap[i * g + j] = torch.sum(pos_rank / all_rank) / g

        miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype)
        loss = (1 - ap) * miner_weights

        return {
            "ap_loss": {
                "losses": loss,
                "indices": c_f.torch_arange_from_size(loss),
                "reduction_type": "element",
            }
        }
