import numpy as np
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.preprocessing import normalize
from datasets import Dataset


class SpeakerMakerUnsupervised:
    """
    Unsupervised speaker clustering system using UMAP dimensionality reduction
    and HDBSCAN clustering to assign speaker labels to audio samples.

    This class performs automatic speaker identification by:
    1. Filtering out invalid embeddings (NaN values)
    2. Reducing high-dimensional speaker embeddings (e.g., 128-dim) to lower dimensions using UMAP
    3. Normalizing the reduced embeddings with L2 normalization
    4. Clustering the normalized embeddings using HDBSCAN to identify distinct speakers
    5. Assigning speaker_id labels to each sample
    6. Filtering out noise samples (speaker_id == -1)

    Attributes:
        cfg (dict): Configuration dictionary with 'UMAP' and 'HDBSCAN' parameters
        embedding_column (str): Name of the column containing speaker embeddings
        pre_reducer (UMAP): UMAP dimensionality reduction model
        clusterer (HDBSCAN): HDBSCAN clustering model

    Example:
        >>> cluster_config = {
        ...     'UMAP': {'n_components': 5, 'metric': 'cosine'},
        ...     'HDBSCAN': {'min_cluster_size': 10}
        ... }
        >>> speaker_maker = SpeakerMakerUnsupervised(cluster_config)
        >>> labeled_dataset = speaker_maker(dataset)
    """

    def __init__(self, cluster_config, embedding_column: str = 'wavlm_embedding',
                 speaker_column: str = 'speaker_id'):
        """
        Initialize the unsupervised speaker clustering system.

        Args:
            cluster_config (dict): Configuration dictionary containing:
                - 'UMAP': dict of UMAP parameters (n_components, metric, etc.)
                - 'HDBSCAN': dict of HDBSCAN parameters (min_cluster_size, etc.)
            embedding_column (str): Name of the dataset column containing speaker
                embeddings. Defaults to 'wavlm_embedding'.
            speaker_column (str): Name of the output column for cluster IDs.
                Defaults to 'speaker_id'.
        """
        self.cfg = cluster_config
        self.embedding_column = embedding_column
        self.speaker_column = speaker_column
        self.pre_reducer = UMAP(**self.cfg['UMAP'])
        self.clusterer = HDBSCAN(**self.cfg['HDBSCAN'])

    def proc_log(self, labels: np.ndarray) -> None:
        """
        Log clustering results and statistics.

        Prints information about:
        - Total number of samples processed
        - Number of unique speakers identified
        - Number of noise samples (speaker_id == -1)
        - Percentage of samples assigned to valid clusters

        Args:
            labels (np.ndarray): Array of speaker labels from clustering.
                Labels >= 0 represent speaker clusters, -1 represents noise.
        """
        total_samples = len(labels)
        unique_speakers = np.unique(labels[labels >= 0])
        num_speakers = len(unique_speakers)
        num_noise = (labels == -1).sum()
        num_valid = total_samples - num_noise
        valid_percentage = (num_valid / total_samples) * 100 if total_samples > 0 else 0

        print("\n" + "="*60)
        print("Speaker Clustering Results")
        print("="*60)
        print(f"Total samples processed:      {total_samples}")
        print(f"Unique speakers identified:   {num_speakers}")
        print(f"Valid cluster assignments:    {num_valid} ({valid_percentage:.1f}%)")
        print(f"Noise samples (unassigned):   {num_noise} ({100-valid_percentage:.1f}%)")

        if num_speakers > 0:
            # Calculate cluster size statistics
            speaker_counts = np.bincount(labels[labels >= 0])
            print(f"\nCluster size statistics:")
            print(f"  Min samples per speaker:    {speaker_counts.min()}")
            print(f"  Max samples per speaker:    {speaker_counts.max()}")
            print(f"  Mean samples per speaker:   {speaker_counts.mean():.1f}")
            print(f"  Median samples per speaker: {np.median(speaker_counts):.0f}")

        print("="*60 + "\n")

    def __call__(self, dataset: Dataset) -> Dataset:
        """
        Perform unsupervised speaker clustering on the dataset.

        This method processes speaker embeddings through the following pipeline:
        1. Filter out samples with NaN embeddings
        2. Extract embeddings into numpy array
        3. Apply UMAP dimensionality reduction
        4. Normalize reduced embeddings with L2 normalization
        5. Cluster normalized embeddings using HDBSCAN
        6. Add 'speaker_id' column to dataset
        7. Remove noise samples (speaker_id == -1)

        Args:
            dataset (Dataset): HuggingFace Dataset containing speaker embeddings
                in the column specified by self.embedding_column

        Returns:
            Dataset: Dataset with added 'speaker_id' column, noise samples removed.
                Only samples with valid speaker assignments (speaker_id >= 0) are retained.

        Note:
            - Samples with NaN embeddings are filtered out before processing
            - HDBSCAN assigns label -1 to noise samples (outliers)
            - Noise samples are automatically removed from the final dataset
        """
        print(f"\n[Speaker Clustering] Starting pipeline...")
        print(f"[Speaker Clustering] Dataset size: {len(dataset)} samples")
        print(f"[Speaker Clustering] (NaN filtering done upstream in SpeakerEmbeddingProcessor)")

        # Extract embeddings
        print(f"[Speaker Clustering] Extracting embeddings from '{self.embedding_column}' column...")
        emb = np.array(dataset[self.embedding_column])
        print(f"[Speaker Clustering] Embeddings shape: {emb.shape}")

        # Apply UMAP dimensionality reduction
        print(f"[Speaker Clustering] Applying UMAP dimensionality reduction...")
        embeddings_5d = self.pre_reducer.fit_transform(emb)
        print(f"[Speaker Clustering] Reduced embeddings shape: {embeddings_5d.shape}")

        # Normalize embeddings
        print(f"[Speaker Clustering] Normalizing reduced embeddings (L2)...")
        emb_normalized = normalize(embeddings_5d, norm='l2')

        # Perform clustering
        print(f"[Speaker Clustering] Running HDBSCAN clustering...")
        labels = self.clusterer.fit_predict(emb_normalized)

        # Add speaker labels to dataset
        print(f"[Speaker Clustering] Adding '{self.speaker_column}' column to dataset...")
        dataset = dataset.add_column(name=self.speaker_column, column=labels)

        # Log clustering results
        self.proc_log(labels)

        # Remove noise samples
        print(f"[Speaker Clustering] Removing noise samples ({self.speaker_column} == -1)...")
        desc = 'Removing noise samples'
        dataset = dataset.filter(lambda i: i[self.speaker_column] >= 0, num_proc=10, desc=desc)
        print(f"[Speaker Clustering] Final dataset size: {len(dataset)} samples")
        print(f"[Speaker Clustering] Pipeline completed successfully!\n")

        return dataset


