#!/usr/bin/env python3
"""
Voice Activity Detection with TRUE parallelism + Compute Monitoring.

=== OPTIMIZATION (v6.2) ===
Key improvements:
- Persistent workers: Load Silero model ONCE per worker via initializer
- Memory-efficient: Accept numpy array directly (no file re-read)
- Adaptive worker count based on detected CPU cores
- Optimized for 0.4s event detection (min_speech_ms = 200)
"""

import time
import logging
import warnings
import os
import sys
import contextlib
from typing import List, Dict, Tuple, Optional, Union
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np

# === SUPPRESS NNPACK WARNINGS (v6.9) ===
# NNPACK is for mobile CPU optimization, not needed on GPU machines
# These warnings spam logs and slow down processing
os.environ['NNPACK_DISABLE'] = '1'
warnings.filterwarnings('ignore', message='.*NNPACK.*')
warnings.filterwarnings('ignore', category=UserWarning, module='torch')

import torch
# Suppress PyTorch C++ warnings
torch.set_warn_always(False)


@contextlib.contextmanager
def suppress_stderr():
    """
    Context manager to suppress stderr (C++ warnings like NNPACK).
    
    === WHY THIS IS NEEDED (v6.9) ===
    PyTorch's NNPACK warnings come from C++ code and can't be suppressed
    via Python's warnings module. Redirecting stderr is the only way.
    """
    # Save original stderr
    old_stderr = sys.stderr
    # Redirect to /dev/null
    sys.stderr = open(os.devnull, 'w')
    try:
        yield
    finally:
        sys.stderr.close()
        sys.stderr = old_stderr

logger = logging.getLogger("FastPipelineV6.VAD")

# === GLOBAL VAD MODEL (loaded once per worker via initializer) ===
_vad_model = None
_vad_utils = None


def _init_vad_worker():
    """
    Worker initializer - loads Silero VAD model ONCE per worker process.
    
    === OPTIMIZATION (v6.2) ===
    Previously: Each worker loaded model for EVERY chunk (~200MB x 32 workers = 6.4GB!)
    Now: Model loaded ONCE per worker at startup, reused for all chunks
    
    === EDGE CASE HANDLING (v6.7) ===
    - Retry logic for network errors when multiple workers start
    - Graceful fallback to cached model
    - Exponential backoff on failures
    
    Benefit: ~50% VAD speedup, 80% less memory
    """
    global _vad_model, _vad_utils
    
    # === SUPPRESS NNPACK IN WORKER (v6.9) ===
    # Must be set in each subprocess before torch operations
    import os
    import warnings
    os.environ['NNPACK_DISABLE'] = '1'
    warnings.filterwarnings('ignore', message='.*NNPACK.*')
    warnings.filterwarnings('ignore', category=UserWarning)
    torch.set_warn_always(False)
    
    max_retries = 5
    for attempt in range(max_retries):
        try:
            _vad_model, _vad_utils = torch.hub.load(
                repo_or_dir='snakers4/silero-vad',
                model='silero_vad',
                force_reload=False,
                onnx=False,
                trust_repo=True,  # Avoid validation check
                verbose=False  # Suppress "Using cache found" messages
            )
            return  # Success!
            
        except Exception as e:
            if attempt < max_retries - 1:
                # Exponential backoff: 0.5s, 1s, 2s, 4s
                wait_time = 0.5 * (2 ** attempt)
                import time
                import random
                # Add jitter to avoid thundering herd
                jitter = random.uniform(0, 0.5)
                time.sleep(wait_time + jitter)
                continue
            else:
                # Last attempt failed - try to load from cache directory directly
                import os
                cache_dir = os.path.expanduser('~/.cache/torch/hub/snakers4_silero-vad_master')
                if os.path.exists(cache_dir):
                    try:
                        _vad_model, _vad_utils = torch.hub.load(
                            repo_or_dir=cache_dir,
                            model='silero_vad',
                            source='local',
                            force_reload=False,
                            onnx=False,
                            verbose=False
                        )
                        return
                    except:
                        pass
                
                # Final fallback failed
                raise RuntimeError(
                    f"Failed to load VAD model after {max_retries} attempts. "
                    f"Last error: {e}. Try running the main script first to download the model."
                )


