#!/usr/bin/env python3
"""
WavLM Audio Embedding Generator for HuggingFace Datasets
=========================================================

Clean OOP implementation for generating speaker embeddings using the WavLM model
and adding them as new columns to HuggingFace datasets with multi-GPU support.

Model: Orange/Speaker-wavLM-tbr
"""

import os
import warnings
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import threading
import queue

import torch
import torch.nn.functional as F
import numpy as np
from datasets import Dataset, DatasetDict, load_dataset
from tqdm import tqdm
import torch.multiprocessing as mp

# Import WavLM classes from spk_wavLM
from .spk_wavLM import EmbeddingsModel as WavLMEmbeddingsModel, compute_embedding

warnings.filterwarnings("ignore")


@dataclass
class EmbedderConfig:
    """Configuration for the WavLM audio embedder."""

    model_name: str = "Orange/Speaker-wavLM-tbr"
    target_sr: int = 16000
    max_audio_sec: float = 30.0  # WavLM supports up to 30 seconds
    batch_size: int = 32
    embedding_column_name: str = "wavlm_embedding"
    use_multiprocessing: bool = True  # Enable multi-GPU processing


class AudioProcessor:
    """
    Handles audio loading, preprocessing and normalization.
    """

    def __init__(self, target_sr: int = 16000, max_audio_sec: float = 25.0):
        """
        Initialize audio processor.

        Args:
            target_sr: Target sample rate in Hz
            max_audio_sec: Maximum audio duration in seconds
        """
        self.target_sr = target_sr
        self.max_audio_sec = max_audio_sec
        self.target_samples = int(max_audio_sec * target_sr)

    def process_audio(self, audio_dict: Dict[str, Any]) -> Optional[np.ndarray]:
        """
        Process audio from HuggingFace dataset format.

        Note: Audio is expected to be already resampled to target_sr via
        dataset.cast_column('audio', Audio(sampling_rate=target_sr))

        Args:
            audio_dict: Dictionary with 'array' and 'sampling_rate' keys

        Returns:
            Processed audio array or None if processing fails
        """
        try:
            audio_array = audio_dict['array']

            # Convert to torch tensor
            if isinstance(audio_array, np.ndarray):
                wav = torch.from_numpy(audio_array).float()
            else:
                wav = torch.tensor(audio_array, dtype=torch.float32)

            # Ensure 1D (mono)
            if wav.dim() > 1:
                wav = wav.mean(dim=0)

            # Truncate or pad to target length
            return self._finalize_waveform(wav)

        except Exception as e:
            print(f"Audio processing error: {e}")
            return None

    def _finalize_waveform(self, wav: torch.Tensor) -> Optional[np.ndarray]:
        """
        Truncate or pad waveform to target length.

        Args:
            wav: Audio waveform tensor

        Returns:
            Processed numpy array or None if empty
        """
        if wav.numel() == 0:
            return None

        current_len = wav.shape[0]

        if current_len > self.target_samples:
            wav = wav[:self.target_samples]
        elif current_len < self.target_samples:
            wav = F.pad(wav, (0, self.target_samples - current_len), "constant", 0.0)

        return wav.numpy()


