"""
Production validation worker: claims done videos, runs LID+CTC pipeline,
packs parquet shards. Modeled on the battle-tested src/worker.py.

Key features vs the basic version:
  - Registers in worker_validators table for fleet monitoring
  - Async heartbeat loop (30s + jitter) with live throughput stats
  - Prefetch: claim + download next video while current one processes
  - Graceful shutdown: SIGTERM → finish current → flush parquet → release → offline
  - Dead video recovery: on startup, reset stale validating claims
  - Rolling throughput tracking (last 10 videos)
"""
from __future__ import annotations

import asyncio
import json
import logging
import random
import shutil
import signal
import tarfile
import tempfile
import time
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

from .config import ValidationConfig, PARQUET_SHARD_SIZE, HEARTBEAT_INTERVAL_S
from .audio_loader import load_video_segments
from .pipeline import ValidationPipeline, SegmentResult
from .packer import ParquetPacker

logger = logging.getLogger(__name__)

# DB operation counters
_db_stats = {
    "claims_ok": 0, "claims_fail": 0,
    "heartbeats_ok": 0, "heartbeats_fail": 0,
    "marks_ok": 0, "marks_fail": 0,
    "recovery_reset": 0,
}


def _log_db_stats():
    s = _db_stats
    logger.info(
        f"[DB-STATS] claims={s['claims_ok']}/{s['claims_ok']+s['claims_fail']} "
        f"heartbeats={s['heartbeats_ok']}/{s['heartbeats_ok']+s['heartbeats_fail']} "
        f"marks={s['marks_ok']}/{s['marks_fail']} "
        f"recovery_reset={s['recovery_reset']}"
    )


@dataclass
class ValidationStats:
    """Mutable stats updated during worker lifecycle."""
    current_video_id: Optional[str] = None
    videos_processed: int = 0
    videos_failed: int = 0
    segments_processed: int = 0
    avg_segs_per_second: float = 0.0
    shards_written: int = 0
    shards_uploaded: int = 0
    total_parquet_mb: float = 0.0
    last_video_completed_at: Optional[datetime] = None
    last_error: Optional[str] = None
    _recent_speeds: deque = field(default_factory=lambda: deque(maxlen=10))


@dataclass
class PrefetchResult:
    """Prefetched video ready for processing."""
    video_id: str
    language: str
    work_dir: Path