def _vad_worker(args: Tuple[np.ndarray, int, dict, float]) -> List[Dict]:
    """
    Worker function for parallel VAD processing.
    
    Uses globally initialized model (loaded via _init_vad_worker).
    Accepts tuple for map() compatibility.
    """
    global _vad_model, _vad_utils
    audio_chunk, sample_rate, config_dict, chunk_start = args
    
    # Use pre-loaded model (from initializer)
    model = _vad_model
    (get_speech_timestamps, _, _, _, _) = _vad_utils
    
    # Convert to torch tensor
    wav_tensor = torch.from_numpy(audio_chunk).float()
    
    # Run VAD with config params
    # Wrap in suppress_stderr to silence C++ NNPACK warnings
    with suppress_stderr():
        speech_timestamps = get_speech_timestamps(
            wav_tensor,
            model,
            threshold=config_dict.get('vad_threshold', 0.5),
            min_speech_duration_ms=config_dict.get('vad_min_speech_ms', 200),  # For 0.4s events
            min_silence_duration_ms=config_dict.get('vad_min_silence_ms', 200),
            window_size_samples=config_dict.get('vad_window_size_samples', 512),
            speech_pad_ms=config_dict.get('vad_speech_pad_ms', 30),
            return_seconds=True,
            sampling_rate=sample_rate
        )
    
    # Adjust timestamps to global time
    segments = []
    for ts in speech_timestamps:
        global_start = chunk_start + ts['start']
        global_end = chunk_start + ts['end']
        segments.append({
            'start': float(global_start),
            'end': float(global_end),
            'duration': float(global_end - global_start)
        })
    
    return segments


