"""
Recover reference store.

This lets recover workers use R2-hosted parquet snapshots instead of hitting
Supabase for `transcription_results` and `transcription_flags` on every video.

The snapshots are downloaded once per worker with multipart S3 transfers, then
queried locally through DuckDB.
"""
from __future__ import annotations

import asyncio
import json
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import duckdb
from boto3.s3.transfer import TransferConfig
from botocore.config import Config as BotoConfig

from .config import ValidationConfig

logger = logging.getLogger(__name__)


@dataclass
class RecoverReferenceMetadata:
    tx_rows: int = 0
    flags_rows: int = 0
    tx_size_bytes: int = 0
    flags_size_bytes: int = 0
    source_bucket: str = ""
    tx_key: str = ""
    flags_key: str = ""
    manifest_key: str = ""


class RecoverReferenceStore:
    def __init__(self, config: ValidationConfig, cache_dir: Path):
        self.config = config
        self.cache_dir = cache_dir / "recover_reference"
        self.tx_path = self.cache_dir / _cache_name(config.recover_tx_parquet_key)
        self.flags_path = self.cache_dir / _cache_name(config.recover_flags_parquet_key)
        self.validated_path = self.cache_dir / _cache_name(config.recover_validated_parquet_key)
        self.manifest_path = self.cache_dir / _cache_name(config.recover_reference_manifest_key)
        self.duckdb_path = self.cache_dir / "recover_reference.duckdb"
        self.metadata = RecoverReferenceMetadata(
            source_bucket=config.r2_reference_bucket,
            tx_key=config.recover_tx_parquet_key,
            flags_key=config.recover_flags_parquet_key,
            manifest_key=config.recover_reference_manifest_key,
        )
        self._con: Optional[duckdb.DuckDBPyConnection] = None
        self._lock = asyncio.Lock()

    @property
    def enabled(self) -> bool:
        return self.config.recover_reference_mode == "parquet"

    async def start(self):
        if not self.enabled:
            return
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(None, self._start_sync)

    async def close(self):
        if self._con is None:
            return
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(None, self._close_sync)

    async def fetch_tx_rows(self, video_id: str) -> list[dict]:
        async with self._lock:
            loop = asyncio.get_running_loop()
            return await loop.run_in_executor(None, lambda: self._fetch_tx_rows_sync(video_id))

    async def fetch_validated_segment_ids(self, video_id: str) -> set[str]:
        async with self._lock:
            loop = asyncio.get_running_loop()
            return await loop.run_in_executor(None, lambda: self._fetch_validated_ids_sync(video_id))

    async def fetch_flag_summary(self, segment_ids: list[str]) -> dict:
        if not segment_ids:
            return {"timeout": 0, "error": 0, "rate_limited": 0, "flagged_total": 0}
        async with self._lock:
            loop = asyncio.get_running_loop()
            return await loop.run_in_executor(None, lambda: self._fetch_flag_summary_sync(segment_ids))

    def _start_sync(self):
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self._download_reference_files()
        self._open_duckdb()

    def _close_sync(self):
        if self._con is not None:
            self._con.close()
            self._con = None

    def _download_reference_files(self):
        import boto3

        s3 = boto3.client(
            "s3",
            endpoint_url=self.config.r2_endpoint_url,
            aws_access_key_id=self.config.r2_access_key_id,
            aws_secret_access_key=self.config.r2_secret_access_key,
            region_name="auto",
            config=BotoConfig(max_pool_connections=_s3_pool_size(
                self.config.recover_reference_download_concurrency
            )),
        )
        transfer = TransferConfig(
            multipart_threshold=64 * 1024 * 1024,
            multipart_chunksize=64 * 1024 * 1024,
            max_concurrency=max(self.config.recover_reference_download_concurrency, 1),
            use_threads=True,
        )

        self._try_download_manifest(s3)
        self.metadata.tx_size_bytes = self._download_object(
            s3,
            self.config.r2_reference_bucket,
            self.config.recover_tx_parquet_key,
            self.tx_path,
            transfer,
        )
        self.metadata.flags_size_bytes = self._download_object(
            s3,
            self.config.r2_reference_bucket,
            self.config.recover_flags_parquet_key,
            self.flags_path,
            transfer,
        )
        self._try_download_validated(s3, transfer)

        if self.manifest_path.exists():
            try:
                payload = json.loads(self.manifest_path.read_text())
                self.metadata.tx_rows = int(payload.get("tx_rows", 0) or 0)
                self.metadata.flags_rows = int(payload.get("flags_rows", 0) or 0)
            except Exception as e:  # pragma: no cover - defensive logging path
                logger.warning(f"Failed to parse reference manifest {self.manifest_path}: {e}")

        logger.info(
            "Recover reference snapshots ready: tx=%s (%s), flags=%s (%s), rows tx=%s flags=%s",
            self.tx_path.name,
            _format_bytes(self.metadata.tx_size_bytes),
            self.flags_path.name,
            _format_bytes(self.metadata.flags_size_bytes),
            f"{self.metadata.tx_rows:,}" if self.metadata.tx_rows else "?",
            f"{self.metadata.flags_rows:,}" if self.metadata.flags_rows else "?",
        )

    def _try_download_manifest(self, s3):
        try:
            self._download_object(
                s3,
                self.config.r2_reference_bucket,
                self.config.recover_reference_manifest_key,
                self.manifest_path,
                None,
                force=True,
            )
        except Exception as e:  # pragma: no cover - optional metadata path
            logger.info(f"Recover reference manifest not available yet: {e}")

    def _try_download_validated(self, s3, transfer):
        try:
            self._download_object(
                s3,
                self.config.r2_reference_bucket,
                self.config.recover_validated_parquet_key,
                self.validated_path,
                transfer,
            )
            logger.info(f"Validated segment IDs snapshot ready: {self.validated_path.name}")
        except Exception as e:
            logger.info(f"Validated segment IDs snapshot not available (will validate all): {e}")

    def _download_object(
        self,
        s3,
        bucket: str,
        key: str,
        dest: Path,
        transfer: Optional[TransferConfig],
        *,
        force: bool = False,
    ) -> int:
        head = s3.head_object(Bucket=bucket, Key=key)
        size_bytes = int(head.get("ContentLength", 0) or 0)
        if not force and dest.exists() and dest.stat().st_size == size_bytes:
            logger.info(f"Reusing cached snapshot {dest.name} ({_format_bytes(size_bytes)})")
            return size_bytes

        logger.info(
            f"Downloading recover reference s3://{bucket}/{key} -> {dest} "
            f"({_format_bytes(size_bytes)})"
        )
        t0 = time.time()
        extra = {"Config": transfer} if transfer is not None else {}
        s3.download_file(bucket, key, str(dest), **extra)
        logger.info(f"Downloaded {dest.name} in {time.time() - t0:.1f}s")
        return size_bytes

    def _open_duckdb(self):
        self._con = duckdb.connect(str(self.duckdb_path))
        threads = min(os.cpu_count() or 4, 16)
        self._con.execute(f"SET threads = {threads}")
        tx_path = _sql_path(self.tx_path)
        flags_path = _sql_path(self.flags_path)
        self._con.execute(f"""
            CREATE OR REPLACE VIEW recover_tx_rows AS
            SELECT
                video_id,
                segment_file,
                expected_language_hint,
                detected_language,
                transcription,
                tagged,
                quality_score,
                speaker_emotion,
                speaker_style,
                speaker_pace,
                speaker_accent
            FROM read_parquet('{tx_path}')
        """)
        self._con.execute(f"""
            CREATE OR REPLACE VIEW recover_flags AS
            SELECT segment_id, flag_type
            FROM read_parquet('{flags_path}')
        """)
        if self.validated_path.exists() and self.validated_path.stat().st_size > 0:
            val_path = _sql_path(self.validated_path)
            self._con.execute(f"""
                CREATE OR REPLACE VIEW recover_validated AS
                SELECT video_id, segment_file
                FROM read_parquet('{val_path}')
            """)

    def _fetch_tx_rows_sync(self, video_id: str) -> list[dict]:
        if self._con is None:
            raise RuntimeError("Recover reference store not initialized")
        rel = self._con.execute(
            """
            SELECT
                segment_file,
                expected_language_hint,
                detected_language,
                transcription,
                tagged,
                quality_score,
                speaker_emotion,
                speaker_style,
                speaker_pace,
                speaker_accent
            FROM recover_tx_rows
            WHERE video_id = ?
            ORDER BY segment_file
            """,
            [video_id],
        )
        cols = [d[0] for d in rel.description]
        return [dict(zip(cols, row)) for row in rel.fetchall()]

    def _fetch_validated_ids_sync(self, video_id: str) -> set[str]:
        if self._con is None:
            raise RuntimeError("Recover reference store not initialized")
        try:
            self._con.execute("SELECT 1 FROM recover_validated LIMIT 0")
        except duckdb.CatalogException:
            return set()
        rows = self._con.execute(
            "SELECT segment_file FROM recover_validated WHERE video_id = ?",
            [video_id],
        ).fetchall()
        return {row[0] for row in rows}

    def _fetch_flag_summary_sync(self, segment_ids: list[str]) -> dict:
        if self._con is None:
            raise RuntimeError("Recover reference store not initialized")
        placeholders = ", ".join("?" for _ in segment_ids)
        flagged_total = self._con.execute(
            f"""
            SELECT count(DISTINCT segment_id)
            FROM recover_flags
            WHERE segment_id IN ({placeholders})
            """,
            segment_ids,
        ).fetchone()[0]
        rows = self._con.execute(
            f"""
            SELECT flag_type, count(DISTINCT segment_id) AS cnt
            FROM recover_flags
            WHERE segment_id IN ({placeholders})
            GROUP BY flag_type
            """,
            segment_ids,
        ).fetchall()
        summary = {"timeout": 0, "error": 0, "rate_limited": 0, "flagged_total": int(flagged_total or 0)}
        for flag_type, cnt in rows:
            if flag_type in summary:
                summary[flag_type] = int(cnt)
        return summary


def _format_bytes(size_bytes: int) -> str:
    if size_bytes <= 0:
        return "0B"
    units = ["B", "KB", "MB", "GB", "TB"]
    value = float(size_bytes)
    for unit in units:
        if value < 1024.0 or unit == units[-1]:
            return f"{value:.1f}{unit}"
        value /= 1024.0
    return f"{size_bytes}B"


def _sql_path(path: Path) -> str:
    return path.as_posix().replace("'", "''")


def _cache_name(key: str) -> str:
    return key.strip("/").replace("/", "__")


def _s3_pool_size(download_concurrency: int) -> int:
    concurrency = max(download_concurrency, 1)
    return max(32, concurrency * 2)
