from collections import defaultdict
from typing import List

import numpy as np
from datasets import Dataset
from tqdm import tqdm


class SpeakerEmbeddingGrouper:
    """
    Computes per-speaker averaged embeddings and adds them as a new column.

    Groups all embedding vectors by a speaker column, averages them, then
    maps the averaged vector back to every sample of that speaker.

    Can be used:
    - Inside SpeakerEmbeddingProcessor (after embedding + clustering)
    - Independently when embeddings already exist in the dataset
      (add_speaker_emb: false, group_sp_emb.do_this: true)
    """

    def __init__(self, settings, embedding_column: str):
        """
        Args:
            settings:         GroupEmbeddingSettings instance.
            embedding_column: Name of the source embedding column to average
                              (e.g. 'wavlm_embedding').
        """
        self.settings = settings
        self.embedding_column = embedding_column

    def group(self, dataset: Dataset) -> Dataset:
        """
        Add a grouped (per-speaker averaged) embedding column to the dataset.

        Steps:
        1. Collect all embeddings per unique speaker value, skipping None/NaN.
        2. Average embeddings per speaker.
        3. Map the averaged vector back to every sample via dataset.map().

        Samples whose speaker has no valid embeddings at all receive None
        in the output column (they are not filtered out).

        Args:
            dataset: HuggingFace Dataset containing both the embedding column
                     and the group-by column.

        Returns:
            Dataset with the new grouped embedding column added.
        """
        group_col = self.settings.group_by_column_name
        emb_col   = self.embedding_column
        out_col   = self.settings.grouped_embedding_column

        print(f"\n[Group Embeddings] Starting grouping...")
        print(f"[Group Embeddings] Group by:    '{group_col}'")
        print(f"[Group Embeddings] Source emb:  '{emb_col}'")
        print(f"[Group Embeddings] Output col:  '{out_col}'")
        print(f"[Group Embeddings] Dataset size: {len(dataset)} samples")

        # ── Step 1: collect embeddings per speaker ────────────────────────────
        speaker_embeddings: dict = defaultdict(list)

        for sample in tqdm(dataset, desc="[Group Embeddings] Collecting per speaker"):
            speaker   = sample[group_col]
            embedding = sample[emb_col]

            if embedding is None:
                continue

            emb_arr = np.array(embedding, dtype=np.float32)
            if np.any(np.isnan(emb_arr)):
                continue

            speaker_embeddings[speaker].append(emb_arr)

        n_speakers = len(speaker_embeddings)
        print(f"[Group Embeddings] Unique speakers with valid embeddings: {n_speakers}")

        # ── Step 2: average per speaker ───────────────────────────────────────
        averaged_embeddings: dict = {}

        for speaker, embeddings in tqdm(
            speaker_embeddings.items(), desc="[Group Embeddings] Averaging"
        ):
            avg_emb = np.mean(np.stack(embeddings, axis=0), axis=0)
            averaged_embeddings[speaker] = avg_emb.tolist()

        # ── Step 3: map back to every sample ─────────────────────────────────
        def _add_grouped_emb(sample):
            speaker = sample[group_col]
            return {out_col: averaged_embeddings.get(speaker, None)}

        dataset = dataset.map(
            _add_grouped_emb,
            num_proc=10,
            desc=f"[Group Embeddings] Adding '{out_col}' column",
        )

        print(f"[Group Embeddings] Done — '{out_col}' column added.\n")
        return dataset

    def get_preserved_columns(self) -> List[str]:
        """Return the list of columns produced by this processor."""
        return [self.settings.grouped_embedding_column]
