"""
Recovery validation worker.

Workflow per video:
  - claim from `validation_recover_queue`
  - download raw tar from `1-cleaned-data`
  - replay `audio_polish` deterministically
  - match replayed child IDs to historical `transcription_results`
  - validate all historical transcription rows
  - flush+upload shard BEFORE marking video recovered (crash-safe)
  - prefetch next video's raw tar + replay during GPU inference
"""
from __future__ import annotations

import asyncio
import json
import logging
import random
import shutil
import signal
import tempfile
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

from src.config import EnvConfig
from src.r2_client import R2Client

from .config import HEARTBEAT_INTERVAL_S, ValidationConfig
from .packer import ParquetPacker
from .pipeline import ValidationPipeline
from .recover_loader import RecoveryLoadResult, load_recover_segments
from .recover_replay_ledger import ReplayLedgerWriter, build_replay_ledger_payload
from .recover_reference_store import RecoverReferenceStore

logger = logging.getLogger(__name__)


@dataclass
class RecoverStats:
    current_video_id: Optional[str] = None
    videos_processed: int = 0
    videos_failed: int = 0
    videos_quarantined: int = 0
    segments_processed: int = 0
    avg_segs_per_second: float = 0.0
    shards_written: int = 0
    total_parquet_mb: float = 0.0
    last_video_completed_at: Optional[datetime] = None
    last_error: Optional[str] = None
    _recent_speeds: list[float] = field(default_factory=list)


@dataclass
class PrefetchedVideo:
    video_id: str
    tx_rows: list[dict]
    validated_segment_ids: set[str]
    recover: RecoveryLoadResult
    flag_summary: dict
    work_dir: Path
    prefetch_elapsed: float


