"""Pipeline worker: the main loop that orchestrates everything.

Lifecycle per video:
  1. Claim video from Supabase (atomic, no duplicates)
  2. Download from R2 (prefetcher already has it ready)
  3. Extract audio with ffmpeg (PCM 16kHz mono, loudness normalized)
  4. VAD segmentation → 2-30s speech segments
  5. GPU encode all segments (XCodec2 + BiCodec, parallel streams)
  6. Accumulate tokens in shard buffer
  7. When buffer full → pack shard → upload to R2 → report to Supabase
  8. Clean up local files, claim next video

Recovery: on restart, check for owned in-progress videos, release or resume.
"""

from __future__ import annotations

import gc
import logging
import os
import signal
import sys
import time
import traceback
import uuid
from pathlib import Path
from queue import Empty

import torch

from codecbench.pipeline.config import PipelineConfig
from codecbench.pipeline.r2_client import R2Client, VideoPrefetcher, PrefetchItem, extract_audio_from_video
from codecbench.pipeline.supabase_client import SupabaseOrchestrator
from codecbench.pipeline.vad import segment_audio, Segment
from codecbench.pipeline.encoder import HotEncoder
from codecbench.pipeline.shard_packer import ShardBuffer, VideoTokens, pack_shard
from codecbench.pipeline.heartbeat import HeartbeatThread
from codecbench.pipeline.async_pipeline import AsyncPipeline, PreparedVideo

logger = logging.getLogger(__name__)


