import numpy as np
import torch

from ..distances import LpDistance
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


class AngularLoss(BaseMetricLossFunction):
    """
    Implementation of https://arxiv.org/abs/1708.01682
    Args:
        alpha: The angle (as described in the paper), specified in degrees.
    """

    def __init__(self, alpha=40, **kwargs):
        super().__init__(**kwargs)
        c_f.assert_distance_type(
            self, LpDistance, p=2, power=1, normalize_embeddings=True
        )
        self.alpha = torch.tensor(np.radians(alpha))
        self.add_to_recordable_attributes(list_of_names=["alpha"], is_stat=False)
        self.add_to_recordable_attributes(list_of_names=["average_angle"], is_stat=True)

    def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
        c_f.labels_required(labels)
        anchors, positives, keep_mask, anchor_idx, positive_idx = self.get_pairs(
            embeddings, labels, indices_tuple, ref_emb, ref_labels
        )
        if anchors is None:
            return self.zero_losses()

        sq_tan_alpha = torch.tan(self.alpha) ** 2
        ap_dot = torch.sum(anchors * positives, dim=1, keepdim=True)
        ap_matmul_embeddings = torch.matmul(
            (anchors + positives), (ref_emb.unsqueeze(2))
        )
        ap_matmul_embeddings = ap_matmul_embeddings.squeeze(2).t()

        final_form = (4 * sq_tan_alpha * ap_matmul_embeddings) - (
            2 * (1 + sq_tan_alpha) * ap_dot
        )
        losses = lmu.logsumexp(final_form, keep_mask=keep_mask, add_one=True)
        return {
            "loss": {
                "losses": losses,
                "indices": (anchor_idx, positive_idx),
                "reduction_type": "pos_pair",
            }
        }

    def get_pairs(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
        a1, p, a2, _ = lmu.convert_to_pairs(indices_tuple, labels, ref_labels)
        if len(a1) == 0 or len(a2) == 0:
            return [None] * 5
        anchors = self.distance.normalize(embeddings[a1])
        positives = self.distance.normalize(ref_emb[p])
        keep_mask = labels[a1].unsqueeze(1) != ref_labels.unsqueeze(0)
        self.set_stats(anchors, positives, embeddings, ref_emb, keep_mask)
        return anchors, positives, keep_mask, a1, p

    def set_stats(self, anchors, positives, embeddings, ref_emb, keep_mask):
        if self.collect_stats:
            with torch.no_grad():
                centers = (anchors + positives) / 2
                ap_dist = self.distance.pairwise_distance(anchors, positives)
                nc_dist = self.distance.get_norm(
                    centers - ref_emb.unsqueeze(1), dim=2
                ).t()
                angles = torch.atan(ap_dist.unsqueeze(1) / (2 * nc_dist))
                average_angle = torch.sum(angles[keep_mask]) / torch.sum(keep_mask)
                self.average_angle = np.degrees(average_angle.item())
