#!/usr/bin/env python3
"""
Persistent Worker Pools for High-Throughput Pipeline

Stage 1 optimization: Eliminate per-video worker pool creation overhead.

Contains:
- GlobalVADPool: Persistent VAD workers with pre-loaded Silero models
- DownloadPool: Aggressive prefetcher for hiding download latency

Usage:
    from src.worker_pools import GlobalVADPool, DownloadPool

    # Initialize once at startup
    vad_pool = GlobalVADPool.get_instance(num_workers=32)

    # Use for all videos
    result = vad_pool.process_chunks(chunk_args)

    # Shutdown on exit
    vad_pool.shutdown()
"""

import os
import time
import logging
import threading
import hashlib
from queue import Queue, Empty
from typing import List, Dict, Tuple, Optional, Any, Callable
from concurrent.futures import ProcessPoolExecutor, Future, as_completed
from dataclasses import dataclass
from pathlib import Path

import numpy as np

logger = logging.getLogger("WorkerPools")


class GlobalVADPool:
    """
    Singleton persistent VAD worker pool.

    Workers are initialized once with Silero model loaded, then reused
    across all videos. Eliminates ~2-3s model loading overhead per video.

    Thread-safe singleton pattern ensures only one pool exists.
    """

    _instance: Optional['GlobalVADPool'] = None
    _lock = threading.Lock()

    def __init__(self, num_workers: int = 32):
        """
        Initialize the VAD pool. Use get_instance() instead.

        Args:
            num_workers: Number of parallel VAD workers
        """
        self.num_workers = num_workers
        self._executor: Optional[ProcessPoolExecutor] = None
        self._ready = False
        self._shutdown = False

    @classmethod
    def get_instance(cls, num_workers: int = 32) -> 'GlobalVADPool':
        """
        Get or create the singleton VAD pool instance.

        Thread-safe lazy initialization.
        """
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = cls(num_workers)
                    cls._instance._initialize()
        return cls._instance

    def _initialize(self):
        """Initialize the worker pool with VAD models."""
        if self._ready:
            return

        logger.info(f"Initializing GlobalVADPool with {self.num_workers} workers...")
        start = time.time()

        # Pre-download model to avoid worker race condition
        try:
            import torch
            torch.hub.load(
                repo_or_dir='snakers4/silero-vad',
                model='silero_vad',
                force_reload=False,
                onnx=False,
                trust_repo=True,
                verbose=False
            )
        except Exception as e:
            logger.warning(f"Pre-download attempt failed (will retry in workers): {e}")

        # Import the initializer from vad module
        from src.vad import _init_vad_worker

        # Create persistent executor
        self._executor = ProcessPoolExecutor(
            max_workers=self.num_workers,
            initializer=_init_vad_worker
        )

        # Warm up workers by submitting dummy tasks
        warmup_futures = []
        for _ in range(min(self.num_workers, 4)):
            future = self._executor.submit(lambda: True)
            warmup_futures.append(future)

        # Wait for warmup
        for f in warmup_futures:
            try:
                f.result(timeout=60)
            except Exception as e:
                logger.warning(f"Warmup task failed: {e}")

        elapsed = time.time() - start
        self._ready = True
        logger.info(f"GlobalVADPool ready in {elapsed:.1f}s ({self.num_workers} workers)")

    def process_chunks(
        self,
        chunk_args: List[Tuple[np.ndarray, int, dict, float]],
        worker_fn: Callable,
        timeout_per_chunk: int = 120,
        progress_callback: Optional[Callable[[int, int], None]] = None
    ) -> List[Any]:
        """
        Process VAD chunks using the persistent pool.

        Args:
            chunk_args: List of (audio_chunk, sample_rate, config_dict, chunk_start)
            worker_fn: Worker function to execute
            timeout_per_chunk: Timeout in seconds per chunk
            progress_callback: Optional callback(completed, total)

        Returns:
            List of results from all chunks (unordered)
        """
        if not self._ready or self._executor is None:
            raise RuntimeError("GlobalVADPool not initialized. Call get_instance() first.")

        if self._shutdown:
            raise RuntimeError("GlobalVADPool has been shut down.")

        results = []
        completed = 0
        total = len(chunk_args)

        # Submit all chunks
        futures: Dict[Future, int] = {}
        for i, args in enumerate(chunk_args):
            future = self._executor.submit(worker_fn, args)
            futures[future] = i

        # Collect results
        for future in as_completed(futures):
            try:
                result = future.result(timeout=timeout_per_chunk)
                results.extend(result)
                completed += 1

                if progress_callback:
                    progress_callback(completed, total)

            except Exception as e:
                logger.error(f"VAD chunk failed: {e}")

        return results

    def is_ready(self) -> bool:
        """Check if pool is ready to accept work."""
        return self._ready and not self._shutdown

    def shutdown(self, wait: bool = True):
        """
        Shutdown the worker pool.

        Args:
            wait: Wait for pending tasks to complete
        """
        if self._executor is not None:
            logger.info("Shutting down GlobalVADPool...")
            self._executor.shutdown(wait=wait)
            self._shutdown = True
            self._ready = False

    @classmethod
    def reset(cls):
        """Reset the singleton (mainly for testing)."""
        with cls._lock:
            if cls._instance is not None:
                cls._instance.shutdown(wait=False)
            cls._instance = None


