"""
regroup_embeddings.py

Standalone script to recompute grouped (per-speaker averaged) embeddings
after manually filtering a dataset (e.g. by cosine similarity threshold).

Usage:
    venv/bin/python regroup_embeddings.py \
        --dataset path/to/filtered_dataset \
        --embedding-col wavlm_embedding \
        --group-col label \
        --output-col grouped_sp_emb \
        --output path/to/output_dataset \
        [--hub your-username/dataset-name]

The input dataset must already be saved to disk (Arrow format, as produced
by datasets.save_to_disk). It must contain both the embedding column and
the group-by column.

The script rewrites the grouped_sp_emb column in-place and saves the result
to --output (local) and/or pushes to the Hub (--hub).
"""

import argparse
from collections import defaultdict

import numpy as np
from datasets import load_from_disk
from tqdm import tqdm


def regroup(dataset, embedding_col: str, group_col: str, output_col: str):
    print(f"\n[Regroup] Dataset size:   {len(dataset)} samples")
    print(f"[Regroup] Embedding col:  '{embedding_col}'")
    print(f"[Regroup] Group by:       '{group_col}'")
    print(f"[Regroup] Output col:     '{output_col}'")

    # Step 1: collect embeddings per speaker
    speaker_embeddings: dict = defaultdict(list)

    for sample in tqdm(dataset, desc="[Regroup] Collecting per speaker"):
        speaker = sample[group_col]
        embedding = sample[embedding_col]
        if embedding is None:
            continue
        emb_arr = np.array(embedding, dtype=np.float32)
        if np.any(np.isnan(emb_arr)):
            continue
        speaker_embeddings[speaker].append(emb_arr)

    n_speakers = len(speaker_embeddings)
    print(f"[Regroup] Unique speakers with valid embeddings: {n_speakers}")

    # Step 2: average per speaker
    averaged_embeddings: dict = {}
    for speaker, embeddings in tqdm(speaker_embeddings.items(), desc="[Regroup] Averaging"):
        avg_emb = np.mean(np.stack(embeddings, axis=0), axis=0)
        averaged_embeddings[speaker] = avg_emb.tolist()

    # Step 3: map back
    def _add_grouped_emb(sample):
        return {output_col: averaged_embeddings.get(sample[group_col], None)}

    # Remove old column if it exists
    if output_col in dataset.column_names:
        dataset = dataset.remove_columns([output_col])

    dataset = dataset.map(
        _add_grouped_emb,
        num_proc=10,
        desc=f"[Regroup] Writing '{output_col}'",
    )

    print(f"[Regroup] Done.\n")
    return dataset


def main():
    parser = argparse.ArgumentParser(
        description="Recompute grouped speaker embeddings after dataset filtering."
    )
    parser.add_argument(
        "--dataset", required=True,
        help="Path to the filtered dataset saved with save_to_disk()"
    )
    parser.add_argument(
        "--embedding-col", default="wavlm_embedding",
        help="Name of the per-utterance embedding column (default: wavlm_embedding)"
    )
    parser.add_argument(
        "--group-col", default="label",
        help="Column to group by (speaker ID / cluster label, default: label)"
    )
    parser.add_argument(
        "--output-col", default="grouped_sp_emb",
        help="Name of the output grouped embedding column (default: grouped_sp_emb)"
    )
    parser.add_argument(
        "--output", required=True,
        help="Path to save the resulting dataset locally (save_to_disk)"
    )
    parser.add_argument(
        "--hub", default=None,
        help="HuggingFace repo to push the result to (optional, e.g. your-org/dataset)"
    )
    args = parser.parse_args()

    print(f"[Regroup] Loading dataset from: {args.dataset}")
    dataset = load_from_disk(args.dataset)
    print(f"[Regroup] Loaded: {len(dataset)} samples, columns: {dataset.column_names}")

    dataset = regroup(
        dataset,
        embedding_col=args.embedding_col,
        group_col=args.group_col,
        output_col=args.output_col,
    )

    print(f"[Regroup] Saving to disk: {args.output}")
    dataset.save_to_disk(args.output)
    print(f"[Regroup] Saved.")

    if args.hub:
        print(f"[Regroup] Pushing to Hub: {args.hub}")
        dataset.push_to_hub(args.hub, private=True)
        print(f"[Regroup] Pushed.")

    print("\n[Regroup] All done.")


if __name__ == "__main__":
    main()