class RecoverValidationWorker:
    def __init__(self, config: ValidationConfig):
        self.config = config
        self.pipeline = ValidationPipeline(config)
        self.packer: Optional[ParquetPacker] = None
        self.stats = RecoverStats()
        self._db = None
        self._shutdown_event = asyncio.Event()
        self._heartbeat_task: Optional[asyncio.Task] = None
        self._work_dir = Path(tempfile.mkdtemp(prefix="recover_validation_"))
        self._start_time = 0.0
        self._raw_r2 = R2Client(EnvConfig())
        self._reference_store = RecoverReferenceStore(config, self._work_dir)
        self._replay_ledger_writer = ReplayLedgerWriter(config)

    async def start(self):
        self._start_time = time.time()
        loop = asyncio.get_running_loop()
        for sig in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._handle_shutdown(s)))

        try:
            await self._connect_db()
            await self._recover_dead_videos()
            await self._reference_store.start()
            await self._register()

            logger.info("Loading validation models for recover worker...")
            self.pipeline.load_models()

            output_dir = self._work_dir / "recover_parquet_shards"
            self.packer = ParquetPacker(self.config, output_dir)

            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
            await self._main_loop()
        except Exception as e:
            logger.error(f"Recover worker fatal error: {e}", exc_info=True)
            self.stats.last_error = str(e)[:500]
            await self._update_worker_status("error", str(e)[:500])
        finally:
            await self._cleanup()

    async def _main_loop(self):
        consecutive_empty = 0
        max_empty = 5
        consecutive_gpu_failures = 0
        max_gpu_failures = 3

        max_videos = self.config.max_videos
        if max_videos > 0:
            logger.info(f"RECOVER MAX_VIDEOS={max_videos} — will stop after {max_videos} video(s)")

        prefetch_task: Optional[asyncio.Task] = None
        prefetched: Optional[PrefetchedVideo] = None

        while not self._shutdown_event.is_set():
            if max_videos > 0 and self.stats.videos_processed >= max_videos:
                logger.info(f"Reached RECOVER MAX_VIDEOS={max_videos}, shutting down.")
                break

            if prefetched:
                video_id = prefetched.video_id
                logger.info(f"Using prefetched video: {video_id}")
            else:
                claimed = await self._claim_video()
                if not claimed:
                    consecutive_empty += 1
                    if consecutive_empty >= max_empty:
                        logger.info(f"{max_empty} consecutive empty recover claims, exiting")
                        break
                    wait = min(2.0 * consecutive_empty, 30.0)
                    logger.info(f"No pending recover videos, waiting {wait:.0f}s...")
                    try:
                        await asyncio.wait_for(self._shutdown_event.wait(), timeout=wait)
                        break
                    except asyncio.TimeoutError:
                        continue
                video_id = claimed["video_id"]

            consecutive_empty = 0
            self.stats.current_video_id = video_id

            remaining = (max_videos - self.stats.videos_processed - 1) if max_videos > 0 else 1
            if remaining > 0 and not self._shutdown_event.is_set():
                prefetch_task = asyncio.create_task(self._prefetch_next())

            try:
                await self._process_one_video(video_id, prefetched)
                self.stats.videos_processed += 1
                self.stats.last_video_completed_at = datetime.now(timezone.utc)
                consecutive_gpu_failures = 0
                try:
                    await self._send_heartbeat()
                except Exception as e:
                    logger.warning(f"Post-video heartbeat failed: {e}")
            except RecoverQuarantineError as e:
                self.stats.videos_quarantined += 1
                self.stats.last_error = str(e)[:500]
                logger.warning(f"[{video_id}] Quarantined: {e}")
                await self._mark_video_quarantined(video_id, str(e))
            except Exception as e:
                self.stats.videos_failed += 1
                self.stats.last_error = str(e)[:500]
                logger.error(f"[{video_id}] Failed: {e}", exc_info=True)
                await self._mark_video_failed(video_id, str(e))
                err_str = str(e)
                is_gpu_fatal = "CUDA" in err_str or "device-side" in err_str or "cuFFT" in err_str
                if is_gpu_fatal:
                    import torch
                    torch.cuda.empty_cache()
                    consecutive_gpu_failures += 1
                    logger.warning(
                        f"GPU error ({consecutive_gpu_failures}/{max_gpu_failures}), "
                        f"cleared CUDA cache"
                    )
                    if consecutive_gpu_failures >= max_gpu_failures:
                        logger.error(
                            f"{consecutive_gpu_failures} consecutive GPU failures "
                            "(GPU likely poisoned), exiting"
                        )
                        break
                else:
                    consecutive_gpu_failures = 0
            finally:
                self.stats.current_video_id = None
                prefetched = None

            if prefetch_task and not prefetch_task.done():
                try:
                    prefetched = await asyncio.wait_for(prefetch_task, timeout=30)
                except (asyncio.TimeoutError, Exception):
                    prefetched = None
            elif prefetch_task and prefetch_task.done():
                try:
                    prefetched = prefetch_task.result()
                except Exception:
                    prefetched = None
            prefetch_task = None

        logger.info(
            f"Recover main loop ended: {self.stats.videos_processed} processed, "
            f"{self.stats.videos_quarantined} quarantined, {self.stats.videos_failed} failed"
        )

    async def _prefetch_next(self) -> Optional[PrefetchedVideo]:
        """Claim + download + replay the next video while current one runs on GPU."""
        video_id = None
        try:
            claimed = await self._claim_video()
            if not claimed:
                return None

            video_id = claimed["video_id"]
            t0 = time.time()
            video_work = self._work_dir / video_id
            video_work.mkdir(exist_ok=True)
            loop = asyncio.get_running_loop()

            tx_rows = await self._fetch_tx_rows(video_id)
            if not tx_rows:
                await self._mark_video_failed(video_id, "No transcription rows found")
                shutil.rmtree(video_work, ignore_errors=True)
                return None

            validated_ids = await self._fetch_validated_segment_ids(video_id)
            tx_ids = {row["segment_file"] for row in tx_rows}
            target_ids = tx_ids - validated_ids

            tar_path = await loop.run_in_executor(None, self._raw_r2.download_tar, video_id, video_work)
            extracted = await loop.run_in_executor(None, self._raw_r2.extract_tar, tar_path, video_id)

            recover = await loop.run_in_executor(
                None,
                lambda: load_recover_segments(
                    extracted.work_dir, video_id, tx_rows,
                    target_segment_ids=target_ids if target_ids != tx_ids else None,
                    replay_all_tx_parents=True,
                ),
            )
            flag_summary = await self._fetch_flag_summary(recover.extra_regen_ids)

            elapsed = time.time() - t0
            logger.info(
                f"[prefetch] {video_id}: {len(recover.segments)} to validate "
                f"({len(validated_ids)} already validated, {len(recover.extra_regen_ids)} extras) "
                f"ready in {elapsed:.1f}s"
            )
            return PrefetchedVideo(
                video_id=video_id,
                tx_rows=tx_rows,
                validated_segment_ids=validated_ids,
                recover=recover,
                flag_summary=flag_summary,
                work_dir=video_work,
                prefetch_elapsed=elapsed,
            )
        except Exception as e:
            logger.warning(f"[prefetch] Failed: {e}")
            if video_id:
                await self._mark_video_failed(video_id, f"prefetch error: {str(e)[:200]}")
            return None

    async def _process_one_video(self, video_id: str, prefetched: Optional[PrefetchedVideo] = None):
        t0 = time.time()
        loop = asyncio.get_running_loop()

        if prefetched and prefetched.video_id == video_id:
            tx_rows = prefetched.tx_rows
            validated_ids = prefetched.validated_segment_ids
            recover = prefetched.recover
            flag_summary = prefetched.flag_summary
            video_work = prefetched.work_dir
            tx_elapsed = 0.0
            raw_elapsed = 0.0
            replay_elapsed = prefetched.prefetch_elapsed
        else:
            video_work = self._work_dir / video_id
            video_work.mkdir(exist_ok=True)

            tx_t0 = time.time()
            tx_rows = await self._fetch_tx_rows(video_id)
            if not tx_rows:
                raise RuntimeError("No transcription rows found for recover video")
            tx_elapsed = time.time() - tx_t0

            raw_t0 = time.time()
            tar_path = await loop.run_in_executor(None, self._raw_r2.download_tar, video_id, video_work)
            extracted = await loop.run_in_executor(None, self._raw_r2.extract_tar, tar_path, video_id)
            raw_elapsed = time.time() - raw_t0

            validated_ids = await self._fetch_validated_segment_ids(video_id)
            tx_ids = {row["segment_file"] for row in tx_rows}
            target_ids = tx_ids - validated_ids

            replay_t0 = time.time()
            recover = await loop.run_in_executor(
                None,
                lambda: load_recover_segments(
                    extracted.work_dir, video_id, tx_rows,
                    target_segment_ids=target_ids if target_ids != tx_ids else None,
                    replay_all_tx_parents=True,
                ),
            )
            replay_elapsed = time.time() - replay_t0
            flag_summary = await self._fetch_flag_summary(recover.extra_regen_ids)

        try:
            if recover.missing_tx_ids:
                raise RecoverQuarantineError(
                    f"missing historical IDs after replay: {len(recover.missing_tx_ids)}"
                )
            if recover.missing_parent_files:
                raise RecoverQuarantineError(
                    f"missing raw parent files: {len(recover.missing_parent_files)}"
                )

            if recover.segments:
                val_t0 = time.time()
                results = self.pipeline.process_video(video_id, recover.segments)
                val_elapsed = time.time() - val_t0

                self.packer.add_video_results(video_id, results)
                self.packer.flush()

                packer_stats = self.packer.stats
                self.stats.segments_processed += len(results)
                self.stats.shards_written = packer_stats["shards_written"]
                self.stats.total_parquet_mb = packer_stats.get("total_mb", 0.0)
            else:
                results = []
                val_elapsed = 0.0
                logger.info(f"[{video_id}] All segments already validated, replay-only for extras")

            elapsed = time.time() - t0
            segs_per_sec = len(results) / elapsed if elapsed > 0 else 0.0
            self.stats._recent_speeds.append(segs_per_sec)
            self.stats._recent_speeds = self.stats._recent_speeds[-10:]
            self.stats.avg_segs_per_second = sum(self.stats._recent_speeds) / len(self.stats._recent_speeds)

            extras_json = json.dumps(sorted(recover.extra_regen_ids)[:5000]) if recover.extra_regen_ids else None
            ledger_payload = build_replay_ledger_payload(
                video_id=video_id,
                tx_rows=tx_rows,
                replayed_segment_ids=recover.replayed_regen_ids,
                matched_tx_ids=recover.matched_tx_ids,
                validated_segment_ids=validated_ids,
                extra_regen_ids=recover.extra_regen_ids,
                flag_summary=flag_summary,
                worker_id=self.config.worker_id,
                missing_tx_ids=recover.missing_tx_ids,
                missing_parent_files=recover.missing_parent_files,
            )
            ledger_artifact = self._replay_ledger_writer.upload(video_id, ledger_payload)

            await self._mark_video_done(
                video_id,
                validated_segments=len(results),
                replayed_segments=len(recover.replayed_regen_ids),
                extras_count=len(recover.extra_regen_ids),
                missing_tx_segments=0,
                missing_parent_files=0,
                extra_timeout_segments=flag_summary["timeout"],
                extra_error_segments=flag_summary["error"],
                extra_flagged_segments=flag_summary["flagged_total"],
                extra_unflagged_segments=max(len(recover.extra_regen_ids) - flag_summary["flagged_total"], 0),
                extra_regen_ids_json=extras_json,
            )

            logger.info(
                f"[{video_id}] Recover complete: tx={len(tx_rows)}, validated={len(results)}, "
                f"extras={len(recover.extra_regen_ids)} (flagged={flag_summary['flagged_total']}, "
                f"unflagged={max(len(recover.extra_regen_ids) - flag_summary['flagged_total'], 0)}), "
                f"ledger={ledger_artifact.key}, "
                f"timings tx={tx_elapsed:.1f}s raw={raw_elapsed:.1f}s replay={replay_elapsed:.1f}s "
                f"validate={val_elapsed:.1f}s total={elapsed:.1f}s"
            )
        finally:
            shutil.rmtree(video_work, ignore_errors=True)

    # ── Database Operations ────────────────────────────────────────────

    async def _connect_db(self):
        if not self.config.database_url:
            raise RuntimeError("DATABASE_URL required for recover worker")
        import asyncpg

        is_pooler = ":6543" in self.config.database_url
        self._db = await asyncpg.create_pool(
            dsn=self.config.database_url,
            min_size=1,
            max_size=4,
            command_timeout=60,
            statement_cache_size=0 if is_pooler else 100,
            ssl="require",
        )
        logger.info(f"Recover DB pool connected (pooler={is_pooler})")

    async def _recover_dead_videos(self):
        async with self._db.acquire() as conn:
            await conn.execute(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'pending', claimed_by = NULL, claimed_at = NULL, updated_at = now()
                WHERE status = 'recovering'
                  AND claimed_at < now() - interval '15 minutes'
            """)

    async def _register(self):
        config_json = {
            "mode": "recover",
            "models": {
                "mms_lid": self.config.enable_mms_lid,
                "voxlingua": self.config.enable_voxlingua,
                "conformer_multi": self.config.enable_conformer_multi,
                "wav2vec_lang": self.config.enable_wav2vec_lang,
            },
            "reference_mode": self.config.recover_reference_mode,
            "reference_bucket": self.config.r2_reference_bucket,
            "tx_parquet_key": self.config.recover_tx_parquet_key,
            "flags_parquet_key": self.config.recover_flags_parquet_key,
            "replay_ledger_prefix": self.config.recover_replay_ledger_prefix,
            "model_bucket": self.config.r2_model_bucket,
            "output_bucket": self.config.r2_bucket_output,
            "recover_queue_table": self.config.recover_queue_table,
        }
        async with self._db.acquire() as conn:
            await conn.execute("""
                INSERT INTO worker_validators (
                    worker_id, status, gpu_type, config_json,
                    started_at, last_heartbeat_at,
                    videos_processed, videos_failed, segments_processed,
                    shards_written, shards_uploaded, total_parquet_mb
                ) VALUES ($1, 'online', $2, $3::jsonb, now(), now(), 0, 0, 0, 0, 0, 0)
                ON CONFLICT (worker_id) DO UPDATE SET
                    status = 'online',
                    gpu_type = $2,
                    config_json = $3::jsonb,
                    started_at = now(),
                    last_heartbeat_at = now(),
                    videos_processed = 0,
                    videos_failed = 0,
                    segments_processed = 0,
                    shards_written = 0,
                    shards_uploaded = 0,
                    total_parquet_mb = 0,
                    last_error = NULL,
                    current_video_id = NULL
            """, self.config.worker_id, self.config.gpu_type, json.dumps(config_json))

    async def _claim_video(self) -> Optional[dict]:
        async with self._db.acquire() as conn:
            row = await conn.fetchrow(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'recovering',
                    claimed_by = $1,
                    claimed_at = now(),
                    updated_at = now()
                WHERE video_id = (
                    SELECT video_id
                    FROM {self.config.recover_queue_table}
                    WHERE status = 'pending'
                    ORDER BY tx_segments DESC, video_id
                    LIMIT 1
                    FOR UPDATE SKIP LOCKED
                )
                RETURNING video_id, tx_segments
            """, self.config.worker_id)
        return dict(row) if row else None

    async def _fetch_validated_segment_ids(self, video_id: str) -> set[str]:
        if self._reference_store.enabled:
            return await self._reference_store.fetch_validated_segment_ids(video_id)
        return set()

    async def _fetch_tx_rows(self, video_id: str) -> list[dict]:
        if self._reference_store.enabled:
            return await self._reference_store.fetch_tx_rows(video_id)
        async with self._db.acquire() as conn:
            rows = await conn.fetch("""
                SELECT
                    segment_file,
                    expected_language_hint,
                    detected_language,
                    transcription,
                    tagged,
                    quality_score,
                    speaker_emotion,
                    speaker_style,
                    speaker_pace,
                    speaker_accent
                FROM transcription_results
                WHERE video_id = $1
                ORDER BY segment_file
            """, video_id)
        return [dict(row) for row in rows]

    async def _fetch_flag_summary(self, segment_ids: list[str]) -> dict:
        if not segment_ids:
            return {"timeout": 0, "error": 0, "rate_limited": 0, "flagged_total": 0}
        if self._reference_store.enabled:
            return await self._reference_store.fetch_flag_summary(segment_ids)

        async with self._db.acquire() as conn:
            flagged_total = await conn.fetchval("""
                SELECT count(DISTINCT segment_id)
                FROM transcription_flags
                WHERE segment_id = ANY($1::text[])
            """, segment_ids)
            rows = await conn.fetch("""
                SELECT flag_type, count(DISTINCT segment_id) AS cnt
                FROM transcription_flags
                WHERE segment_id = ANY($1::text[])
                GROUP BY flag_type
            """, segment_ids)

        summary = {"timeout": 0, "error": 0, "rate_limited": 0, "flagged_total": int(flagged_total or 0)}
        for row in rows:
            flag_type = row["flag_type"]
            if flag_type in summary:
                summary[flag_type] = int(row["cnt"])
        return summary

    async def _mark_video_done(
        self,
        video_id: str,
        *,
        validated_segments: int,
        replayed_segments: int,
        extras_count: int,
        missing_tx_segments: int,
        missing_parent_files: int,
        extra_timeout_segments: int,
        extra_error_segments: int,
        extra_flagged_segments: int,
        extra_unflagged_segments: int,
        extra_regen_ids_json: Optional[str] = None,
    ):
        async with self._db.acquire() as conn:
            await conn.execute(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'recovered',
                    recovered_segments = $2,
                    replayed_segments = $3,
                    extra_regen_segments = $4,
                    missing_tx_segments = $5,
                    missing_parent_files = $6,
                    extra_timeout_segments = $7,
                    extra_error_segments = $8,
                    extra_flagged_segments = $9,
                    extra_unflagged_segments = $10,
                    extra_regen_ids_json = $11,
                    completed_at = now(),
                    updated_at = now(),
                    error_message = NULL
                WHERE video_id = $1
            """, video_id, validated_segments, replayed_segments, extras_count, missing_tx_segments,
                 missing_parent_files, extra_timeout_segments, extra_error_segments,
                 extra_flagged_segments, extra_unflagged_segments, extra_regen_ids_json)

    async def _mark_video_quarantined(self, video_id: str, error: str):
        async with self._db.acquire() as conn:
            await conn.execute(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'quarantined',
                    error_message = $2,
                    updated_at = now()
                WHERE video_id = $1
            """, video_id, error[:500])

    async def _mark_video_failed(self, video_id: str, error: str):
        async with self._db.acquire() as conn:
            await conn.execute(f"""
                UPDATE {self.config.recover_queue_table}
                SET status = 'failed',
                    error_message = $2,
                    updated_at = now()
                WHERE video_id = $1
            """, video_id, error[:500])

    # ── Heartbeat ──────────────────────────────────────────────────────

    async def _heartbeat_loop(self):
        while not self._shutdown_event.is_set():
            try:
                await self._send_heartbeat()
            except Exception as e:
                logger.warning(f"Recover heartbeat failed: {e}")
            try:
                await asyncio.wait_for(self._shutdown_event.wait(), timeout=HEARTBEAT_INTERVAL_S)
                break
            except asyncio.TimeoutError:
                pass

    async def _send_heartbeat(self):
        async with self._db.acquire() as conn:
            await conn.execute("""
                UPDATE worker_validators SET
                    current_video_id = $2,
                    videos_processed = $3,
                    videos_failed = $4,
                    segments_processed = $5,
                    avg_segs_per_second = $6,
                    shards_written = $7,
                    total_parquet_mb = $8,
                    last_heartbeat_at = now(),
                    last_video_completed_at = $9,
                    last_error = $10
                WHERE worker_id = $1
            """,
                self.config.worker_id,
                self.stats.current_video_id,
                self.stats.videos_processed,
                self.stats.videos_failed + self.stats.videos_quarantined,
                self.stats.segments_processed,
                round(self.stats.avg_segs_per_second, 2),
                self.stats.shards_written,
                round(self.stats.total_parquet_mb, 2),
                self.stats.last_video_completed_at,
                self.stats.last_error,
            )

    async def _update_worker_status(self, status: str, error: str = ""):
        if not self._db:
            return
        async with self._db.acquire() as conn:
            await conn.execute("""
                UPDATE worker_validators
                SET status = $2, last_error = $3, last_heartbeat_at = now()
                WHERE worker_id = $1
            """, self.config.worker_id, status, error[:500] if error else None)

    # ── Shutdown ───────────────────────────────────────────────────────

    async def _handle_shutdown(self, sig):
        logger.info(f"Received {sig.name}, initiating recover worker shutdown...")
        self._shutdown_event.set()

    async def _cleanup(self):
        if self._heartbeat_task:
            self._heartbeat_task.cancel()
            try:
                await self._heartbeat_task
            except asyncio.CancelledError:
                pass

        if self.packer:
            self.packer.flush()
            packer_stats = self.packer.stats
            self.stats.shards_written = packer_stats["shards_written"]
            self.stats.total_parquet_mb = packer_stats.get("total_mb", 0.0)

        try:
            self.pipeline.unload_models()
        except Exception:
            pass

        try:
            await self._reference_store.close()
        except Exception:
            pass

        try:
            await self._send_heartbeat()
        except Exception:
            pass
        await self._update_worker_status("offline")

        if self._db:
            await self._db.close()

        shutil.rmtree(self._work_dir, ignore_errors=True)
        logger.info(
            f"Recover worker shutdown complete. processed={self.stats.videos_processed} "
            f"quarantined={self.stats.videos_quarantined} failed={self.stats.videos_failed} "
            f"segments={self.stats.segments_processed}"
        )


class RecoverQuarantineError(RuntimeError):
    """Raised when replay cannot faithfully recover the historical tx set."""
