"""
Recovery validation worker.

Workflow:
  - claim one transcribed video from `validation_recover_queue`
  - download raw tar from `1-cleaned-data`
  - replay `audio_polish` deterministically
  - match replayed child IDs to historical `transcription_results`
  - validate all historical transcription rows for that video
  - log replay-only extras separately for later transcription salvage

This intentionally does NOT try to skip already-validated segments in-worker.
The fast operational path is:
  - regenerate the full historical transcribed corpus cleanly
  - publish recover shards to a separate bucket
  - perform one final sync/merge job after the fleet completes
"""
from __future__ import annotations

import asyncio
import json
import logging
import random
import shutil
import signal
import tempfile
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

from src.config import EnvConfig
from src.r2_client import R2Client

from .config import HEARTBEAT_INTERVAL_S, ValidationConfig
from .packer import ParquetPacker
from .pipeline import ValidationPipeline
from .recover_loader import load_recover_segments

logger = logging.getLogger(__name__)


@dataclass
class RecoverStats:
    current_video_id: Optional[str] = None
    videos_processed: int = 0
    videos_failed: int = 0
    videos_quarantined: int = 0
    segments_processed: int = 0
    avg_segs_per_second: float = 0.0
    shards_written: int = 0
    total_parquet_mb: float = 0.0
    last_video_completed_at: Optional[datetime] = None
    last_error: Optional[str] = None
    _recent_speeds: list[float] = field(default_factory=list)