def run_vad_from_buffer(
    waveform_np: np.ndarray,
    sample_rate: int,
    config,
    pool=None
) -> List[Dict[str, float]]:
    """
    Run Silero VAD on pre-loaded audio buffer with persistent workers.

    === OPTIMIZATION (v6.2) ===
    - Accepts numpy array directly (no file re-read)
    - Uses initializer to load model ONCE per worker

    === OPTIMIZATION (v7.9 - Stage 1) ===
    - Accepts optional GlobalVADPool for cross-video persistence
    - When pool provided: reuses workers across videos (eliminates 2-3s overhead)
    - When pool is None: creates local pool (backward compatible)

    Args:
        waveform_np: Audio samples as numpy array (mono)
        sample_rate: Sample rate
        config: Pipeline configuration
        pool: Optional GlobalVADPool for persistent workers

    Returns:
        List of speech segments with start/end/duration
    """
    total_duration = len(waveform_np) / sample_rate
    pool_mode = "global" if pool is not None else "local"
    logger.info(f"🎤 Running Silero VAD ({pool_mode} pool, {config.vad_workers} workers)...")
    start = time.time()

    # Create chunks for parallel processing
    chunk_size_samples = int(config.vad_chunk_size * sample_rate)
    chunk_args: List[Tuple[np.ndarray, int, dict, float]] = []

    # Convert config to dict for pickling
    config_dict = {
        'vad_threshold': config.vad_threshold,
        'vad_min_speech_ms': config.vad_min_speech_ms,
        'vad_min_silence_ms': config.vad_min_silence_ms,
        'vad_window_size_samples': config.vad_window_size_samples,
        'vad_speech_pad_ms': config.vad_speech_pad_ms,
    }

    for i in range(0, len(waveform_np), chunk_size_samples):
        chunk_data = waveform_np[i:i + chunk_size_samples]
        chunk_start_time = i / sample_rate
        chunk_args.append((chunk_data, sample_rate, config_dict, chunk_start_time))

    logger.info(f"   Processing {len(chunk_args)} chunks ({config.vad_chunk_size}s each)...")

    all_segments = []
    completed = 0

    # === USE GLOBAL POOL IF PROVIDED (Stage 1 optimization) ===
    if pool is not None:
        def progress_cb(done, total):
            nonlocal completed
            completed = done
            if done % max(1, total // 10) == 0:
                logger.info(f"   VAD progress: {done}/{total} chunks")

        all_segments = pool.process_chunks(
            chunk_args,
            worker_fn=_vad_worker,
            timeout_per_chunk=120,
            progress_callback=progress_cb
        )
    else:
        # === LEGACY: Create local pool (backward compatible) ===
        # Pre-download model to avoid worker initialization race
        try:
            torch.hub.load(
                repo_or_dir='snakers4/silero-vad',
                model='silero_vad',
                force_reload=False,
                onnx=False,
                trust_repo=True,
                verbose=False
            )
        except Exception as e:
            logger.warning(f"   Pre-download attempt failed (will retry in workers): {e}")

        with ProcessPoolExecutor(
            max_workers=config.vad_workers,
            initializer=_init_vad_worker
        ) as executor:
            futures = {executor.submit(_vad_worker, args): i for i, args in enumerate(chunk_args)}

            for future in as_completed(futures):
                try:
                    segments = future.result(timeout=120)
                    all_segments.extend(segments)
                    completed += 1

                    if completed % max(1, len(chunk_args) // 10) == 0:
                        logger.info(f"   VAD progress: {completed}/{len(chunk_args)} chunks")
                except Exception as e:
                    logger.error(f"VAD chunk failed: {e}")

    # Sort by start time
    all_segments.sort(key=lambda x: x['start'])

    # Merge adjacent segments (fix boundary splits from chunking)
    merged = _merge_adjacent_vad_segments(all_segments, merge_gap=0.1)

    total_speech = sum(s['duration'] for s in merged)
    elapsed = time.time() - start

    logger.info(f"✅ VAD: {elapsed:.1f}s | {len(merged)} segments | {total_speech:.1f}s speech")
    logger.info(f"   Speedup: {total_duration/elapsed:.1f}x realtime | Mode: {pool_mode}")

    return merged


def run_vad_parallel(audio_path: str, config, audio_buffer=None, pool=None) -> List[Dict[str, float]]:
    """
    Run Silero VAD with TRUE parallelism and persistent workers.

    === OPTIMIZATION (v6.2) ===
    - If audio_buffer provided, uses pre-loaded numpy array (no file re-read)
    - Uses initializer to load Silero model ONCE per worker (not per chunk!)

    === OPTIMIZATION (v7.9 - Stage 1) ===
    - Accepts optional GlobalVADPool for cross-video persistence
    - Eliminates ~2-3s pool creation overhead per video

    Args:
        audio_path: Path to audio file (used if audio_buffer is None)
        config: Pipeline configuration
        audio_buffer: Optional AudioBuffer with pre-loaded waveform
        pool: Optional GlobalVADPool for persistent workers

    Returns:
        List of speech segments with start/end/duration
    """
    # Use buffer if provided, otherwise load from file
    if audio_buffer is not None:
        return run_vad_from_buffer(audio_buffer.waveform_np, audio_buffer.sample_rate, config, pool=pool)

    # Legacy path: load from file (still uses persistent workers)
    import torchaudio
    waveform, sr = torchaudio.load(audio_path)
    waveform_np = waveform.squeeze(0).numpy()
    return run_vad_from_buffer(waveform_np, sr, config, pool=pool)


def _merge_adjacent_vad_segments(segments: List[Dict], merge_gap: float = 0.1) -> List[Dict]:
    """
    Merge VAD segments that were split at chunk boundaries.
    
    Args:
        segments: Sorted list of VAD segments
        merge_gap: Maximum gap (seconds) to merge across
    """
    if not segments:
        return []
    
    merged = []
    current = segments[0].copy()
    
    for seg in segments[1:]:
        gap = seg['start'] - current['end']
        
        # Merge if very close (likely chunk boundary artifact)
        if gap < merge_gap:
            current['end'] = seg['end']
            current['duration'] = current['end'] - current['start']
        else:
            merged.append(current)
            current = seg.copy()
    
    merged.append(current)
    return merged


def find_silence_boundaries(vad_segments: List[Dict], 
                           total_duration: float,
                           min_silence: float = 0.3) -> List[Dict]:
    """
    Find silence regions (gaps between VAD segments).
    
    Useful for:
    - Finding optimal chunk boundaries
    - Marking non-speech regions
    
    Args:
        vad_segments: List of speech segments from VAD
        total_duration: Total audio duration
        min_silence: Minimum silence duration to report
    
    Returns:
        List of silence segments with start/end/duration
    """
    silences = []
    
    # Check start of audio
    if vad_segments and vad_segments[0]['start'] >= min_silence:
        silences.append({
            'start': 0.0,
            'end': vad_segments[0]['start'],
            'duration': vad_segments[0]['start']
        })
    
    # Gaps between segments
    for i in range(len(vad_segments) - 1):
        gap_start = vad_segments[i]['end']
        gap_end = vad_segments[i + 1]['start']
        gap_duration = gap_end - gap_start
        
        if gap_duration >= min_silence:
            silences.append({
                'start': gap_start,
                'end': gap_end,
                'duration': gap_duration
            })
    
    # Check end of audio
    if vad_segments and total_duration - vad_segments[-1]['end'] >= min_silence:
        silences.append({
            'start': vad_segments[-1]['end'],
            'end': total_duration,
            'duration': total_duration - vad_segments[-1]['end']
        })
    
    return silences


def find_optimal_cut_point(vad_segments: List[Dict], 
                          target_time: float,
                          window: float = 30.0) -> float:
    """
    Find the best silence point near target_time for chunking.
    
    Strategy: Find the largest gap within ±window of target_time.
    This ensures we don't cut in the middle of speech.
    
    Args:
        vad_segments: List of speech segments from VAD
        target_time: Ideal cut point
        window: Search window (seconds) around target
    
    Returns:
        Best cut point (center of largest gap, or target if no gaps)
    """
    best_cut = target_time
    best_gap_size = 0.0
    
    for i in range(len(vad_segments) - 1):
        gap_start = vad_segments[i]['end']
        gap_end = vad_segments[i + 1]['start']
        gap_size = gap_end - gap_start
        gap_center = (gap_start + gap_end) / 2
        
        # Check if gap center is within window of target
        if abs(gap_center - target_time) <= window and gap_size > best_gap_size:
            best_gap_size = gap_size
            best_cut = gap_center
    
    return best_cut
