#!/usr/bin/env python3
"""
Speaker embedding extraction with OOM protection and compute monitoring.

=== OPTIMIZATION (v6.2) ===
Key improvements:
- Accept audio buffer: No file re-read when buffer provided
- GPU memory-aware batching: Query torch.cuda.mem_get_info() for dynamic batch sizing
- Adaptive batching based on segment length
- OOM protection with fallback to individual processing
- Memory-efficient processing order (short segments first)
"""

import time
import logging
from typing import List, Dict, Optional
import numpy as np
import torch
import torchaudio
from src.models import MODELS

logger = logging.getLogger("FastPipelineV6.Embeddings")


def _extract_embedding_chunked(
    audio_chunk: np.ndarray,
    sr: int,
    max_window_samples: int,
    device: torch.device,
    embedding_model,
    window_overlap: float = 0.5
) -> np.ndarray:
    """
    Extract embedding from a long segment by chunking and averaging.
    
    === OPTIMIZATION (v6.1) ===
    Instead of discarding long segments (>10s), we:
    1. Split into overlapping windows
    2. Extract embedding from each window
    3. Average the embeddings
    
    This preserves ALL your data - longer samples are valuable for TTS!
    
    Args:
        audio_chunk: Audio samples (numpy array)
        sr: Sample rate
        max_window_samples: Maximum samples per window
        device: Torch device
        embedding_model: Embedding extraction model
        window_overlap: Overlap ratio between windows (0.5 = 50%)
    
    Returns:
        Averaged embedding vector
    """
    total_samples = len(audio_chunk)
    
    # Calculate window step (with overlap)
    step_samples = int(max_window_samples * (1 - window_overlap))
    
    # Extract embeddings from each window
    window_embeddings = []
    
    for start_idx in range(0, total_samples, step_samples):
        end_idx = min(start_idx + max_window_samples, total_samples)
        window = audio_chunk[start_idx:end_idx]
        
        # Skip windows that are too short (less than 1s)
        if len(window) < sr:
            continue
        
        try:
            padded = torch.tensor(window).unsqueeze(0).to(device)
            wav_lens = torch.tensor([1.0]).to(device)
            
            with torch.no_grad():
                emb = embedding_model.encode_batch(padded, wav_lens).cpu().numpy()
                if emb.ndim == 3:
                    emb = emb.squeeze()
                window_embeddings.append(emb)
            
            del padded, wav_lens
        except Exception as e:
            logger.debug(f"Window embedding failed: {e}")
            continue
    
    if not window_embeddings:
        return None
    
    # Average all window embeddings
    return np.mean(window_embeddings, axis=0)


def extract_embeddings_batched(
    audio_path: str,
    segments: List[Dict],
    config,
    audio_buffer=None
) -> Dict[int, np.ndarray]:
    """
    Extract speaker embeddings with adaptive batching to prevent OOM.
    
    === OPTIMIZATION (v6.2) ===
    - Accept audio_buffer: No file re-read when buffer provided
    - GPU memory-aware batching: Use torch.cuda.mem_get_info() for dynamic batch sizing
    - Long segments (>10s) are chunked and averaged - preserving valuable data
    
    Strategy to prevent OOM while preserving ALL data:
    1. Query available VRAM before each batch
    2. Short segments: batch processing for efficiency
    3. Long segments: chunk into windows → extract → average
    4. Adaptive batch size based on VRAM + segment lengths
    5. Aggressive cache clearing between batches
    
    Args:
        audio_path: Path to audio file
        segments: List of segments to extract embeddings for
        config: Pipeline configuration
        audio_buffer: Optional AudioBuffer (avoids file re-read)
    
    Returns:
        Dict mapping segment index to embedding array
    """
    logger.info(f"🔢 Extracting embeddings ({len(segments)} segments)...")
    start = time.time()
    
    # Clear cache before starting
    MODELS.clear_cache(aggressive=True)
    
    device = MODELS.get_device()
    embedding_model = MODELS.embedding_model
    
    # === OPTIMIZATION: Use buffer if provided ===
    if audio_buffer is not None:
        waveform_np = audio_buffer.waveform_np
        sr = audio_buffer.sample_rate
        logger.info(f"   Using pre-loaded audio buffer (no file I/O)")
    else:
        # Load audio (legacy path)
        waveform, sr = torchaudio.load(audio_path)
        waveform_np = waveform.squeeze(0).numpy()
    
    embeddings = {}
    min_samples = int(0.3 * sr)  # 0.3s minimum
    max_samples = config.max_embedding_length  # Max for single-pass embedding
    
    # Prepare valid segments - separate short (batchable) from long (chunked)
    short_items = []  # Can be batched
    long_items = []   # Need chunking
    skipped_too_short = 0
    
    for i, seg in enumerate(segments):
        # Skip overlap and non-speech segments
        if seg.get('speaker') in ['OVERLAP', 'NON_SPEECH']:
            continue
        
        start_sample = int(seg['start'] * sr)
        end_sample = int(seg['end'] * sr)
        chunk = waveform_np[start_sample:end_sample]
        
        if len(chunk) < min_samples:
            skipped_too_short += 1
            continue
        
        # === CHANGED (v6.1): Don't skip long segments - chunk them instead! ===
        if len(chunk) > max_samples:
            # Long segment: will be processed via chunking
            long_items.append((i, chunk))
        else:
            # Short segment: can be batched
            short_items.append((i, torch.tensor(chunk)))
    
    if skipped_too_short > 0:
        logger.info(f"   Skipped {skipped_too_short} segments (too short <0.3s)")
    
    # === PROCESS LONG SEGMENTS (chunked embedding extraction) ===
    if long_items:
        logger.info(f"   Processing {len(long_items)} long segments via chunking (no data loss!)")
        for i, chunk in long_items:
            emb = _extract_embedding_chunked(chunk, sr, max_samples, device, embedding_model)
            if emb is not None:
                embeddings[i] = emb
        MODELS.clear_cache(aggressive=True)
    
    # === PROCESS SHORT SEGMENTS (batched) ===
    valid_items = short_items
    
    if not valid_items:
        if embeddings:
            logger.info(f"   All segments were long - processed via chunking")
        else:
            logger.warning("   No valid segments for embedding extraction")
        return embeddings
    
    # Sort by length (process shorter segments first - more memory efficient)
    valid_items.sort(key=lambda x: len(x[1]))
    
    # Compute adaptive batch size
    batch_size = _compute_adaptive_batch_size(valid_items, config.embedding_batch_size)
    total_batches = (len(valid_items) + batch_size - 1) // batch_size
    
    logger.info(f"   Processing {len(valid_items)} segments in {total_batches} batches (size={batch_size})")
    
    # Process in batches
    for batch_num, batch_start in enumerate(range(0, len(valid_items), batch_size)):
        batch_end = min(batch_start + batch_size, len(valid_items))
        batch_items = valid_items[batch_start:batch_end]
        
        indices = [item[0] for item in batch_items]
        audio_chunks = [item[1] for item in batch_items]
        
        # Pad to same length
        max_len = max(len(c) for c in audio_chunks)
        
        # Check estimated memory usage
        estimated_size = len(audio_chunks) * max_len * 4  # 4 bytes per float32
        if estimated_size > 500_000_000:  # 500MB threshold
            logger.warning(f"   Batch {batch_num+1} too large ({estimated_size/1e9:.2f}GB), processing individually")
            _process_individually(indices, audio_chunks, device, embedding_model, embeddings)
            MODELS.clear_cache(aggressive=True)
            continue
        
        # Normal batch processing
        try:
            _process_batch(indices, audio_chunks, max_len, device, embedding_model, embeddings)
        except Exception as e:
            logger.error(f"   Batch {batch_num+1} failed: {e}, trying individual processing")
            _process_individually(indices, audio_chunks, device, embedding_model, embeddings)
        
        # Clear cache every few batches
        if (batch_num + 1) % 5 == 0:
            MODELS.clear_cache(aggressive=True)
    
    # Final cleanup
    MODELS.clear_cache(aggressive=True)
    
    elapsed = time.time() - start
    logger.info(f"✅ Embeddings: {elapsed:.1f}s | {len(embeddings)}/{len(segments)} extracted")
    
    return embeddings