class WavLMWorker:
    """
    Worker class for processing dataset chunks with WavLM model on a specific GPU.
    """

    def __init__(
        self,
        rank: int,
        config: EmbedderConfig,
        audio_processor: AudioProcessor
    ):
        """
        Initialize WavLM GPU worker.

        Args:
            rank: GPU device index
            config: Embedder configuration
            audio_processor: Audio preprocessing instance
        """
        self.rank = rank
        self.config = config
        self.audio_processor = audio_processor
        self.device = None
        self.model = None

    def initialize(self):
        """Initialize WavLM model and move to assigned GPU."""
        try:
            torch.cuda.set_device(self.rank)
            self.device = torch.device(f"cuda:{self.rank}")
        except Exception as e:
            print(f"[GPU {self.rank}] Failed to set device: {e}")
            self.device = torch.device("cpu")

        print(f"[GPU {self.rank}] Loading WavLM model...")
        self.model = WavLMEmbeddingsModel.from_pretrained(self.config.model_name)
        self.model.to(self.device).eval()
        print(f"[GPU {self.rank}] WavLM model ready")

    def process_batch(self, audio_dicts: List[Dict]) -> np.ndarray:
        """
        Process a batch of audio samples with WavLM using true batching with dynamic padding.

        Uses dynamic padding: pads to the maximum length in the current batch (up to max_audio_sec),
        rather than always padding to max_audio_sec. This dramatically reduces computation for
        batches with shorter audio samples.

        Args:
            audio_dicts: List of audio dictionaries from dataset

        Returns:
            Numpy array of embeddings (batch_size, embedding_dim)
        """
        # Max length limit from config (30 seconds = 480000 samples @ 16kHz)
        max_allowed_length = int(self.config.max_audio_sec * self.config.target_sr)

        processed_audios = []
        error_indices = []

        # First pass: process all audio and collect them
        for idx, audio_dict in enumerate(audio_dicts):
            try:
                audio_array = audio_dict['array']
                # Audio is already resampled to 16kHz via cast_column

                # Convert to torch tensor
                if isinstance(audio_array, np.ndarray):
                    sig = torch.from_numpy(audio_array).float()
                else:
                    sig = torch.tensor(audio_array, dtype=torch.float32)

                # Ensure 1D (mono)
                if sig.dim() > 1:
                    sig = sig.mean(dim=0)

                # Truncate if exceeds max allowed length
                if sig.shape[0] > max_allowed_length:
                    sig = sig[:max_allowed_length]

                processed_audios.append(sig)

            except Exception as e:
                print(f"[GPU {self.rank}] Error processing audio at index {idx}: {e}")
                # Add zero tensor as placeholder (will be padded later)
                processed_audios.append(torch.zeros(1))
                error_indices.append(idx)

        # Find maximum length in this batch (dynamic!)
        batch_max_length = max(sig.shape[0] for sig in processed_audios)

        # Second pass: pad all to batch_max_length
        batch_tensors = []
        for sig in processed_audios:
            current_len = sig.shape[0]
            if current_len < batch_max_length:
                sig = F.pad(sig, (0, batch_max_length - current_len), "constant", 0.0)
            batch_tensors.append(sig)

        # Stack all tensors into a single batch (batch_size, batch_max_length)
        batch = torch.stack(batch_tensors).to(self.device)

        # Single forward pass for entire batch - TRUE BATCHING with DYNAMIC PADDING!
        with torch.no_grad():
            embeddings = self.model(batch)  # (batch_size, 128)

        # Convert to numpy
        embeddings_np = embeddings.cpu().numpy()

        # Replace error samples with zero embeddings
        for idx in error_indices:
            embeddings_np[idx] = np.zeros(128)

        return embeddings_np