class RecoverValidationWorker:
    def __init__(self, config: ValidationConfig):
        self.config = config
        self.pipeline = ValidationPipeline(config)
        self.packer: Optional[ParquetPacker] = None
        self.stats = RecoverStats()
        self._db = None
        self._shutdown_event = asyncio.Event()
        self._heartbeat_task: Optional[asyncio.Task] = None
        self._work_dir = Path(tempfile.mkdtemp(prefix="recover_validation_"))
        self._start_time = 0.0
        self._raw_r2 = R2Client(EnvConfig())

    async def start(self):
        self._start_time = time.time()
        loop = asyncio.get_running_loop()
        for sig in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._handle_shutdown(s)))

        try:
            await self._connect_db()
            await self._recover_dead_videos()
            await self._register()

            logger.info("Loading validation models for recover worker...")
            self.pipeline.load_models()

            output_dir = self._work_dir / "recover_parquet_shards"
            self.packer = ParquetPacker(self.config, output_dir)

            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
            await self._main_loop()
        except Exception as e:
            logger.error(f"Recover worker fatal error: {e}", exc_info=True)
            self.stats.last_error = str(e)[:500]
            await self._update_worker_status("error", str(e)[:500])
        finally:
            await self._cleanup()

    async def _main_loop(self):
        consecutive_empty = 0
        max_empty = 5

        max_videos = self.config.max_videos
        if max_videos > 0:
            logger.info(f"RECOVER MAX_VIDEOS={max_videos} — will stop after {max_videos} video(s)")

        while not self._shutdown_event.is_set():
            if max_videos > 0 and self.stats.videos_processed >= max_videos:
                logger.info(f"Reached RECOVER MAX_VIDEOS={max_videos}, shutting down.")
                break

            claimed = await self._claim_video()
            if not claimed:
                consecutive_empty += 1
                if consecutive_empty >= max_empty:
                    logger.info(f"{max_empty} consecutive empty recover claims, exiting")
                    break
                wait = min(2.0 * consecutive_empty, 30.0)
                logger.info(f"No pending recover videos, waiting {wait:.0f}s...")
                try:
                    await asyncio.wait_for(self._shutdown_event.wait(), timeout=wait)
                    break
                except asyncio.TimeoutError:
                    continue

            consecutive_empty = 0
            video_id = claimed["video_id"]
            self.stats.current_video_id = video_id

            try:
                await self._process_one_video(video_id)
                self.stats.videos_processed += 1
                self.stats.last_video_completed_at = datetime.now(timezone.utc)
                try:
                    await self._send_heartbeat()
                except Exception as e:
                    logger.warning(f"Post-video heartbeat failed: {e}")
            except RecoverQuarantineError as e:
                self.stats.videos_quarantined += 1
                self.stats.last_error = str(e)[:500]
                logger.warning(f"[{video_id}] Quarantined: {e}")
                await self._mark_video_quarantined(video_id, str(e))
            except Exception as e:
                self.stats.videos_failed += 1
                self.stats.last_error = str(e)[:500]
                logger.error(f"[{video_id}] Failed: {e}", exc_info=True)
                await self._mark_video_failed(video_id, str(e))
            finally:
                self.stats.current_video_id = None

        logger.info(
            f"Recover main loop ended: {self.stats.videos_processed} processed, "
            f"{self.stats.videos_quarantined} quarantined, {self.stats.videos_failed} failed"
        )

    async def _process_one_video(self, video_id: str):
        t0 = time.time()
        video_work = self._work_dir / video_id
        video_work.mkdir(exist_ok=True)
        loop = asyncio.get_running_loop()

        try:
            tx_t0 = time.time()
            tx_rows = await self._fetch_tx_rows(video_id)
            if not tx_rows:
                raise RuntimeError("No transcription rows found for recover video")
            tx_elapsed = time.time() - tx_t0

            raw_t0 = time.time()
            tar_path = await loop.run_in_executor(None, self._raw_r2.download_tar, video_id, video_work)
            extracted = await loop.run_in_executor(None, self._raw_r2.extract_tar, tar_path, video_id)
            raw_elapsed = time.time() - raw_t0

            replay_t0 = time.time()
            recover = await loop.run_in_executor(
                None,
                lambda: load_recover_segments(extracted.work_dir, video_id, tx_rows),
            )
            replay_elapsed = time.time() - replay_t0

            if recover.missing_tx_ids:
                raise RecoverQuarantineError(
                    f"missing historical IDs after replay: {len(recover.missing_tx_ids)}"
                )
            if recover.missing_parent_files:
                raise RecoverQuarantineError(
                    f"missing raw parent files: {len(recover.missing_parent_files)}"
                )
            if not recover.segments:
                raise RuntimeError("Replay produced zero validation segments")

            flag_summary = await self._fetch_flag_summary(recover.extra_regen_ids)

            val_t0 = time.time()
            results = self.pipeline.process_video(video_id, recover.segments)
            val_elapsed = time.time() - val_t0

            self.packer.add_video_results(video_id, results)
            packer_stats = self.packer.stats
            self.stats.segments_processed += len(results)
            self.stats.shards_written = packer_stats["shards_written"]
            self.stats.total_parquet_mb = packer_stats.get("total_mb", 0.0)

            elapsed = time.time() - t0
            segs_per_sec = len(results) / elapsed if elapsed > 0 else 0.0
            self.stats._recent_speeds.append(segs_per_sec)
            self.stats._recent_speeds = self.stats._recent_speeds[-10:]
            self.stats.avg_segs_per_second = sum(self.stats._recent_speeds) / len(self.stats._recent_speeds)

            await self._mark_video_done(
                video_id,
                validated_segments=len(results),
                replayed_segments=len(recover.segments),
                extras_count=len(recover.extra_regen_ids),
                missing_tx_segments=0,
                missing_parent_files=0,
                extra_timeout_segments=flag_summary["timeout"],
                extra_error_segments=flag_summary["error"],
                extra_flagged_segments=flag_summary["flagged_total"],
                extra_unflagged_segments=max(len(recover.extra_regen_ids) - flag_summary["flagged_total"], 0),
            )

            logger.info(
                f"[{video_id}] Recover complete: tx={len(tx_rows)}, validated={len(results)}, "
                f"extras={len(recover.extra_regen_ids)} (flagged={flag_summary['flagged_total']}, "
                f"unflagged={max(len(recover.extra_regen_ids) - flag_summary['flagged_total'], 0)}), "
                f"timings tx={tx_elapsed:.1f}s raw={raw_elapsed:.1f}s replay={replay_elapsed:.1f}s "
                f"validate={val_elapsed:.1f}s total={elapsed:.1f}s"
            )
        finally:
            shutil.rmtree(video_work, ignore_errors=True)

    async def _connect_db(self):
        if not self.config.database_url:
            raise RuntimeError("DATABASE_URL required for recover worker")
        import asyncpg

        is_pooler = ":6543" in self.config.database_url
        self._db = await asyncpg.create_pool(
            dsn=self.config.database_url,
            min_size=1,
            max_size=4,
            command_timeout=60,
            statement_cache_size=0 if is_pooler else 100,
            ssl="require",
        )
        logger.info(f"Recover DB pool connected (pooler={is_pooler})")

    async def _recover_dead_videos(self):
        async with self._db.acquire() as conn:
            await conn.execute(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'pending', claimed_by = NULL, claimed_at = NULL, updated_at = now()
                WHERE status = 'recovering'
                  AND claimed_at < now() - interval '15 minutes'
            """)

    async def _register(self):
        config_json = {
            "mode": "recover",
            "models": {
                "mms_lid": self.config.enable_mms_lid,
                "voxlingua": self.config.enable_voxlingua,
                "conformer_multi": self.config.enable_conformer_multi,
                "wav2vec_lang": self.config.enable_wav2vec_lang,
            },
            "model_bucket": self.config.r2_model_bucket,
            "output_bucket": self.config.r2_bucket_output,
            "recover_queue_table": self.config.recover_queue_table,
        }
        async with self._db.acquire() as conn:
            await conn.execute("""
                INSERT INTO worker_validators (
                    worker_id, status, gpu_type, config_json,
                    started_at, last_heartbeat_at,
                    videos_processed, videos_failed, segments_processed,
                    shards_written, shards_uploaded, total_parquet_mb
                ) VALUES ($1, 'online', $2, $3::jsonb, now(), now(), 0, 0, 0, 0, 0, 0)
                ON CONFLICT (worker_id) DO UPDATE SET
                    status = 'online',
                    gpu_type = $2,
                    config_json = $3::jsonb,
                    started_at = now(),
                    last_heartbeat_at = now(),
                    videos_processed = 0,
                    videos_failed = 0,
                    segments_processed = 0,
                    shards_written = 0,
                    shards_uploaded = 0,
                    total_parquet_mb = 0,
                    last_error = NULL,
                    current_video_id = NULL
            """, self.config.worker_id, self.config.gpu_type, json.dumps(config_json))

    async def _claim_video(self) -> Optional[dict]:
        async with self._db.acquire() as conn:
            row = await conn.fetchrow(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'recovering',
                    claimed_by = $1,
                    claimed_at = now(),
                    updated_at = now()
                WHERE video_id = (
                    SELECT video_id
                    FROM {self.config.recover_queue_table}
                    WHERE status = 'pending'
                    ORDER BY tx_segments DESC, video_id
                    LIMIT 1
                    FOR UPDATE SKIP LOCKED
                )
                RETURNING video_id, tx_segments
            """, self.config.worker_id)
        return dict(row) if row else None

    async def _fetch_tx_rows(self, video_id: str) -> list[dict]:
        async with self._db.acquire() as conn:
            rows = await conn.fetch("""
                SELECT
                    segment_file,
                    expected_language_hint,
                    detected_language,
                    transcription,
                    tagged,
                    quality_score,
                    speaker_emotion,
                    speaker_style,
                    speaker_pace,
                    speaker_accent
                FROM transcription_results
                WHERE video_id = $1
                ORDER BY segment_file
            """, video_id)
        return [dict(row) for row in rows]

    async def _fetch_flag_summary(self, segment_ids: list[str]) -> dict:
        if not segment_ids:
            return {"timeout": 0, "error": 0, "rate_limited": 0, "flagged_total": 0}

        async with self._db.acquire() as conn:
            flagged_total = await conn.fetchval("""
                SELECT count(DISTINCT segment_id)
                FROM transcription_flags
                WHERE segment_id = ANY($1::text[])
            """, segment_ids)
            rows = await conn.fetch("""
                SELECT flag_type, count(DISTINCT segment_id) AS cnt
                FROM transcription_flags
                WHERE segment_id = ANY($1::text[])
                GROUP BY flag_type
            """, segment_ids)

        summary = {"timeout": 0, "error": 0, "rate_limited": 0, "flagged_total": int(flagged_total or 0)}
        for row in rows:
            flag_type = row["flag_type"]
            if flag_type in summary:
                summary[flag_type] = int(row["cnt"])
        return summary

    async def _mark_video_done(
        self,
        video_id: str,
        *,
        validated_segments: int,
        replayed_segments: int,
        extras_count: int,
        missing_tx_segments: int,
        missing_parent_files: int,
        extra_timeout_segments: int,
        extra_error_segments: int,
        extra_flagged_segments: int,
        extra_unflagged_segments: int,
    ):
        async with self._db.acquire() as conn:
            await conn.execute(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'recovered',
                    recovered_segments = $2,
                    replayed_segments = $3,
                    extra_regen_segments = $4,
                    missing_tx_segments = $5,
                    missing_parent_files = $6,
                    extra_timeout_segments = $7,
                    extra_error_segments = $8,
                    extra_flagged_segments = $9,
                    extra_unflagged_segments = $10,
                    completed_at = now(),
                    updated_at = now(),
                    error_message = NULL
                WHERE video_id = $1
            """, video_id, validated_segments, replayed_segments, extras_count, missing_tx_segments,
                 missing_parent_files, extra_timeout_segments, extra_error_segments,
                 extra_flagged_segments, extra_unflagged_segments)

    async def _mark_video_quarantined(self, video_id: str, error: str):
        async with self._db.acquire() as conn:
            await conn.execute(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'quarantined',
                    error_message = $2,
                    updated_at = now()
                WHERE video_id = $1
            """, video_id, error[:500])

    async def _mark_video_failed(self, video_id: str, error: str):
        async with self._db.acquire() as conn:
            await conn.execute(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'failed',
                    error_message = $2,
                    updated_at = now()
                WHERE video_id = $1
            """, video_id, error[:500])

    async def _heartbeat_loop(self):
        while not self._shutdown_event.is_set():
            try:
                await self._send_heartbeat()
            except Exception as e:
                logger.warning(f"Recover heartbeat failed: {e}")
            try:
                await asyncio.wait_for(self._shutdown_event.wait(), timeout=HEARTBEAT_INTERVAL_S)
                break
            except asyncio.TimeoutError:
                pass

    async def _send_heartbeat(self):
        async with self._db.acquire() as conn:
            await conn.execute("""
                UPDATE worker_validators SET
                    current_video_id = $2,
                    videos_processed = $3,
                    videos_failed = $4,
                    segments_processed = $5,
                    avg_segs_per_second = $6,
                    shards_written = $7,
                    total_parquet_mb = $8,
                    last_heartbeat_at = now(),
                    last_video_completed_at = $9,
                    last_error = $10
                WHERE worker_id = $1
            """,
                self.config.worker_id,
                self.stats.current_video_id,
                self.stats.videos_processed,
                self.stats.videos_failed + self.stats.videos_quarantined,
                self.stats.segments_processed,
                round(self.stats.avg_segs_per_second, 2),
                self.stats.shards_written,
                round(self.stats.total_parquet_mb, 2),
                self.stats.last_video_completed_at,
                self.stats.last_error,
            )

    async def _update_worker_status(self, status: str, error: str = ""):
        if not self._db:
            return
        async with self._db.acquire() as conn:
            await conn.execute("""
                UPDATE worker_validators
                SET status = $2, last_error = $3, last_heartbeat_at = now()
                WHERE worker_id = $1
            """, self.config.worker_id, status, error[:500] if error else None)

    async def _handle_shutdown(self, sig):
        logger.info(f"Received {sig.name}, initiating recover worker shutdown...")
        self._shutdown_event.set()

    async def _cleanup(self):
        if self._heartbeat_task:
            self._heartbeat_task.cancel()
            try:
                await self._heartbeat_task
            except asyncio.CancelledError:
                pass

        if self.packer:
            self.packer.flush()
            packer_stats = self.packer.stats
            self.stats.shards_written = packer_stats["shards_written"]
            self.stats.total_parquet_mb = packer_stats.get("total_mb", 0.0)

        try:
            self.pipeline.unload_models()
        except Exception:
            pass

        try:
            await self._send_heartbeat()
        except Exception:
            pass
        await self._update_worker_status("offline")

        if self._db:
            await self._db.close()

        shutil.rmtree(self._work_dir, ignore_errors=True)
        logger.info(
            f"Recover worker shutdown complete. processed={self.stats.videos_processed} "
            f"quarantined={self.stats.videos_quarantined} failed={self.stats.videos_failed} "
            f"segments={self.stats.segments_processed}"
        )


class RecoverQuarantineError(RuntimeError):
    """Raised when replay cannot faithfully recover the historical tx set."""