def _compute_adaptive_batch_size(valid_items: List, base_batch: int) -> int:
    """
    Compute optimal batch size based on segment lengths AND available VRAM.
    
    === OPTIMIZATION (v6.2) ===
    Query torch.cuda.mem_get_info() for dynamic batch sizing instead of
    relying solely on average length heuristics.
    """
    if not valid_items:
        return base_batch
    
    # Get average segment length
    avg_len = sum(len(item[1]) for item in valid_items) / len(valid_items)
    
    # === NEW: Factor in available VRAM ===
    try:
        from src.audio_buffer import compute_optimal_batch_size
        vram_batch = compute_optimal_batch_size(
            num_items=len(valid_items),
            avg_item_samples=int(avg_len),
            base_batch=base_batch,
            min_batch=2,
            max_batch=64,
            vram_headroom_gb=2.0
        )
    except (ImportError, Exception):
        # Fallback to length-based heuristic
        vram_batch = base_batch
    
    # Also apply length-based heuristic
    if avg_len < 16000:  # < 1s
        length_batch = min(base_batch * 2, 64)
    elif avg_len < 48000:  # < 3s
        length_batch = base_batch
    else:  # > 3s
        length_batch = max(base_batch // 2, 4)
    
    # Take the more conservative of the two
    return min(vram_batch, length_batch)


def _process_batch(
    indices: List[int],
    audio_chunks: List[torch.Tensor],
    max_len: int,
    device: torch.device,
    embedding_model,
    embeddings: Dict[int, np.ndarray]
):
    """Process a batch of segments."""
    padded = torch.zeros(len(audio_chunks), max_len)
    wav_lens = torch.zeros(len(audio_chunks))
    
    for j, chunk in enumerate(audio_chunks):
        padded[j, :len(chunk)] = chunk
        wav_lens[j] = len(chunk) / max_len
    
    padded = padded.to(device)
    wav_lens = wav_lens.to(device)
    
    with torch.no_grad():
        embs = embedding_model.encode_batch(padded, wav_lens).cpu().numpy()
        if embs.ndim == 3 and embs.shape[1] == 1:
            embs = embs.squeeze(1)
        elif embs.ndim == 3:
            embs = embs.reshape(embs.shape[0], -1)
    
    for j, idx in enumerate(indices):
        embeddings[idx] = embs[j] if embs.ndim > 1 else embs
    
    # Cleanup
    del padded, wav_lens, embs


def _process_individually(
    indices: List[int],
    audio_chunks: List[torch.Tensor],
    device: torch.device,
    embedding_model,
    embeddings: Dict[int, np.ndarray]
):
    """Process segments one at a time (fallback)."""
    for idx, chunk in zip(indices, audio_chunks):
        try:
            padded_single = chunk.unsqueeze(0).to(device)
            wav_lens_single = torch.tensor([1.0]).to(device)
            
            with torch.no_grad():
                emb = embedding_model.encode_batch(padded_single, wav_lens_single).cpu().numpy()
                if emb.ndim == 3:
                    emb = emb.squeeze()
                embeddings[idx] = emb
            
            del padded_single, wav_lens_single
        except Exception as e:
            logger.error(f"   Individual embedding failed for {idx}: {e}")