def gpu_worker_process(rank: int, config: EmbedderConfig, dataset_shard, audio_column: str, result_queue):
    """
    Worker process function for multi-GPU processing with data prefetching.

    Uses a background thread to load data asynchronously while GPU processes the current batch,
    eliminating data loading bottlenecks.

    Args:
        rank: GPU device index
        config: Embedder configuration
        dataset_shard: Subset of dataset to process
        audio_column: Name of audio column
        result_queue: Queue to store results
    """
    audio_processor = AudioProcessor(config.target_sr, config.max_audio_sec)
    worker = WavLMWorker(rank, config, audio_processor)
    worker.initialize()

    results = []
    num_samples = len(dataset_shard)
    batches = [
        dataset_shard[i:min(i + config.batch_size, num_samples)]
        for i in range(0, num_samples, config.batch_size)
    ]

    # Setup prefetching queue (prefetch 2-3 batches ahead)
    prefetch_queue = queue.Queue(maxsize=3)
    loading_complete = threading.Event()

    def data_loader_thread():
        """Background thread that loads batches asynchronously."""
        try:
            for batch in batches:
                audio_dicts = batch[audio_column]
                prefetch_queue.put(audio_dicts)
        except Exception as e:
            print(f"[GPU {rank}] Data loader thread error: {e}")
        finally:
            loading_complete.set()

    # Start background data loading thread
    loader_thread = threading.Thread(target=data_loader_thread, daemon=True)
    loader_thread.start()

    pbar = tqdm(
        total=num_samples,
        desc=f"GPU {rank}",
        position=rank,
        leave=True
    )

    # Process batches as they become available
    for _ in range(len(batches)):
        # Get next batch from prefetch queue (non-blocking if already loaded!)
        audio_dicts = prefetch_queue.get()

        # GPU processing
        embeddings = worker.process_batch(audio_dicts)

        # Store results
        for emb in embeddings:
            results.append(emb)

        pbar.update(len(embeddings))

    pbar.close()

    # Wait for loader thread to complete
    loader_thread.join(timeout=5.0)

    # Put results in queue
    result_queue.put((rank, results))
    print(f"[GPU {rank}] Completed processing {len(results)} samples")


class WavLMEmbedder:
    """
    Main class for generating WavLM speaker embeddings on HuggingFace datasets.

    Supports multi-GPU parallel processing with automatic batching and progress tracking.
    """

    def __init__(self, config: EmbedderConfig):
        """
        Initialize the embedder.

        Args:
            config: Configuration object with model and processing parameters
        """
        self.config = config
        self.audio_processor = AudioProcessor(
            target_sr=config.target_sr,
            max_audio_sec=config.max_audio_sec
        )

        # Auto-detect number of GPUs
        self.num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1

    def embed_dataset(
        self,
        dataset: Dataset,
        audio_column: str = "audio",
        output_path: Optional[str] = None,
        split_name: str = "train"
    ) -> Dataset:
        """
        Add embedding column to a HuggingFace dataset.

        Args:
            dataset: HuggingFace Dataset object
            audio_column: Name of the column containing audio data
            output_path: Optional path to save the dataset locally
            split_name: Name of the split being processed (for logging)

        Returns:
            Dataset with added embedding column
        """
        print(f"\n{'='*60}")
        print(f"Processing {split_name} split with {len(dataset)} examples")
        print(f"Using {self.num_gpus} GPU(s)")
        print(f"{'='*60}\n")

        if audio_column not in dataset.column_names:
            raise ValueError(f"Column '{audio_column}' not found in dataset. "
                           f"Available columns: {dataset.column_names}")

        # Use multi-GPU processing if enabled and multiple GPUs available
        if self.config.use_multiprocessing and self.num_gpus > 1:
            embeddings = self._multi_gpu_embed(dataset, audio_column)
        else:
            embeddings = self._single_gpu_embed(dataset, audio_column)

        # Add embeddings to dataset
        dataset = dataset.add_column(
            self.config.embedding_column_name,
            embeddings
        )

        return dataset

    def _single_gpu_embed(self, dataset: Dataset, audio_column: str) -> List[List[float]]:
        """
        Generate embeddings using single GPU with data prefetching.

        Args:
            dataset: Dataset to process
            audio_column: Name of audio column

        Returns:
            List of embeddings
        """
        worker = WavLMWorker(0, self.config, self.audio_processor)
        worker.initialize()

        embeddings = []
        num_samples = len(dataset)

        # Create list of batch indices
        batch_indices = [
            (i, min(i + self.config.batch_size, num_samples))
            for i in range(0, num_samples, self.config.batch_size)
        ]

        # Setup prefetching queue
        prefetch_queue = queue.Queue(maxsize=3)

        def data_loader_thread():
            """Background thread for loading data asynchronously."""
            try:
                for start_idx, end_idx in batch_indices:
                    batch = dataset[start_idx:end_idx]
                    audio_dicts = batch[audio_column]
                    prefetch_queue.put(audio_dicts)
            except Exception as e:
                print(f"Data loader thread error: {e}")

        # Start background data loading
        loader_thread = threading.Thread(target=data_loader_thread, daemon=True)
        loader_thread.start()

        pbar = tqdm(total=num_samples, desc="Generating embeddings")

        # Process batches as they become available
        for _ in range(len(batch_indices)):
            audio_dicts = prefetch_queue.get()

            batch_embeddings = worker.process_batch(audio_dicts)
            embeddings.extend(batch_embeddings.tolist())

            pbar.update(len(audio_dicts))

        pbar.close()

        # Wait for loader thread
        loader_thread.join(timeout=5.0)

        return embeddings

    def _multi_gpu_embed(self, dataset: Dataset, audio_column: str) -> List[List[float]]:
        """
        Generate embeddings using multiple GPUs in parallel.

        Args:
            dataset: Dataset to process
            audio_column: Name of audio column

        Returns:
            List of embeddings in original order
        """
        # Split dataset into shards for each GPU
        num_samples = len(dataset)
        shard_size = (num_samples + self.num_gpus - 1) // self.num_gpus

        shards = []
        for i in range(self.num_gpus):
            start_idx = i * shard_size
            end_idx = min(start_idx + shard_size, num_samples)
            if start_idx < num_samples:
                shards.append(dataset.select(range(start_idx, end_idx)))

        # Create result queue
        mp_context = mp.get_context('spawn')
        result_queue = mp_context.Queue()

        # Spawn worker processes
        processes = []
        for rank in range(len(shards)):
            p = mp_context.Process(
                target=gpu_worker_process,
                args=(rank, self.config, shards[rank], audio_column, result_queue)
            )
            p.start()
            processes.append(p)

        # Collect results
        results_dict = {}
        for _ in range(len(shards)):
            rank, embeddings = result_queue.get()
            results_dict[rank] = embeddings

        # Wait for all processes to complete
        for p in processes:
            p.join()

        # Combine results in correct order
        all_embeddings = []
        for rank in range(len(shards)):
            all_embeddings.extend(results_dict[rank])

        # Convert to list of lists
        return [emb.tolist() if isinstance(emb, np.ndarray) else emb
                for emb in all_embeddings]


