"""
Validation worker: claims videos from queue, downloads transcribed tars,
runs validation pipeline, packs results into parquet shards.
Prefetches next video while current one is processing for max throughput.
"""
from __future__ import annotations

import asyncio
import json
import logging
import shutil
import signal
import tarfile
import tempfile
import time
from pathlib import Path
from typing import Optional

from .config import ValidationConfig, PREFETCH_QUEUE_SIZE, HEARTBEAT_INTERVAL_S
from .audio_loader import load_video_segments
from .pipeline import ValidationPipeline, SegmentResult
from .packer import ParquetPacker

logger = logging.getLogger(__name__)


class ValidationWorker:
    """
    Main worker loop. Claims videos, validates, packs parquet.
    
    Flow:
      1. Claim video from Supabase queue
      2. Download {videoID}_transcribed.tar from R2
      3. Extract + load audio + transcriptions
      4. Run pipeline (LID + CTC)
      5. Add results to parquet packer
      6. Update Supabase validation status
      7. Repeat (with prefetch overlap)
    """

    def __init__(self, config: ValidationConfig):
        self.config = config
        self.pipeline = ValidationPipeline(config)
        self.packer: Optional[ParquetPacker] = None
        self._db = None
        self._s3 = None
        self._work_dir = Path(tempfile.mkdtemp(prefix="validation_"))
        self._running = True
        self._videos_processed = 0
        self._total_segments = 0
        self._start_time = 0.0

    async def start(self):
        """Main entry point — load models, connect DB, run loop."""
        self._start_time = time.time()

        # Signal handlers for graceful shutdown
        for sig in (signal.SIGTERM, signal.SIGINT):
            signal.signal(sig, self._handle_signal)

        # Connect DB
        await self._connect_db()

        # Initialize R2
        self._init_s3()

        # Load ML models
        logger.info("Loading validation models...")
        self.pipeline.load_models()

        # Initialize packer
        output_dir = self._work_dir / "parquet_shards"
        self.packer = ParquetPacker(self.config, output_dir)

        # Main loop
        logger.info(f"Worker {self.config.worker_id} starting validation loop")
        await self._main_loop()

        # Flush remaining results
        self.packer.flush()

        # Cleanup
        self.pipeline.unload_models()
        if self._db:
            await self._db.close()

        elapsed = time.time() - self._start_time
        logger.info(
            f"Worker done: {self._videos_processed} videos, "
            f"{self._total_segments} segments, {elapsed:.0f}s total"
        )

    async def _main_loop(self):
        empty_claims = 0
        max_empty = 5

        while self._running:
            if self.config.max_videos and self._videos_processed >= self.config.max_videos:
                logger.info(f"Reached max_videos={self.config.max_videos}, stopping")
                break

            # Claim a video
            video_id, language = await self._claim_video()
            if not video_id:
                empty_claims += 1
                if empty_claims >= max_empty:
                    logger.info(f"{max_empty} consecutive empty claims, exiting")
                    break
                await asyncio.sleep(2.0 * empty_claims)
                continue
            empty_claims = 0

            try:
                await self._process_one_video(video_id, language)
                self._videos_processed += 1
            except Exception as e:
                logger.error(f"Video {video_id} failed: {e}", exc_info=True)
                await self._mark_video_failed(video_id, str(e))

    async def _process_one_video(self, video_id: str, language: str):
        """Download, validate, pack one video."""
        t0 = time.time()
        video_work = self._work_dir / video_id
        video_work.mkdir(exist_ok=True)

        try:
            # Download tar
            tar_path = self._download_tar(video_id, video_work)
            if not tar_path:
                await self._mark_video_failed(video_id, "tar download failed")
                return

            # Extract
            self._extract_tar(tar_path, video_work)

            # Load audio + transcriptions
            metadata, segments = load_video_segments(video_work, video_id)
            if not segments:
                logger.warning(f"[{video_id}] No valid segments")
                await self._mark_video_done(video_id, 0)
                return

            # Run validation pipeline
            results = self.pipeline.process_video(video_id, segments)

            # Pack into parquet buffer
            self.packer.add_video_results(video_id, results)
            self._total_segments += len(results)

            # Update Supabase
            await self._mark_video_done(video_id, len(results))

            elapsed = time.time() - t0
            logger.info(
                f"[{video_id}] Complete: {len(results)} segments in {elapsed:.1f}s "
                f"({len(results)/elapsed:.0f} segs/s)"
            )

        finally:
            # Cleanup temp files for this video
            if video_work.exists():
                shutil.rmtree(video_work, ignore_errors=True)

    def _download_tar(self, video_id: str, work_dir: Path) -> Optional[Path]:
        """Download transcribed tar from R2."""
        tar_path = work_dir / f"{video_id}_transcribed.tar"

        if self.config.mock_mode:
            logger.info(f"[MOCK] Would download {video_id}_transcribed.tar")
            return None

        key = f"{video_id}_transcribed.tar"
        try:
            logger.info(f"Downloading s3://{self.config.r2_bucket_source}/{key}")
            self._s3.download_file(self.config.r2_bucket_source, key, str(tar_path))
            size_mb = tar_path.stat().st_size / 1e6
            logger.info(f"Downloaded {key}: {size_mb:.1f}MB")
            return tar_path
        except Exception as e:
            logger.error(f"Download failed for {key}: {e}")
            return None

    def _extract_tar(self, tar_path: Path, work_dir: Path):
        """Extract tar to work directory."""
        with tarfile.open(tar_path, "r:*") as tf:
            tf.extractall(work_dir, filter="data")
        # Remove tar to free disk
        tar_path.unlink(missing_ok=True)

    def _init_s3(self):
        if self.config.mock_mode:
            return
        import boto3
        self._s3 = boto3.client(
            "s3",
            endpoint_url=self.config.r2_endpoint_url,
            aws_access_key_id=self.config.r2_access_key_id,
            aws_secret_access_key=self.config.r2_secret_access_key,
            region_name="auto",
        )

    async def _connect_db(self):
        if self.config.mock_mode or not self.config.database_url:
            logger.info("Running without DB (mock mode or no DATABASE_URL)")
            return

        try:
            import asyncpg
            self._db = await asyncpg.create_pool(
                dsn=self.config.database_url,
                min_size=1, max_size=4,
                command_timeout=30,
                statement_cache_size=0,
                ssl="require",
            )
            logger.info("DB pool connected")
        except Exception as e:
            logger.warning(f"DB connection failed (continuing without DB): {e}")
            self._db = None

    async def _claim_video(self) -> tuple[Optional[str], str]:
        """Claim a video from the validation queue."""
        if not self._db:
            return None, ""

        try:
            async with self._db.acquire() as conn:
                row = await conn.fetchrow("""
                    UPDATE video_queue
                    SET status = 'validating',
                        claimed_by = $1,
                        claimed_at = now()
                    WHERE video_id = (
                        SELECT video_id FROM video_queue
                        WHERE status = 'done'
                        LIMIT 1
                        FOR UPDATE SKIP LOCKED
                    )
                    RETURNING video_id, language
                """, self.config.worker_id)

            if row:
                logger.info(f"Claimed {row['video_id']} (lang={row['language']})")
                return row["video_id"], row.get("language", "")
        except Exception as e:
            logger.error(f"Claim failed: {e}")

        return None, ""

    async def _mark_video_done(self, video_id: str, segment_count: int):
        if not self._db:
            return
        try:
            async with self._db.acquire() as conn:
                await conn.execute("""
                    UPDATE video_queue
                    SET status = 'validated',
                        validation_segments = $2,
                        validation_at = now()
                    WHERE video_id = $1
                """, video_id, segment_count)
        except Exception as e:
            logger.warning(f"mark_done failed for {video_id}: {e}")

    async def _mark_video_failed(self, video_id: str, error: str):
        if not self._db:
            return
        try:
            async with self._db.acquire() as conn:
                await conn.execute("""
                    UPDATE video_queue
                    SET status = 'validation_failed',
                        error_message = $2
                    WHERE video_id = $1
                """, video_id, error[:500])
        except Exception as e:
            logger.warning(f"mark_failed failed for {video_id}: {e}")

    def _handle_signal(self, signum, frame):
        logger.info(f"Received signal {signum}, shutting down gracefully...")
        self._running = False
