# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import random
from collections import Counter

import torch


class EM:
    """
    EM algorithm used to quantize the columns of W to minimize

                         ||W - W_hat||^2

    Args:
        - W: weight matrix of size (in_features x out_features)
        - n_iter: number of k-means iterations
        - n_centroids: number of centroids (size of codebook)
        - eps: for cluster reassignment when an empty cluster is found
        - max_tentatives for cluster reassignment when an empty cluster is found
        - verbose: print error after each iteration

    Remarks:
        - If one cluster is empty, the most populated cluster is split into
          two clusters
        - All the relevant dimensions are specified in the code
    """

    def __init__(
        self, W, n_centroids=256, n_iter=20, eps=1e-6, max_tentatives=30, verbose=True
    ):
        self.W = W
        self.n_centroids = n_centroids
        self.n_iter = n_iter
        self.eps = eps
        self.max_tentatives = max_tentatives
        self.verbose = verbose
        self.centroids = torch.Tensor()
        self.assignments = torch.Tensor()
        self.objective = []

    def initialize_centroids(self):
        """
        Initializes the centroids by sampling random columns from W.
        """

        in_features, out_features = self.W.size()
        indices = torch.randint(
            low=0, high=out_features, size=(self.n_centroids,)
        ).long()
        self.centroids = self.W[:, indices].t()  # (n_centroids x in_features)

    def step(self, i):
        """
        There are two standard steps for each iteration: expectation (E) and
        minimization (M). The E-step (assignment) is performed with an exhaustive
        search and the M-step (centroid computation) is performed with
        the exact solution.

        Args:
            - i: step number

        Remarks:
            - The E-step heavily uses PyTorch broadcasting to speed up computations
              and reduce the memory overhead
        """

        # assignments (E-step)
        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features)
        n_empty_clusters = self.resolve_empty_clusters()

        # centroids (M-step)
        for k in range(self.n_centroids):
            W_k = self.W[:, self.assignments == k]  # (in_features x size_of_cluster_k)
            self.centroids[k] = W_k.mean(dim=1)  # (in_features)

        # book-keeping
        obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item()
        self.objective.append(obj)
        if self.verbose:
            logging.info(
                f"Iteration: {i},\t"
                f"objective: {obj:.6f},\t"
                f"resolved empty clusters: {n_empty_clusters}"
            )

    def resolve_empty_clusters(self):
        """
        If one cluster is empty, the most populated cluster is split into
        two clusters by shifting the respective centroids. This is done
        iteratively for a fixed number of tentatives.
        """

        # empty clusters
        counts = Counter(map(lambda x: x.item(), self.assignments))
        empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
        n_empty_clusters = len(empty_clusters)

        tentatives = 0
        while len(empty_clusters) > 0:
            # given an empty cluster, find most populated cluster and split it into two
            k = random.choice(list(empty_clusters))
            m = counts.most_common(1)[0][0]
            e = torch.randn_like(self.centroids[m]) * self.eps
            self.centroids[k] = self.centroids[m].clone()
            self.centroids[k] += e
            self.centroids[m] -= e

            # recompute assignments
            distances = self.compute_distances()  # (n_centroids x out_features)
            self.assignments = torch.argmin(distances, dim=0)  # (out_features)

            # check for empty clusters
            counts = Counter(map(lambda x: x.item(), self.assignments))
            empty_clusters = set(range(self.n_centroids)) - set(counts.keys())

            # increment tentatives
            if tentatives == self.max_tentatives:
                logging.info(
                    f"Could not resolve all empty clusters, {len(empty_clusters)} remaining"
                )
                raise EmptyClusterResolveError
            tentatives += 1

        return n_empty_clusters

    def compute_distances(self):
        """
        For every centroid m, computes

                          ||M - m[None, :]||_2

        Remarks:
            - We rely on PyTorch's broadcasting to speed up computations
              and reduce the memory overhead
            - Without chunking, the sizes in the broadcasting are modified as:
              (n_centroids x n_samples x out_features) -> (n_centroids x out_features)
            - The broadcasting computation is automatically chunked so that
              the tensors fit into the memory of the GPU
        """

        nb_centroids_chunks = 1

        while True:
            try:
                return torch.cat(
                    [
                        (self.W[None, :, :] - centroids_c[:, :, None]).norm(p=2, dim=1)
                        for centroids_c in self.centroids.chunk(
                            nb_centroids_chunks, dim=0
                        )
                    ],
                    dim=0,
                )
            except RuntimeError:
                nb_centroids_chunks *= 2

    def assign(self):
        """
        Assigns each column of W to its closest centroid, thus essentially
        performing the E-step in train().

        Remarks:
            - The function must be called after train() or after loading
              centroids using self.load(), otherwise it will return empty tensors
        """

        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features)

    def save(self, path, layer):
        """
        Saves centroids and assignments.

        Args:
            - path: folder used to save centroids and assignments
        """

        torch.save(self.centroids, os.path.join(path, "{}_centroids.pth".format(layer)))
        torch.save(
            self.assignments, os.path.join(path, "{}_assignments.pth".format(layer))
        )
        torch.save(self.objective, os.path.join(path, "{}_objective.pth".format(layer)))

    def load(self, path, layer):
        """
        Loads centroids and assignments from a given path

        Args:
            - path: folder use to load centroids and assignments
        """

        self.centroids = torch.load(
            os.path.join(path, "{}_centroids.pth".format(layer))
        )
        self.assignments = torch.load(
            os.path.join(path, "{}_assignments.pth".format(layer))
        )
        self.objective = torch.load(
            os.path.join(path, "{}_objective.pth".format(layer))
        )


class EmptyClusterResolveError(Exception):
    pass
