"""3-stage async pipeline: Download → ExtractVAD → GPU Encode.

Keeps CPU and GPU both busy by overlapping stages across different videos:
  - Download workers: prefetch videos from R2 (2 threads)
  - ExtractVAD workers: ffmpeg pipe + peak normalize + Silero-VAD (N threads)
  - GPU encoder: drain ready queue, encode segments, pack shards (main thread)

The GPU should never wait for input once the pipeline is primed.

Throughput math (A100, 16 cores):
  FFmpeg (no loudnorm, pipe): ~15-20s/video
  VAD: ~17s/video
  GPU encode: 2.7s/video
  With 6 extract workers: 1 PreparedVideo every ~5.5s → GPU never starved
"""

from __future__ import annotations

import logging
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, Future
from dataclasses import dataclass, field
from pathlib import Path
from queue import Queue, Empty
from threading import Event, Thread

import torch

from codecbench.pipeline.config import PipelineConfig, VADConfig
from codecbench.pipeline.r2_client import (
    R2Client, extract_audio_pipe, normalize_audio_peak,
)
from codecbench.pipeline.vad import segment_tensor, Segment
from codecbench.pipeline.encoder import HotEncoder, EncodedSegment
from codecbench.pipeline.shard_packer import ShardBuffer, VideoTokens, pack_shard

logger = logging.getLogger(__name__)

_SENTINEL = object()


@dataclass
class PreparedVideo:
    """A video fully ready for GPU encoding: audio extracted, VAD done, segments cut."""
    video_id: str
    language: str
    bucket: str
    segments: list[Segment]
    usable_audio_s: float
    num_segments: int
    extract_time_s: float
    vad_time_s: float
    video_path: Path | None = None


@dataclass
class PipelineStats:
    """Running stats for the async pipeline."""
    videos_processed: int = 0
    total_audio_s: float = 0.0
    total_extract_s: float = 0.0
    total_vad_s: float = 0.0
    total_encode_s: float = 0.0
    total_download_s: float = 0.0
    total_pack_s: float = 0.0
    total_upload_s: float = 0.0
    videos_failed: int = 0


class DownloadItem:
    """A claimed video ready for download."""
    __slots__ = ("video_id", "language", "bucket")
    def __init__(self, video_id: str, language: str, bucket: str):
        self.video_id = video_id
        self.language = language
        self.bucket = bucket


class DownloadedItem:
    """A downloaded video waiting for extract+VAD."""
    __slots__ = ("video_id", "language", "bucket", "path", "download_time_s")
    def __init__(self, video_id: str, language: str, bucket: str,
                 path: Path | None, download_time_s: float):
        self.video_id = video_id
        self.language = language
        self.bucket = bucket
        self.path = path
        self.download_time_s = download_time_s


