import os
from typing import List

import numpy as np
from datasets import Dataset, Audio

from .embedder import EmbedderConfig, WavLMEmbedder
from .speaker_clustering import SpeakerMakerUnsupervised
from .group_embeddings import SpeakerEmbeddingGrouper


class SpeakerEmbeddingProcessor:
    """
    Orchestrates speaker embedding generation and optional downstream steps
    (unsupervised clustering, per-speaker averaged embeddings) for a
    HuggingFace dataset.

    Receives a dataset (already loaded at 22050 Hz by DatasetProcessor),
    generates WavLM speaker embeddings using a separate 16 kHz audio channel,
    then optionally runs:
      - UMAP + HDBSCAN clustering to assign cluster-based speaker IDs
      - Per-speaker embedding averaging (group_sp_emb)

    Returns the dataset with all produced columns added, ready for the
    tokenization step.
    """

    def __init__(self, settings):
        """
        Args:
            settings: SpeakerEmbeddingSettings instance from ConfigManager.
        """
        self.settings = settings
        embedder_config = EmbedderConfig(
            model_name=settings.model_name,
            target_sr=settings.target_sr,
            max_audio_sec=settings.max_audio_sec,
            batch_size=settings.batch_size,
            embedding_column_name=settings.embedding_column,
            use_multiprocessing=settings.use_multiprocessing,
        )
        self.embedder = WavLMEmbedder(embedder_config)

    def process(self, dataset: Dataset, audio_column: str) -> Dataset:
        """
        Add embedding column and optionally speaker/cluster column and grouped
        embedding column to the dataset.

        Audio resampling strategy:
          1. Switch audio column to raw-bytes mode (decode=False) — avoids
             double decode/resample quality loss.
          2. Copy raw bytes to a temporary "audio_spk" column.
          3. Decode "audio_spk" at target_sr (16 kHz) for embedding.
          4. Decode original audio column back at 22050 Hz for tokenization.
          5. Filter samples shorter than 0.7 s (too short for reliable embedding).
          6. Generate embeddings; remove "audio_spk".
          7. Filter samples with NaN embeddings (mandatory, before all downstream steps).
          8. Optionally run UMAP + HDBSCAN clustering.
          9. Optionally run per-speaker embedding averaging (group_sp_emb).

        Args:
            dataset:      HuggingFace Dataset loaded at 22050 Hz.
            audio_column: Name of the audio column (from DatasetConfig).

        Returns:
            Dataset with added embedding (and optional clustering / grouping)
            columns. The audio column is still at 22050 Hz — ready for
            tokenization.
        """
        os.environ['DATASETS_AUDIO_BACKEND'] = 'torchaudio'

        # Step 1 — remember SR and switch to raw bytes mode
        original_sr = dataset.features[audio_column].sampling_rate
        dataset = dataset.cast_column(audio_column, Audio(decode=False))

        # Step 2 — copy raw bytes to audio_spk
        dataset = dataset.map(
            lambda i: {"audio_spk": i[audio_column]},
            num_proc=10,
            desc="Copying audio for speaker embedding",
        )

        # Step 3 — decode audio_spk at target_sr (16 kHz)
        dataset = dataset.cast_column(
            "audio_spk",
            Audio(sampling_rate=self.settings.target_sr, decode=True),
        )

        # Step 4 — restore original audio column at 22050 Hz
        dataset = dataset.cast_column(
            audio_column,
            Audio(sampling_rate=original_sr, decode=True),
        )

        # Step 5 — filter samples shorter than 0.7 s
        dataset = dataset.filter(
            lambda i: (
                len(i["audio_spk"]["array"]) / i["audio_spk"]["sampling_rate"] >= 0.7
            ),
            num_proc=5,
            desc="Filtering short audio",
        )

        # Step 6 — generate embeddings using audio_spk
        dataset = self.embedder.embed_dataset(
            dataset=dataset,
            audio_column="audio_spk",
            split_name="speaker_embedding",
        )

        # Remove temporary column
        dataset = dataset.remove_columns("audio_spk")

        # Step 7 — filter NaN embeddings (mandatory before clustering / grouping)
        emb_col = self.settings.embedding_column
        initial_size = len(dataset)
        dataset = dataset.filter(
            lambda i: np.isnan(np.array(i[emb_col])).sum() == 0,
            num_proc=5,
            desc="Filtering NaN embeddings",
        )
        removed = initial_size - len(dataset)
        if removed:
            print(
                f"[SpeakerEmbeddingProcessor] Removed {removed} samples "
                f"with NaN embeddings. Remaining: {len(dataset)}"
            )

        # Step 8 — optional UMAP + HDBSCAN clustering
        if self.settings.do_clusters:
            speaker_maker = SpeakerMakerUnsupervised(
                cluster_config={
                    "UMAP": self.settings.umap_params,
                    "HDBSCAN": self.settings.hdbscan_params,
                },
                embedding_column=self.settings.embedding_column,
                speaker_column=self.settings.clustering_speaker_column,
            )
            dataset = speaker_maker(dataset)

        # Step 9 — optional per-speaker embedding averaging
        if self.settings.group_emb and self.settings.group_emb.do_this:
            grouper = SpeakerEmbeddingGrouper(
                settings=self.settings.group_emb,
                embedding_column=self.settings.embedding_column,
            )
            dataset = grouper.group(dataset)

        return dataset

    def get_preserved_columns(self) -> List[str]:
        """
        Return the list of column names added by this processor.
        These must be passed through the tokenization step (preserve_columns).
        """
        cols = [self.settings.embedding_column]
        if self.settings.do_clusters:
            cols.append(self.settings.clustering_speaker_column)
        if self.settings.group_emb and self.settings.group_emb.do_this:
            cols.append(self.settings.group_emb.grouped_embedding_column)
        return cols
