import itertools
from collections import defaultdict

import torch
from torch.utils.data.sampler import Sampler

from ..utils import common_functions as c_f


# Inspired by
# https://github.com/kunhe/Deep-Metric-Learning-Baselines/blob/master/datasets.py
class HierarchicalSampler(Sampler):
    def __init__(
        self,
        labels,
        batch_size,
        samples_per_class,
        batches_per_super_tuple=4,
        super_classes_per_batch=2,
        inner_label=0,
        outer_label=1,
    ):
        """
        labels: 2D array, where rows correspond to elements, and columns correspond to the hierarchical labels
        batch_size: because this is a BatchSampler the batch size must be specified
        samples_per_class: number of instances to sample for a specific class. set to "all" if all element in a class
        batches_per_super_tuples: number of batches to create for a pair of categories (or super labels)
        inner_label: columns index corresponding to classes
        outer_label: columns index corresponding to the level of hierarchy for the pairs
        """
        if torch.is_tensor(labels):
            labels = labels.cpu().numpy()

        self.batch_size = batch_size
        self.batches_per_super_tuple = batches_per_super_tuple
        self.samples_per_class = samples_per_class
        self.super_classes_per_batch = super_classes_per_batch

        # checks
        assert (
            self.batch_size % super_classes_per_batch == 0
        ), f"batch_size should be a multiple of {super_classes_per_batch}"
        self.sub_batch_len = self.batch_size // super_classes_per_batch

        if self.samples_per_class != "all":
            assert self.samples_per_class > 0
            assert (
                self.sub_batch_len % self.samples_per_class == 0
            ), "batch_size not a multiple of samples_per_class"

        all_super_labels = set(labels[:, outer_label])
        self.super_image_lists = {slb: defaultdict(list) for slb in all_super_labels}
        for idx, instance in enumerate(labels):
            slb, lb = instance[outer_label], instance[inner_label]
            self.super_image_lists[slb][lb].append(idx)

        self.super_pairs = list(
            itertools.combinations(all_super_labels, super_classes_per_batch)
        )
        self.reshuffle()

    def __iter__(
        self,
    ):
        self.reshuffle()
        for batch in self.batches:
            yield batch

    def __len__(
        self,
    ):
        return len(self.batches)

    def reshuffle(self):
        batches = []
        for combinations in self.super_pairs:
            for b in range(self.batches_per_super_tuple):
                batch = []
                for slb in combinations:
                    sub_batch = []
                    all_classes = list(self.super_image_lists[slb].keys())
                    c_f.NUMPY_RANDOM.shuffle(all_classes)
                    for cl in all_classes:
                        if len(sub_batch) >= self.sub_batch_len:
                            break
                        instances = self.super_image_lists[slb][cl]
                        samples_per_class = (
                            self.samples_per_class
                            if self.samples_per_class != "all"
                            else len(instances)
                        )
                        if len(sub_batch) + samples_per_class > self.sub_batch_len:
                            continue
                        sub_batch.extend(
                            c_f.safe_random_choice(instances, size=samples_per_class)
                        )

                    batch.extend(sub_batch)

                c_f.NUMPY_RANDOM.shuffle(batch)
                batches.append(batch)

        c_f.NUMPY_RANDOM.shuffle(batches)
        self.batches = batches