class AsyncPipeline:
    """3-stage pipeline that keeps CPU and GPU maximally busy.

    Stage 1 (Download):      2 threads pull video files from R2
    Stage 2 (Extract+VAD):   N threads run ffmpeg pipe + peak norm + Silero-VAD
    Stage 3 (GPU Encode):    main thread drains ready_queue, encodes, packs shards

    Usage:
        pipeline = AsyncPipeline(cfg, encoder, r2)
        pipeline.start()

        # Feed videos (e.g., from Supabase claims or R2 listing)
        for video in videos:
            pipeline.submit(video_id, language, bucket)

        # Drain encoded results
        while True:
            prepared = pipeline.get_ready(timeout=120)
            if prepared is None:
                break
            encoded = encoder.encode_segments(prepared.segments)
            ...

        pipeline.stop()
    """

    def __init__(self, cfg: PipelineConfig, r2: R2Client):
        self._cfg = cfg
        self._r2 = r2
        self._stop = Event()
        self.stats = PipelineStats()

        # Stage 1: Download queue → download workers
        self._download_q: Queue[DownloadItem | object] = Queue(
            maxsize=cfg.worker.prefetch_videos + 2
        )
        self._download_pool = ThreadPoolExecutor(
            max_workers=2, thread_name_prefix="dl"
        )

        # Stage 1→2: downloaded items waiting for extract+VAD
        self._downloaded_q: Queue[DownloadedItem | object] = Queue(
            maxsize=cfg.worker.extract_workers + 2
        )

        # Stage 2: Extract+VAD workers
        self._extract_pool = ThreadPoolExecutor(
            max_workers=cfg.worker.extract_workers, thread_name_prefix="ext"
        )

        # Stage 2→3: prepared videos ready for GPU
        self._ready_q: Queue[PreparedVideo | object] = Queue(
            maxsize=cfg.worker.ready_queue_depth
        )

        self._tmp_dir = Path(cfg.worker.local_tmp_dir)
        self._tmp_dir.mkdir(parents=True, exist_ok=True)
        self._video_dir = self._tmp_dir / "videos"
        self._video_dir.mkdir(parents=True, exist_ok=True)

        self._download_futures: list[Future] = []
        self._extract_futures: list[Future] = []

        # Failed videos — surfaced to worker for replacement claims
        self._failed_q: Queue[str] = Queue()

        # Pump threads: move items between queues via worker pools
        self._dl_pump: Thread | None = None
        self._ext_pump: Thread | None = None

    def start(self) -> None:
        """Start background pump threads that feed items through the pipeline."""
        self._stop.clear()
        self._dl_pump = Thread(target=self._download_pump, daemon=True, name="dl-pump")
        self._ext_pump = Thread(target=self._extract_pump, daemon=True, name="ext-pump")
        self._dl_pump.start()
        self._ext_pump.start()
        logger.info(
            "AsyncPipeline started: 2 download workers, %d extract+VAD workers, "
            "ready queue depth %d",
            self._cfg.worker.extract_workers, self._cfg.worker.ready_queue_depth,
        )

    def submit(self, video_id: str, language: str, bucket: str) -> None:
        """Submit a video for processing. Non-blocking until download queue is full."""
        if self._stop.is_set():
            return
        self._download_q.put(DownloadItem(video_id, language, bucket))

    def get_ready(self, timeout: float = 120.0) -> PreparedVideo | None:
        """Block until a PreparedVideo is ready for GPU encoding.

        Returns None if pipeline is stopped/drained.
        """
        try:
            item = self._ready_q.get(timeout=timeout)
            if item is _SENTINEL:
                return None
            return item
        except Empty:
            if self._stop.is_set():
                return None
            raise

    def stop(self) -> None:
        """Gracefully stop all pipeline stages."""
        self._stop.set()
        # Unblock pump threads
        self._download_q.put(_SENTINEL)
        self._downloaded_q.put(_SENTINEL)
        self._ready_q.put(_SENTINEL)
        self._download_pool.shutdown(wait=False)
        self._extract_pool.shutdown(wait=False)
        if self._dl_pump:
            self._dl_pump.join(timeout=5)
        if self._ext_pump:
            self._ext_pump.join(timeout=5)
        logger.info("AsyncPipeline stopped")

    def drain(self) -> None:
        """Signal that no more videos will be submitted. Wait for pipeline to flush."""
        self._download_q.put(_SENTINEL)

    def get_failed(self) -> list[str]:
        """Non-blocking: drain all failed video IDs so worker can mark FAILED + claim replacements."""
        failed = []
        while True:
            try:
                failed.append(self._failed_q.get_nowait())
            except Empty:
                break
        return failed

    @property
    def ready_count(self) -> int:
        return self._ready_q.qsize()

    @property
    def pending_count(self) -> int:
        return self._download_q.qsize() + self._downloaded_q.qsize()

    # ── Stage 1: Download pump ──────────────────────────────────────────

    def _download_pump(self) -> None:
        """Take items from download_q, dispatch to download workers, put results in downloaded_q."""
        while not self._stop.is_set():
            try:
                item = self._download_q.get(timeout=5)
                if item is _SENTINEL:
                    # Wait for all in-flight downloads to complete before propagating sentinel
                    for fut in self._download_futures:
                        try:
                            fut.result(timeout=600)
                        except Exception:
                            pass
                    self._downloaded_q.put(_SENTINEL)
                    break
                fut = self._download_pool.submit(self._do_download, item)
                self._download_futures.append(fut)
            except Empty:
                continue
            except Exception:
                if not self._stop.is_set():
                    logger.error("Download pump error: %s", traceback.format_exc())

    def _do_download(self, item: DownloadItem) -> None:
        """Download one video from R2."""
        if self._stop.is_set():
            return
        t0 = time.perf_counter()
        try:
            path = self._r2.download_video(item.video_id, item.bucket, self._video_dir, language=item.language)
            elapsed = time.perf_counter() - t0
            self._downloaded_q.put(DownloadedItem(
                item.video_id, item.language, item.bucket, path, elapsed,
            ))
        except Exception as e:
            logger.error("Download failed for %s: %s", item.video_id, e)
            self._downloaded_q.put(DownloadedItem(
                item.video_id, item.language, item.bucket, None, time.perf_counter() - t0,
            ))

    # ── Stage 2: Extract+VAD pump ───────────────────────────────────────

    def _extract_pump(self) -> None:
        """Take items from downloaded_q, dispatch to extract+VAD workers, put results in ready_q."""
        while not self._stop.is_set():
            try:
                item = self._downloaded_q.get(timeout=5)
                if item is _SENTINEL:
                    # Wait for all in-flight extract jobs to finish
                    for fut in self._extract_futures:
                        try:
                            fut.result(timeout=300)
                        except Exception:
                            pass
                    self._ready_q.put(_SENTINEL)
                    break
                if item.path is None:
                    self.stats.videos_failed += 1
                    self._failed_q.put(item.video_id)
                    continue
                fut = self._extract_pool.submit(self._do_extract_and_vad, item)
                self._extract_futures.append(fut)
            except Empty:
                continue
            except Exception:
                if not self._stop.is_set():
                    logger.error("Extract pump error: %s", traceback.format_exc())

    def _do_extract_and_vad(self, item: DownloadedItem) -> None:
        """Extract audio via ffmpeg pipe, peak normalize, run VAD, produce PreparedVideo."""
        if self._stop.is_set():
            return

        video_id = item.video_id
        video_path = item.path

        try:
            # FFmpeg pipe extraction (single-pass, no loudnorm, no disk I/O)
            t0 = time.perf_counter()
            wav = extract_audio_pipe(
                video_path,
                target_sr=self._cfg.codec.target_sr,
                ffmpeg_threads=self._cfg.worker.ffmpeg_threads,
            )
            wav = normalize_audio_peak(wav, target_peak=0.95)
            extract_time = time.perf_counter() - t0

            duration_s = wav.shape[-1] / self._cfg.codec.target_sr
            logger.info(
                "Extracted %s: %.1fs audio in %.1fs (%.1fx RT, pipe+peak_norm)",
                video_id, duration_s, extract_time, duration_s / max(extract_time, 0.001),
            )

            # VAD segmentation (in-memory, no disk)
            t1 = time.perf_counter()
            segments = segment_tensor(wav, self._cfg.codec.target_sr, self._cfg.vad)
            vad_time = time.perf_counter() - t1

            if not segments:
                logger.warning("No speech segments in %s, skipping", video_id)
                self.stats.videos_failed += 1
                self._cleanup_video(video_path)
                return

            usable_audio = sum(s.duration_s for s in segments)
            logger.info(
                "VAD %s: %d segments, %.1fs usable in %.1fs",
                video_id, len(segments), usable_audio, vad_time,
            )

            prepared = PreparedVideo(
                video_id=video_id,
                language=item.language,
                bucket=item.bucket,
                segments=segments,
                usable_audio_s=usable_audio,
                num_segments=len(segments),
                extract_time_s=extract_time,
                vad_time_s=vad_time,
                video_path=video_path,
            )

            self.stats.total_extract_s += extract_time
            self.stats.total_vad_s += vad_time
            self.stats.total_download_s += item.download_time_s

            self._ready_q.put(prepared)

        except Exception as e:
            logger.error("Extract+VAD failed for %s: %s\n%s",
                         video_id, e, traceback.format_exc())
            self.stats.videos_failed += 1
            self._failed_q.put(video_id)
            self._cleanup_video(video_path)

    def _cleanup_video(self, video_path: Path | None) -> None:
        if video_path:
            try:
                video_path.unlink(missing_ok=True)
            except Exception:
                pass