@dataclass
class PrefetchedVideo:
    """Container for prefetched video data."""
    url: str
    audio_path: str
    metadata: Dict[str, Any]
    timestamp: float


class DownloadPool:
    """
    Aggressive download prefetcher.

    Maintains a queue of pre-downloaded videos so processing never waits
    for downloads. Downloads run in background threads.

    Usage:
        pool = DownloadPool(prefetch_depth=50)
        pool.start(video_urls)

        # Processing loop
        for _ in range(len(urls)):
            video = pool.get_next()  # Returns immediately (already downloaded)
            process(video)
            pool.mark_complete(video.url)

        pool.stop()
    """

    def __init__(
        self,
        prefetch_depth: int = 50,
        download_workers: int = 4,
        storage_path: str = "/dev/shm",
        max_storage_gb: float = 50.0
    ):
        """
        Initialize download pool.

        Args:
            prefetch_depth: How many videos to keep prefetched
            download_workers: Parallel download threads
            storage_path: Where to store downloaded audio (tmpfs recommended)
            max_storage_gb: Maximum storage to use
        """
        self.prefetch_depth = prefetch_depth
        self.download_workers = download_workers
        self.storage_path = Path(storage_path)
        self.max_storage_gb = max_storage_gb

        # Queues
        self._url_queue: Queue = Queue()  # URLs to download
        self._ready_queue: Queue = Queue(maxsize=prefetch_depth)  # Ready videos
        self._completed: set = set()  # URLs that are done

        # Threading
        self._threads: List[threading.Thread] = []
        self._running = False
        self._lock = threading.Lock()

        # Ensure storage exists
        self.storage_path.mkdir(parents=True, exist_ok=True)

    def start(self, video_urls: List[str]):
        """
        Start prefetching videos.

        Args:
            video_urls: List of video URLs to process
        """
        self._running = True

        # Add all URLs to queue
        for url in video_urls:
            self._url_queue.put(url)

        # Start download threads
        for i in range(self.download_workers):
            thread = threading.Thread(
                target=self._download_loop,
                name=f"Downloader-{i}",
                daemon=True
            )
            thread.start()
            self._threads.append(thread)

        logger.info(f"DownloadPool started: {self.download_workers} workers, "
                   f"prefetch depth {self.prefetch_depth}")

    def _download_loop(self):
        """Background thread that downloads videos."""
        from src.download import download_audio  # Import here to avoid circular

        while self._running:
            try:
                # Get next URL (block with timeout)
                url = self._url_queue.get(timeout=1.0)

                # Check if already completed
                with self._lock:
                    if url in self._completed:
                        continue

                # Generate unique path
                url_hash = hashlib.md5(url.encode()).hexdigest()[:12]
                audio_path = self.storage_path / f"prefetch_{url_hash}.wav"

                # Download
                try:
                    metadata = download_audio(url, str(audio_path))

                    video = PrefetchedVideo(
                        url=url,
                        audio_path=str(audio_path),
                        metadata=metadata or {},
                        timestamp=time.time()
                    )

                    # Add to ready queue (blocks if full)
                    self._ready_queue.put(video, timeout=60)
                    logger.debug(f"Prefetched: {url[:50]}...")

                except Exception as e:
                    logger.error(f"Download failed for {url}: {e}")
                    # Put back in queue for retry? Or skip?
                    # For now, skip failed downloads

            except Empty:
                continue
            except Exception as e:
                logger.error(f"Download loop error: {e}")

    def get_next(self, timeout: float = 300) -> Optional[PrefetchedVideo]:
        """
        Get next prefetched video.

        Blocks until a video is ready or timeout.

        Args:
            timeout: Maximum wait time in seconds

        Returns:
            PrefetchedVideo or None if timeout
        """
        try:
            return self._ready_queue.get(timeout=timeout)
        except Empty:
            return None

    def mark_complete(self, url: str, cleanup: bool = True):
        """
        Mark a video as completed and optionally cleanup.

        Args:
            url: Video URL
            cleanup: Delete the audio file
        """
        with self._lock:
            self._completed.add(url)

        if cleanup:
            url_hash = hashlib.md5(url.encode()).hexdigest()[:12]
            audio_path = self.storage_path / f"prefetch_{url_hash}.wav"
            try:
                if audio_path.exists():
                    audio_path.unlink()
            except Exception as e:
                logger.warning(f"Cleanup failed: {e}")

    def stop(self, wait: bool = True):
        """Stop the download pool."""
        self._running = False

        if wait:
            for thread in self._threads:
                thread.join(timeout=5.0)

        # Cleanup remaining files
        self.cleanup_all()
        logger.info("DownloadPool stopped")

    def cleanup_all(self):
        """Remove all prefetched files."""
        for path in self.storage_path.glob("prefetch_*.wav"):
            try:
                path.unlink()
            except Exception:
                pass

    @property
    def ready_count(self) -> int:
        """Number of videos ready for processing."""
        return self._ready_queue.qsize()

    @property
    def pending_count(self) -> int:
        """Number of videos still to download."""
        return self._url_queue.qsize()