class ValidationWorker:
    """
    Production worker with heartbeats, prefetch, graceful shutdown.
    Each Docker container runs one instance.
    """

    def __init__(self, config: ValidationConfig):
        self.config = config
        self.pipeline = ValidationPipeline(config)
        self.packer: Optional[ParquetPacker] = None
        self.stats = ValidationStats()
        self._db = None
        self._s3 = None
        self._shutdown_event = asyncio.Event()
        self._heartbeat_task: Optional[asyncio.Task] = None
        self._work_dir = Path(tempfile.mkdtemp(prefix="validation_"))
        self._start_time = 0.0

    async def start(self):
        """Main entry: connect → recover → register → load models → heartbeat → loop."""
        self._start_time = time.time()

        loop = asyncio.get_running_loop()
        for sig in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._handle_shutdown(s)))

        try:
            await self._connect_db()
            self._init_s3()

            await self._recover_dead_videos()
            await self._register()

            self._hf_login()
            logger.info("Loading validation models (downloading if needed)...")
            self.pipeline.load_models()

            output_dir = self._work_dir / "parquet_shards"
            self.packer = ParquetPacker(self.config, output_dir)

            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())

            logger.info(f"Worker {self.config.worker_id} entering main loop")
            await self._main_loop()

        except Exception as e:
            logger.error(f"Worker fatal error: {e}", exc_info=True)
            self.stats.last_error = str(e)[:500]
            if self._db:
                await self._update_worker_status("error", error=str(e)[:500])
        finally:
            await self._cleanup()

    # ── Main Loop ──────────────────────────────────────────────────────

    async def _main_loop(self):
        consecutive_empty = 0
        max_empty = 5
        consecutive_failures = 0
        max_consecutive_failures = 3  # Exit if GPU is poisoned (device-side assert cascade)
        prefetch_task: Optional[asyncio.Task] = None
        prefetched: Optional[PrefetchResult] = None

        max_videos = self.config.max_videos
        if max_videos > 0:
            logger.info(f"MAX_VIDEOS={max_videos} — will stop after {max_videos} video(s)")

        while not self._shutdown_event.is_set():
            if max_videos > 0 and self.stats.videos_processed >= max_videos:
                logger.info(f"Reached MAX_VIDEOS={max_videos}, shutting down.")
                break

            # Use prefetched video or claim a new one
            if prefetched:
                video_id, language, video_work = prefetched.video_id, prefetched.language, prefetched.work_dir
                prefetched = None
                logger.info(f"Using prefetched video: {video_id}")
            else:
                video_id, language = await self._claim_video()
                video_work = None

            if not video_id:
                consecutive_empty += 1
                if consecutive_empty >= max_empty:
                    logger.info(f"{max_empty} consecutive empty claims, exiting")
                    break
                wait = min(2.0 * consecutive_empty, 30.0)
                logger.info(f"No pending videos, waiting {wait:.0f}s (attempt {consecutive_empty}/{max_empty})...")
                try:
                    await asyncio.wait_for(self._shutdown_event.wait(), timeout=wait)
                    break
                except asyncio.TimeoutError:
                    continue

            consecutive_empty = 0
            self.stats.current_video_id = video_id

            # Kick off prefetch for next video while we process this one
            remaining = (max_videos - self.stats.videos_processed - 1) if max_videos > 0 else 1
            if remaining > 0 and not self._shutdown_event.is_set():
                prefetch_task = asyncio.create_task(self._prefetch_next())

            try:
                await self._process_one_video(video_id, language, video_work)
                self.stats.videos_processed += 1
                self.stats.last_video_completed_at = datetime.now(timezone.utc)
                consecutive_failures = 0
            except Exception as e:
                logger.error(f"Video {video_id} failed: {e}", exc_info=True)
                self.stats.videos_failed += 1
                self.stats.last_error = str(e)[:500]
                await self._mark_video_failed(video_id, str(e))
                err_str = str(e)
                is_gpu_fatal = ("CUDA" in err_str or "device-side" in err_str
                                or "cuFFT" in err_str)
                if is_gpu_fatal:
                    consecutive_failures += 1
                    if consecutive_failures >= max_consecutive_failures:
                        logger.error(
                            f"{consecutive_failures} consecutive GPU failures "
                            "(GPU likely poisoned), exiting"
                        )
                        break
                else:
                    consecutive_failures = 0

            self.stats.current_video_id = None

            # Collect prefetch result
            if prefetch_task and not prefetch_task.done():
                try:
                    prefetched = await asyncio.wait_for(prefetch_task, timeout=10)
                except (asyncio.TimeoutError, Exception):
                    prefetched = None
            elif prefetch_task and prefetch_task.done():
                try:
                    prefetched = prefetch_task.result()
                except Exception:
                    prefetched = None
            prefetch_task = None

        logger.info(f"Main loop ended: {self.stats.videos_processed} processed, {self.stats.videos_failed} failed")

    # ── Video Processing ───────────────────────────────────────────────

    async def _process_one_video(self, video_id: str, language: str, prefetch_dir: Optional[Path] = None):
        t0 = time.time()
        video_work = prefetch_dir or (self._work_dir / video_id)
        video_work.mkdir(exist_ok=True)
        loop = asyncio.get_running_loop()

        try:
            # Download tar (skip if prefetched)
            tar_path = video_work / f"{video_id}_transcribed.tar"
            if not tar_path.exists():
                tar_path = await loop.run_in_executor(None, self._download_tar, video_id, video_work)
                if not tar_path:
                    await self._mark_video_failed(video_id, "tar download failed")
                    return

            # Extract
            await loop.run_in_executor(None, self._extract_tar, tar_path, video_work)

            # Load audio + transcriptions
            metadata, segments = load_video_segments(video_work, video_id)
            if not segments:
                logger.warning(f"[{video_id}] No valid segments")
                await self._mark_video_done(video_id, 0)
                return

            # Run validation pipeline (GPU-bound, runs in current thread)
            results = self.pipeline.process_video(video_id, segments)

            # Pack into parquet buffer
            self.packer.add_video_results(video_id, results)
            self.stats.segments_processed += len(results)

            # Sync packer stats
            packer_stats = self.packer.stats
            self.stats.shards_written = packer_stats["shards_written"]
            self.stats.total_parquet_mb = packer_stats.get("total_mb", 0.0)

            # Update Supabase
            await self._mark_video_done(video_id, len(results))

            elapsed = time.time() - t0
            speed = len(results) / elapsed if elapsed > 0 else 0
            self.stats._recent_speeds.append(speed)
            self.stats.avg_segs_per_second = sum(self.stats._recent_speeds) / len(self.stats._recent_speeds)

            logger.info(
                f"[{video_id}] Complete: {len(results)} segments in {elapsed:.1f}s "
                f"({speed:.1f} segs/s, rolling avg {self.stats.avg_segs_per_second:.1f})"
            )

        finally:
            if video_work.exists() and video_work != prefetch_dir:
                shutil.rmtree(video_work, ignore_errors=True)
            elif prefetch_dir and prefetch_dir.exists():
                shutil.rmtree(prefetch_dir, ignore_errors=True)

    # ── Prefetch ───────────────────────────────────────────────────────

    async def _prefetch_next(self) -> Optional[PrefetchResult]:
        """Claim and pre-download the next video while current one processes."""
        video_id = None
        try:
            video_id, language = await self._claim_video()
            if not video_id:
                return None

            work_dir = Path(tempfile.mkdtemp(prefix=f"prefetch_{self.config.worker_id}_"))
            loop = asyncio.get_running_loop()
            tar_path = await loop.run_in_executor(None, self._download_tar, video_id, work_dir)
            if not tar_path:
                shutil.rmtree(work_dir, ignore_errors=True)
                await self._mark_video_failed(video_id, "tar download failed (prefetch)")
                return None

            logger.info(f"[prefetch] {video_id} tar ready at {work_dir}")
            return PrefetchResult(video_id=video_id, language=language, work_dir=work_dir)
        except Exception as e:
            logger.warning(f"[prefetch] Failed: {e}")
            if video_id:
                await self._mark_video_failed(video_id, f"prefetch error: {str(e)[:200]}")
            return None

    # ── R2 Operations ──────────────────────────────────────────────────

    def _download_tar(self, video_id: str, work_dir: Path) -> Optional[Path]:
        """Download transcribed tar from R2. Checks both bucket locations
        (tars split across 'transcribed' bucket and '1-cleaned-data/transcribed/' prefix)."""
        tar_path = work_dir / f"{video_id}_transcribed.tar"
        if self.config.mock_mode:
            return None

        key = f"{video_id}_transcribed.tar"
        locations = [
            (self.config.r2_bucket_source, key),
            ("1-cleaned-data", f"transcribed/{key}"),
        ]

        for bucket, obj_key in locations:
            try:
                self._s3.head_object(Bucket=bucket, Key=obj_key)
                logger.info(f"Downloading s3://{bucket}/{obj_key}")
                self._s3.download_file(bucket, obj_key, str(tar_path))
                size_mb = tar_path.stat().st_size / 1e6
                logger.info(f"Downloaded {obj_key}: {size_mb:.1f}MB")
                return tar_path
            except Exception:
                continue

        logger.error(f"Download failed for {video_id}: tar not found in any R2 location")
        return None

    def _extract_tar(self, tar_path: Path, work_dir: Path):
        with tarfile.open(tar_path, "r:*") as tf:
            tf.extractall(work_dir, filter="data")
        tar_path.unlink(missing_ok=True)

    def _init_s3(self):
        if self.config.mock_mode:
            return
        import boto3
        self._s3 = boto3.client(
            "s3",
            endpoint_url=self.config.r2_endpoint_url,
            aws_access_key_id=self.config.r2_access_key_id,
            aws_secret_access_key=self.config.r2_secret_access_key,
            region_name="auto",
        )

    def _hf_login(self):
        """Remove HF_TOKEN from env — all 4 models load from R2, no HF auth needed."""
        import os
        os.environ.pop("HF_TOKEN", None)
        logger.info("HF_TOKEN cleared from env (all models load from R2, no HF auth needed)")

    # ── Database Operations ────────────────────────────────────────────

    async def _connect_db(self):
        if self.config.mock_mode or not self.config.database_url:
            logger.info("Running without DB (mock mode or no DATABASE_URL)")
            return

        import asyncpg
        is_pooler = ":6543" in self.config.database_url
        self._db = await asyncpg.create_pool(
            dsn=self.config.database_url,
            min_size=1, max_size=3,
            command_timeout=60,
            statement_cache_size=0 if is_pooler else 100,
            ssl="require",
        )
        logger.info(f"DB pool connected (min=1, max=3, pooler={is_pooler})")

    async def _register(self):
        """Insert/upsert into worker_validators table."""
        if not self._db:
            return

        config_json = {
            "models": {
                "mms_lid": self.config.enable_mms_lid,
                "voxlingua": self.config.enable_voxlingua,
                "conformer_multi": self.config.enable_conformer_multi,
                "wav2vec_lang": self.config.enable_wav2vec_lang,
            },
            "model_bucket": self.config.r2_model_bucket,
            "output_bucket": self.config.r2_bucket_output,
            "shard_size": PARQUET_SHARD_SIZE,
            "gpu_type": self.config.gpu_type,
        }

        for attempt in range(5):
            try:
                async with self._db.acquire() as conn:
                    await conn.execute("""
                        INSERT INTO worker_validators (
                            worker_id, status, gpu_type, config_json,
                            started_at, last_heartbeat_at,
                            videos_processed, videos_failed, segments_processed,
                            shards_written, shards_uploaded, total_parquet_mb
                        ) VALUES ($1, 'online', $2, $3::jsonb, now(), now(), 0, 0, 0, 0, 0, 0)
                        ON CONFLICT (worker_id) DO UPDATE SET
                            status = 'online', gpu_type = $2, config_json = $3::jsonb,
                            started_at = now(), last_heartbeat_at = now(),
                            videos_processed = 0, videos_failed = 0, segments_processed = 0,
                            shards_written = 0, shards_uploaded = 0, total_parquet_mb = 0,
                            last_error = NULL, current_video_id = NULL
                    """, self.config.worker_id, self.config.gpu_type, json.dumps(config_json))
                logger.info(f"Registered worker {self.config.worker_id} in worker_validators")
                return
            except Exception as e:
                wait = 3 * (attempt + 1) + random.uniform(0, 2)
                logger.warning(f"Register failed (attempt {attempt+1}/5): {str(e)[:100]}, retry in {wait:.0f}s")
                await asyncio.sleep(wait)
        logger.error("Register failed after 5 attempts, continuing anyway")

    async def _recover_dead_videos(self):
        """On startup, reset stale validating claims (handles crash recovery)."""
        if not self._db:
            return

        try:
            async with self._db.acquire() as conn:
                # Reset videos claimed by THIS worker that crashed mid-processing
                result = await conn.execute("""
                    UPDATE video_queue
                    SET validation_status = 'pending', claimed_by = NULL, claimed_at = NULL
                    WHERE validation_status = 'validating' AND claimed_by = $1
                """, self.config.worker_id)
                own_count = int(result.split()[-1]) if result else 0

                # Reset globally stale claims (>15 min old, any worker)
                result = await conn.execute("""
                    UPDATE video_queue
                    SET validation_status = 'pending', claimed_by = NULL, claimed_at = NULL
                    WHERE validation_status = 'validating'
                      AND claimed_at < now() - interval '15 minutes'
                """)
                stale_count = int(result.split()[-1]) if result else 0

            total = own_count + stale_count
            if total > 0:
                _db_stats["recovery_reset"] = total
                logger.info(f"Recovered {total} dead videos (own={own_count}, stale={stale_count})")
        except Exception as e:
            logger.warning(f"Dead video recovery failed (non-fatal): {e}")

    async def _claim_video(self) -> tuple[Optional[str], str]:
        """Atomic claim using validation_status column. Leaves status='done' untouched.
        ORDER BY video_id forces planner to use idx_vq_val_claim partial index
        instead of seq-scanning 500K rows."""
        if not self._db:
            return None, ""

        for attempt in range(3):
            try:
                async with self._db.acquire() as conn:
                    row = await conn.fetchrow("""
                        UPDATE video_queue
                        SET validation_status = 'validating',
                            claimed_by = $1,
                            claimed_at = now()
                        WHERE video_id = (
                            SELECT video_id FROM video_queue
                            WHERE status = 'done' AND validation_status = 'pending'
                            ORDER BY video_id
                            LIMIT 1
                            FOR UPDATE SKIP LOCKED
                        )
                        RETURNING video_id, language
                    """, self.config.worker_id)

                if row:
                    _db_stats["claims_ok"] += 1
                    logger.info(f"Claimed {row['video_id']} (lang={row['language']})")
                    return row["video_id"], row.get("language", "") or ""
                return None, ""

            except Exception as e:
                _db_stats["claims_fail"] += 1
                if attempt < 2:
                    wait = (2 ** attempt) + random.uniform(0.5, 1.5)
                    logger.warning(f"Claim failed (attempt {attempt+1}/3): {str(e)[:100]}, retry in {wait:.1f}s")
                    await asyncio.sleep(wait)
                else:
                    logger.error(f"Claim failed after 3 attempts: {str(e)[:100]}")
                    return None, ""

    async def _mark_video_done(self, video_id: str, segment_count: int):
        if not self._db:
            return
        try:
            async with self._db.acquire() as conn:
                await conn.execute("""
                    UPDATE video_queue
                    SET validation_status = 'validated',
                        validation_segments = $2,
                        validation_at = now()
                    WHERE video_id = $1
                """, video_id, segment_count)
            _db_stats["marks_ok"] += 1
        except Exception as e:
            _db_stats["marks_fail"] += 1
            logger.warning(f"mark_done failed for {video_id}: {e}")

    async def _mark_video_failed(self, video_id: str, error: str):
        if not self._db:
            return
        try:
            async with self._db.acquire() as conn:
                await conn.execute("""
                    UPDATE video_queue
                    SET validation_status = 'validation_failed',
                        error_message = $2
                    WHERE video_id = $1
                """, video_id, error[:500])
            _db_stats["marks_ok"] += 1
        except Exception as e:
            _db_stats["marks_fail"] += 1
            logger.warning(f"mark_failed failed for {video_id}: {e}")

    async def _release_video(self, video_id: str):
        """Release a claimed video back to pending (used during shutdown)."""
        if not self._db:
            return
        try:
            async with self._db.acquire() as conn:
                await conn.execute("""
                    UPDATE video_queue
                    SET validation_status = 'pending', claimed_by = NULL, claimed_at = NULL
                    WHERE video_id = $1 AND validation_status = 'validating'
                """, video_id)
            logger.info(f"Released {video_id} back to pending")
        except Exception as e:
            logger.warning(f"Release failed for {video_id}: {e}")

    # ── Heartbeat ──────────────────────────────────────────────────────

    async def _heartbeat_loop(self):
        hb_count = 0
        while not self._shutdown_event.is_set():
            try:
                await self._send_heartbeat()
            except Exception as e:
                logger.warning(f"Heartbeat failed: {e}")

            hb_count += 1
            if hb_count % 10 == 0:
                _log_db_stats()

            jitter = random.uniform(0, 15)
            try:
                await asyncio.wait_for(
                    self._shutdown_event.wait(),
                    timeout=HEARTBEAT_INTERVAL_S + jitter,
                )
                break
            except asyncio.TimeoutError:
                pass

    async def _send_heartbeat(self):
        if not self._db:
            return
        try:
            s = self.stats
            async with self._db.acquire() as conn:
                await conn.execute("""
                    UPDATE worker_validators SET
                        current_video_id = $2,
                        videos_processed = $3,
                        videos_failed = $4,
                        segments_processed = $5,
                        avg_segs_per_second = $6,
                        shards_written = $7,
                        shards_uploaded = $8,
                        total_parquet_mb = $9,
                        last_heartbeat_at = now(),
                        last_video_completed_at = $10,
                        last_error = $11
                    WHERE worker_id = $1
                """,
                    self.config.worker_id,
                    s.current_video_id,
                    s.videos_processed,
                    s.videos_failed,
                    s.segments_processed,
                    round(s.avg_segs_per_second, 2),
                    s.shards_written,
                    s.shards_uploaded,
                    round(s.total_parquet_mb, 2),
                    s.last_video_completed_at,
                    s.last_error,
                )
            _db_stats["heartbeats_ok"] += 1
        except Exception as e:
            _db_stats["heartbeats_fail"] += 1

    async def _update_worker_status(self, status: str, error: str = ""):
        if not self._db:
            return
        try:
            async with self._db.acquire() as conn:
                await conn.execute("""
                    UPDATE worker_validators
                    SET status = $2, last_error = $3, last_heartbeat_at = now()
                    WHERE worker_id = $1
                """, self.config.worker_id, status, error[:500] if error else None)
        except Exception as e:
            logger.warning(f"Status update failed: {e}")

    # ── Shutdown ───────────────────────────────────────────────────────

    async def _handle_shutdown(self, sig):
        logger.info(f"Received {sig.name}, initiating graceful shutdown...")
        self._shutdown_event.set()

    async def _cleanup(self):
        logger.info("Cleaning up worker...")

        # Cancel heartbeat
        if self._heartbeat_task:
            self._heartbeat_task.cancel()
            try:
                await self._heartbeat_task
            except asyncio.CancelledError:
                pass

        # Release in-progress video back to pending
        if self.stats.current_video_id:
            await self._release_video(self.stats.current_video_id)

        # Flush remaining parquet buffer
        if self.packer:
            shard = self.packer.flush()
            if shard:
                logger.info(f"Flushed final shard: {shard}")

        # Unload models
        try:
            self.pipeline.unload_models()
        except Exception:
            pass

        # Final heartbeat + set offline
        try:
            await self._send_heartbeat()
        except Exception:
            pass
        await self._update_worker_status("offline")

        _log_db_stats()

        # Close DB pool
        if self._db:
            await self._db.close()

        # Cleanup temp dir
        if self._work_dir.exists():
            shutil.rmtree(self._work_dir, ignore_errors=True)

        elapsed = time.time() - self._start_time
        logger.info(
            f"Worker {self.config.worker_id} shutdown complete. "
            f"{self.stats.videos_processed} processed, {self.stats.videos_failed} failed, "
            f"{self.stats.segments_processed} segments, {elapsed:.0f}s total"
        )