class PipelineWorker:
    def __init__(self, cfg: PipelineConfig):
        self.cfg = cfg
        self.r2 = R2Client(cfg.r2)
        self.orch = SupabaseOrchestrator(cfg.supabase)
        self.encoder = HotEncoder(cfg.codec)
        self.prefetcher: VideoPrefetcher | None = None
        self.heartbeat: HeartbeatThread | None = None
        self.shard_buffer = ShardBuffer()

        self._running = False
        self._total_audio_s = 0.0
        self._total_videos = 0
        self._consecutive_errors = 0

        # Resolve worker ID
        if not cfg.worker.worker_id:
            gpu_name = "unknown"
            if torch.cuda.is_available():
                gpu_name = torch.cuda.get_device_name(0).replace(" ", "_")
            cfg.worker.gpu_name = gpu_name
            cfg.worker.worker_id = (
                f"{cfg.worker.offer_id}_{gpu_name}_{os.getpid()}"
            )

        self._tmp_dir = Path(cfg.worker.local_tmp_dir)
        self._tmp_dir.mkdir(parents=True, exist_ok=True)

    def setup(self) -> None:
        """One-time setup: Supabase tables, RPC functions, model loading."""
        logger.info("=== Pipeline Worker Setup ===")
        logger.info("Worker ID: %s", self.cfg.worker.worker_id)

        # Supabase schema
        self.orch.ensure_tables()
        self.orch.create_claim_rpc()
        self.orch.create_release_stale_rpc()

        # R2 output bucket
        self.r2.ensure_bucket(self.cfg.r2.output_bucket)

        # Register worker
        self.orch.register_worker(
            self.cfg.worker.worker_id,
            self.cfg.worker.offer_id,
            self.cfg.worker.gpu_name,
        )

        # Start heartbeat
        self.heartbeat = HeartbeatThread(
            self.orch, self.cfg.worker.worker_id,
            self.cfg.supabase.heartbeat_interval_s,
        )
        self.heartbeat.start()

        # Download custom XCodec2 checkpoint from R2 if configured and not already local
        if self.cfg.r2.xcodec_ckpt_key and not self.cfg.codec.xcodec2_custom_ckpt:
            ckpt_local = Path(self.cfg.worker.local_tmp_dir) / "xcodec2_custom.ckpt"
            if not ckpt_local.exists():
                logger.info("Downloading custom XCodec2 checkpoint from R2: %s/%s",
                            self.cfg.r2.xcodec_bucket, self.cfg.r2.xcodec_ckpt_key)
                self.r2.download_file(
                    self.cfg.r2.xcodec_bucket, self.cfg.r2.xcodec_ckpt_key, ckpt_local,
                )
                logger.info("Custom checkpoint downloaded: %.1f MB", ckpt_local.stat().st_size / 1e6)
            else:
                logger.info("Custom checkpoint already cached at %s", ckpt_local)
            self.cfg.codec.xcodec2_custom_ckpt = str(ckpt_local)

        # Load codec models (this takes ~30s, models stay warm for entire lifetime)
        logger.info("Loading codec models...")
        self.encoder.load()

        # Transition: LOADING_MODELS → ALIVE
        self.orch.mark_worker_alive(self.cfg.worker.worker_id)

        # Start prefetcher
        self.prefetcher = VideoPrefetcher(
            self.r2,
            max_prefetch=self.cfg.worker.prefetch_videos,
            download_workers=self.cfg.r2.max_download_workers,
        )

        # Recovery: release stale claims from dead workers
        self.orch.release_stale_claims()

        logger.info("=== Setup Complete ===")

    def _resolve_bucket(self, language: str) -> str:
        """Map language to source R2 bucket."""
        if language == "english":
            return self.cfg.r2.source_buckets["english"]
        return self.cfg.r2.source_buckets["indic"]

    def _prefetch_next(self, language: str | None = None) -> None:
        """Claim a video from Supabase and queue it for prefetch download."""
        video = self.orch.claim_video(self.cfg.worker.worker_id, language)
        if video is None:
            return

        video_id = video["video_id"]
        lang = video.get("language", "unknown")
        bucket = self._resolve_bucket(lang)

        self.orch.update_video_status(video_id, "DOWNLOADING")
        dest_dir = self._tmp_dir / "videos"
        self.prefetcher.submit(video_id, lang, bucket, dest_dir)

    @staticmethod
    def _is_cuda_oom(exc: Exception) -> bool:
        msg = str(exc).lower()
        return isinstance(exc, torch.cuda.OutOfMemoryError) or "out of memory" in msg

    @staticmethod
    def _clear_cuda_after_oom() -> None:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

    def _encode_segments_resilient(
        self,
        video_id: str,
        segments: list[Segment],
    ):
        """Encode with progressive OOM fallbacks; keep video processing alive."""
        default_bs = max(self.cfg.codec.xcodec_batch_size, 1)
        seg_threshold = max(self.cfg.worker.oom_segment_threshold, 1)

        attempts: list[tuple[int, str]] = []
        if len(segments) > seg_threshold:
            attempts.append((1, "high_segment_safe_mode"))
        attempts.extend([
            (default_bs, "default"),
            (1, "xcodec_batch_size_1"),
        ])

        deduped: list[tuple[int, str]] = []
        seen: set[int] = set()
        for bs, label in attempts:
            if bs in seen:
                continue
            seen.add(bs)
            deduped.append((bs, label))

        for idx, (bs, label) in enumerate(deduped, start=1):
            try:
                if label != "default":
                    logger.warning(
                        "OOM-safe encode mode for %s: %s (xcodec_bs=%d, segs=%d)",
                        video_id, label, bs, len(segments),
                    )
                return self.encoder.encode_segments(
                    segments,
                    xcodec_batch_size_override=bs,
                )
            except Exception as e:
                if not self._is_cuda_oom(e):
                    raise
                logger.warning(
                    "OOM encoding %s on attempt %d/%d (%s, xcodec_bs=%d): %s",
                    video_id, idx, len(deduped), label, bs, e,
                )
                self._clear_cuda_after_oom()

        logger.warning(
            "Final OOM fallback for %s: per-segment sequential encoding (%d segments)",
            video_id, len(segments),
        )
        encoded = []
        for seg_idx, seg in enumerate(segments):
            try:
                one = self.encoder.encode_segments(
                    [seg], xcodec_batch_size_override=1
                )
            except Exception as e:
                if self._is_cuda_oom(e):
                    self._clear_cuda_after_oom()
                raise
            if not one:
                continue
            one[0].segment_idx = seg_idx
            encoded.append(one[0])
        return encoded

    def _process_one_video(self, video_id: str, video_path: Path, language: str) -> None:
        """Full pipeline for one video: extract audio → VAD → encode → buffer."""
        t0 = time.perf_counter()
        self.orch.update_video_status(video_id, "PROCESSING", {"started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ")})

        if self.heartbeat:
            self.heartbeat.current_video = video_id

        # Extract audio from video
        audio_dir = self._tmp_dir / "audio"
        audio_path = audio_dir / f"{video_id}.wav"
        extract_audio_from_video(video_path, audio_path, self.cfg.codec.target_sr)

        # VAD segmentation
        segments = segment_audio(audio_path, self.cfg.vad)
        if not segments:
            logger.warning("No speech segments found in %s, marking done", video_id)
            self.orch.update_video_status(video_id, "DONE", {
                "num_segments": 0,
                "usable_audio_s": 0,
                "finished_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
            })
            self._cleanup_video(video_id, video_path, audio_path)
            return

        usable_audio = sum(s.duration_s for s in segments)
        self.orch.update_video_status(video_id, "PROCESSING", {
            "num_segments": len(segments),
            "usable_audio_s": round(usable_audio, 1),
            "audio_duration_s": round(audio_path.stat().st_size / (2 * self.cfg.codec.target_sr), 1),
        })

        # Encode all segments
        encoded = self._encode_segments_resilient(video_id, segments)

        encode_time = time.perf_counter() - t0
        rtf = usable_audio / max(encode_time, 0.001)

        self.orch.update_video_status(video_id, "ENCODED", {
            "codecs_used": ["xcodec2_fast"],
        })

        # Add to shard buffer
        vt = VideoTokens(
            video_id=video_id,
            language=language,
            duration_s=usable_audio,
            segments=encoded,
            usable_audio_s=usable_audio,
        )
        self.shard_buffer.add(vt)

        # Update stats
        self._total_audio_s += usable_audio
        self._total_videos += 1
        self._consecutive_errors = 0

        if self.heartbeat:
            self.heartbeat.rtf = rtf
            self.heartbeat.total_audio_s = self._total_audio_s
            self.heartbeat.total_videos = self._total_videos

        logger.info(
            "Encoded %s: %d segments, %.1f s audio, RTF=%.1fx, took %.1f s",
            video_id, len(segments), usable_audio, rtf, encode_time,
        )

        # Check if shard buffer is full → pack and upload
        if self.shard_buffer.video_count >= self.cfg.worker.shard_pack_count:
            self._pack_and_upload_shard()

        # Cleanup
        self._cleanup_video(video_id, video_path, audio_path)

    def _pack_and_upload_shard(self) -> None:
        """Pack current buffer into a shard, upload to R2, report to Supabase."""
        if self.shard_buffer.video_count == 0:
            return

        shard_dir = self._tmp_dir / "shards"
        shard_id, shard_path, size_bytes = pack_shard(
            self.shard_buffer, shard_dir,
        )

        # Upload to R2
        r2_key = f"shards/{shard_id}.tar.zst"
        self.r2.upload_shard(shard_path, r2_key)

        # Report to Supabase
        video_ids = [v.video_id for v in self.shard_buffer.videos]
        languages = list(set(v.language for v in self.shard_buffer.videos))

        self.orch.register_shard(
            shard_id=shard_id,
            worker_id=self.cfg.worker.worker_id,
            video_ids=video_ids,
            languages=languages,
            total_segments=self.shard_buffer.total_segments,
            total_audio_s=self.shard_buffer.total_audio_s,
            codecs=["xcodec2_fast"],
            r2_key=r2_key,
            size_bytes=size_bytes,
        )

        # Mark all videos as PACKED → DONE
        for vid in video_ids:
            self.orch.mark_video_done(vid, shard_id, ["xcodec2_fast"])

        # Cleanup
        shard_path.unlink(missing_ok=True)
        self.shard_buffer.clear()
        gc.collect()

        if self.heartbeat:
            self.heartbeat.total_shards += 1
            self.heartbeat.shard_buffer_count = 0

        logger.info("Shard %s uploaded and reported (%d videos)", shard_id, len(video_ids))

    def _cleanup_video(self, video_id: str, video_path: Path, audio_path: Path) -> None:
        """Remove local video + audio files to free disk."""
        for p in [video_path, audio_path]:
            try:
                if p.exists():
                    p.unlink()
            except Exception:
                pass

    def run(self, language: str | None = None, max_videos: int | None = None) -> None:
        """Main loop: dispatches to async or serial pipeline based on config."""
        if self.cfg.worker.use_async_pipeline:
            self.run_async(language, max_videos)
        else:
            self.run_serial(language, max_videos)

    def run_async(self, language: str | None = None, max_videos: int | None = None) -> None:
        """Async 3-stage pipeline: Download→ExtractVAD→GPUEncode all overlap.

        CPU and GPU stay maximally busy. N extract workers feed the GPU.
        """
        self._running = True
        videos_done = 0
        videos_submitted = 0

        def _shutdown(sig, frame):
            logger.info("Received signal %s, shutting down gracefully...", sig)
            self._running = False
        signal.signal(signal.SIGTERM, _shutdown)
        signal.signal(signal.SIGINT, _shutdown)

        pipeline = AsyncPipeline(self.cfg, self.r2)
        pipeline.start()

        logger.info(
            "=== Async Pipeline Starting (language=%s, max=%s, extract_workers=%d) ===",
            language, max_videos, self.cfg.worker.extract_workers,
        )

        # Pre-claim and submit initial batch to fill the pipeline
        prefill = self.cfg.worker.extract_workers + self.cfg.worker.prefetch_videos
        target = min(prefill, max_videos) if max_videos else prefill
        for _ in range(target):
            video = self.orch.claim_video(self.cfg.worker.worker_id, language)
            if video is None:
                break
            vid = video["video_id"]
            lang = video.get("language", "unknown")
            bucket = self._resolve_bucket(lang)
            self.orch.update_video_status(vid, "DOWNLOADING")
            pipeline.submit(vid, lang, bucket)
            videos_submitted += 1

        # Main GPU loop: drain ready queue, encode, pack
        while self._running:
            prepared = None
            if max_videos and videos_done >= max_videos:
                logger.info("Reached max_videos=%d, stopping", max_videos)
                break

            try:
                prepared = pipeline.get_ready(timeout=120.0)
                if prepared is None:
                    logger.info("Pipeline drained, no more videos")
                    break

                if self.heartbeat:
                    self.heartbeat.current_video = prepared.video_id

                self.orch.update_video_status(prepared.video_id, "PROCESSING")

                # GPU encode with progressive OOM recovery
                t0 = time.perf_counter()
                encoded = self._encode_segments_resilient(
                    prepared.video_id, prepared.segments
                )
                torch.cuda.synchronize()
                encode_time = time.perf_counter() - t0

                rtf = prepared.usable_audio_s / max(encode_time, 0.001)
                self.orch.update_video_status(prepared.video_id, "ENCODED", {
                    "num_segments": prepared.num_segments,
                    "usable_audio_s": round(prepared.usable_audio_s, 1),
                    "codecs_used": ["xcodec2_fast"],
                })

                # Shard buffer
                vt = VideoTokens(
                    video_id=prepared.video_id,
                    language=prepared.language,
                    duration_s=prepared.usable_audio_s,
                    segments=encoded,
                    usable_audio_s=prepared.usable_audio_s,
                )
                self.shard_buffer.add(vt)

                self._total_audio_s += prepared.usable_audio_s
                self._total_videos += 1
                videos_done += 1
                self._consecutive_errors = 0

                if self.heartbeat:
                    self.heartbeat.rtf = rtf
                    self.heartbeat.record_rtf(rtf)
                    self.heartbeat.total_audio_s = self._total_audio_s
                    self.heartbeat.total_videos = self._total_videos
                    self.heartbeat.current_stage = "encoding"
                    self.heartbeat.shard_buffer_count = self.shard_buffer.video_count

                logger.info(
                    "Encoded %s: %d segs, %.0fs audio | ext=%.1f vad=%.1f enc=%.1f | "
                    "enc_RTF=%.0fx | ready_q=%d | done=%d/%s",
                    prepared.video_id, prepared.num_segments, prepared.usable_audio_s,
                    prepared.extract_time_s, prepared.vad_time_s, encode_time,
                    rtf, pipeline.ready_count, videos_done, max_videos or "∞",
                )

                # Cleanup video file + reclaim GPU memory
                if prepared.video_path:
                    prepared.video_path.unlink(missing_ok=True)
                if videos_done % 5 == 0:
                    torch.cuda.empty_cache()

                # Pack shard if buffer full
                if self.shard_buffer.video_count >= self.cfg.worker.shard_pack_count:
                    self._pack_and_upload_shard()

                # Handle failed videos: mark FAILED, claim replacements
                for failed_id in pipeline.get_failed():
                    self.orch.mark_video_failed(failed_id, "Download failed or extract error: file not in R2")
                    videos_submitted -= 1
                    if not max_videos or videos_submitted < max_videos:
                        video = self.orch.claim_video(self.cfg.worker.worker_id, language)
                        if video:
                            vid = video["video_id"]
                            lang = video.get("language", "unknown")
                            bucket = self._resolve_bucket(lang)
                            self.orch.update_video_status(vid, "DOWNLOADING")
                            pipeline.submit(vid, lang, bucket)
                            videos_submitted += 1

                # Claim + submit next video to keep pipeline full
                if not max_videos or videos_submitted < max_videos:
                    video = self.orch.claim_video(self.cfg.worker.worker_id, language)
                    if video:
                        vid = video["video_id"]
                        lang = video.get("language", "unknown")
                        bucket = self._resolve_bucket(lang)
                        self.orch.update_video_status(vid, "DOWNLOADING")
                        pipeline.submit(vid, lang, bucket)
                        videos_submitted += 1
                    elif videos_done >= videos_submitted:
                        pipeline.drain()

            except Empty:
                # Timeout waiting for ready video — check for failures and claim replacements
                for failed_id in pipeline.get_failed():
                    self.orch.mark_video_failed(failed_id, "Download failed or extract error: file not in R2")
                    videos_submitted -= 1
                    if not max_videos or videos_submitted < max_videos:
                        video = self.orch.claim_video(self.cfg.worker.worker_id, language)
                        if video:
                            vid = video["video_id"]
                            lang = video.get("language", "unknown")
                            bucket = self._resolve_bucket(lang)
                            self.orch.update_video_status(vid, "DOWNLOADING")
                            pipeline.submit(vid, lang, bucket)
                            videos_submitted += 1
                logger.warning("Ready queue timeout — %d submitted, %d done, pending=%d",
                               videos_submitted, videos_done, pipeline.pending_count)
                if pipeline.pending_count == 0 and pipeline.ready_count == 0:
                    logger.info("No pending or ready videos, stopping")
                    break

            except Exception as e:
                self._consecutive_errors += 1
                error_msg = f"{type(e).__name__}: {e}"
                logger.error("Async pipeline error: %s\n%s", error_msg, traceback.format_exc())

                # Keep the worker alive: fail only this video, then continue.
                if prepared is not None:
                    try:
                        self.orch.mark_video_failed(prepared.video_id, error_msg)
                        videos_submitted = max(videos_done, videos_submitted - 1)
                    except Exception:
                        pass
                    if prepared.video_path:
                        prepared.video_path.unlink(missing_ok=True)

                    # Submit a replacement claim so throughput stays stable.
                    if not max_videos or videos_submitted < max_videos:
                        try:
                            video = self.orch.claim_video(self.cfg.worker.worker_id, language)
                            if video:
                                vid = video["video_id"]
                                lang = video.get("language", "unknown")
                                bucket = self._resolve_bucket(lang)
                                self.orch.update_video_status(vid, "DOWNLOADING")
                                pipeline.submit(vid, lang, bucket)
                                videos_submitted += 1
                        except Exception:
                            pass

                if self._consecutive_errors >= self.cfg.worker.max_retries:
                    logger.error("Too many errors (%d), pausing 60s", self._consecutive_errors)
                    time.sleep(60)
                    self._consecutive_errors = 0

        # Shutdown
        pipeline.stop()
        if self.shard_buffer.video_count > 0:
            logger.info("Flushing %d videos in shard buffer", self.shard_buffer.video_count)
            self._pack_and_upload_shard()
        if self.heartbeat:
            self.heartbeat.stop()

        reason = "max_videos_reached" if (max_videos and videos_done >= max_videos) else "no_more_videos"
        self.orch.mark_worker_stopped(self.cfg.worker.worker_id, reason)

        stats = pipeline.stats
        logger.info(
            "=== Async Pipeline Complete. %d videos, %.0fs audio | "
            "extract=%.1fs vad=%.1fs download=%.1fs ===",
            videos_done, self._total_audio_s,
            stats.total_extract_s, stats.total_vad_s, stats.total_download_s,
        )

    def run_serial(self, language: str | None = None, max_videos: int | None = None) -> None:
        """Legacy serial pipeline: claim → download → process → repeat."""
        self._running = True
        videos_done = 0

        def _shutdown(sig, frame):
            logger.info("Received signal %s, shutting down gracefully...", sig)
            self._running = False
        signal.signal(signal.SIGTERM, _shutdown)
        signal.signal(signal.SIGINT, _shutdown)

        logger.info("=== Worker Main Loop Starting (language=%s, max=%s, mode=serial) ===",
                     language, max_videos)

        for _ in range(self.cfg.worker.prefetch_videos):
            self._prefetch_next(language)

        while self._running:
            if max_videos and videos_done >= max_videos:
                logger.info("Reached max_videos=%d, stopping", max_videos)
                break

            try:
                item = self.prefetcher.get(timeout=60.0)

                if item.path is None:
                    logger.warning("Video %s download failed, skipping", item.video_id)
                    self.orch.mark_video_failed(item.video_id, "Download failed: file not found in R2")
                    self._prefetch_next(language)
                    continue

                self._prefetch_next(language)
                self._process_one_video(item.video_id, item.path, item.language)
                videos_done += 1

            except Exception as e:
                self._consecutive_errors += 1
                error_msg = f"{type(e).__name__}: {e}"
                logger.error("Error processing video: %s\n%s", error_msg, traceback.format_exc())

                if self.heartbeat and self.heartbeat.current_video:
                    try:
                        self.orch.mark_video_failed(self.heartbeat.current_video, error_msg)
                        self.orch.report_worker_error(self.cfg.worker.worker_id, error_msg)
                    except Exception:
                        pass

                if self._consecutive_errors >= self.cfg.worker.max_retries:
                    logger.error("Too many consecutive errors (%d), pausing 60s",
                                self._consecutive_errors)
                    time.sleep(60)
                    self._consecutive_errors = 0

                try:
                    self._prefetch_next(language)
                except Exception:
                    pass

        if self.shard_buffer.video_count > 0:
            logger.info("Flushing remaining %d videos in shard buffer", self.shard_buffer.video_count)
            self._pack_and_upload_shard()

        if self.heartbeat:
            self.heartbeat.stop()
        if self.prefetcher:
            self.prefetcher.stop()

        self.orch.mark_worker_stopped(self.cfg.worker.worker_id, "graceful_shutdown")
        logger.info("=== Worker shutdown complete. Processed %d videos, %.0f s audio ===",
                     videos_done, self._total_audio_s)