def process_dataset(
    dataset_name: str,
    audio_column: str = "audio",
    output_path: Optional[str] = None,
    config: Optional[EmbedderConfig] = None,
    split: Optional[str] = None
) -> Dataset:
    """
    Convenience function to process a HuggingFace dataset.

    Args:
        dataset_name: Name of the dataset on HuggingFace Hub
        audio_column: Name of the column containing audio data
        output_path: Optional path to save processed dataset
        config: Optional custom configuration
        split: Specific split to process (e.g., 'train', 'test') or None for all

    Returns:
        Processed dataset with embeddings
    """
    if config is None:
        config = EmbedderConfig()

    # Load dataset
    print(f"Loading dataset: {dataset_name}...")
    dataset = load_dataset(dataset_name, split=split)

    # Initialize embedder
    embedder = WavLMEmbedder(config)

    # Process dataset(s)
    if isinstance(dataset, DatasetDict):
        processed = DatasetDict()
        for split_name, split_dataset in dataset.items():
            split_output = f"{output_path}/{split_name}" if output_path else None
            processed[split_name] = embedder.embed_dataset(
                split_dataset,
                audio_column=audio_column,
                output_path=split_output,
                split_name=split_name
            )
        return processed
    else:
        return embedder.embed_dataset(
            dataset,
            audio_column=audio_column,
            output_path=output_path,
            split_name=split or "dataset"
        )


# Removed __main__ block - use run_embedder.py instead
