"""
Database operations: video_queue claim/release, worker heartbeat, results insert, flags.
Uses asyncpg for direct PostgreSQL (bypasses Supabase REST/PostgREST bottleneck).
Supports mock mode for local testing.
"""
from __future__ import annotations

import asyncio
import json
import logging
import random
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional

logger = logging.getLogger(__name__)

# DB operation counters for monitoring (logged periodically)
_db_stats = {
    "claims_ok": 0, "claims_fail": 0,
    "inserts_ok": 0, "inserts_fail": 0, "inserts_rows": 0,
    "heartbeats_ok": 0, "heartbeats_fail": 0,
    "retries": 0,
}


def log_db_stats():
    s = _db_stats
    logger.info(
        f"[DB-STATS] claims={s['claims_ok']}/{s['claims_ok']+s['claims_fail']} "
        f"inserts={s['inserts_ok']}ok/{s['inserts_fail']}fail ({s['inserts_rows']} rows) "
        f"heartbeats={s['heartbeats_ok']}/{s['heartbeats_ok']+s['heartbeats_fail']} "
        f"retries={s['retries']}"
    )


@dataclass
class VideoTask:
    video_id: str
    language: str
    segment_count: int = 0
    status: str = "pending"
    prefetch_dir: Optional[Path] = None


@dataclass
class WorkerStats:
    segments_sent: int = 0
    segments_completed: int = 0
    segments_failed: int = 0
    segments_429: int = 0
    cache_hits: int = 0
    batches_completed: int = 0
    avg_batch_latency_ms: float = 0.0
    current_video_id: Optional[str] = None
    segments_remaining: int = 0
    active_rpm: float = 0.0
    active_tpm: float = 0.0
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_cached_tokens: int = 0
    errors: list[str] = field(default_factory=list)


class MockDB:
    """In-memory mock for testing without Supabase."""

    def __init__(self):
        self.video_queue: list[VideoTask] = []
        self.workers: dict[str, dict] = {}
        self.results: list[dict] = []
        self.flags: list[dict] = []

    def seed_videos(self, videos: list[VideoTask]):
        self.video_queue = videos

    async def connect(self): pass
    async def close(self): pass

    async def claim_video(self, worker_id: str) -> Optional[VideoTask]:
        for v in self.video_queue:
            if v.status == "pending":
                v.status = "claimed"
                return v
        return None

    async def release_video(self, video_id: str):
        for v in self.video_queue:
            if v.video_id == video_id:
                v.status = "pending"

    async def mark_video_done(self, video_id: str):
        for v in self.video_queue:
            if v.video_id == video_id:
                v.status = "done"

    async def mark_video_failed(self, video_id: str, error: str):
        for v in self.video_queue:
            if v.video_id == video_id:
                v.status = "failed"

    async def register_worker(self, worker_id: str, provider: str, gpu_type: str, config_json: dict):
        self.workers[worker_id] = {
            "worker_id": worker_id, "status": "online", "provider": provider,
            "gpu_type": gpu_type, "config_json": config_json,
            "started_at": datetime.now(timezone.utc).isoformat(),
        }

    async def update_heartbeat(self, worker_id: str, stats: WorkerStats):
        if worker_id in self.workers:
            self.workers[worker_id].update({
                "total_segments_sent": stats.segments_sent,
                "total_segments_completed": stats.segments_completed,
                "total_segments_failed": stats.segments_failed,
                "total_segments_429": stats.segments_429,
                "total_cache_hits": stats.cache_hits,
                "batches_completed": stats.batches_completed,
                "avg_batch_latency_ms": stats.avg_batch_latency_ms,
                "current_video_id": stats.current_video_id,
                "segments_remaining": stats.segments_remaining,
                "active_rpm": stats.active_rpm,
                "active_tpm": stats.active_tpm,
                "total_input_tokens": stats.total_input_tokens,
                "total_output_tokens": stats.total_output_tokens,
                "total_cached_tokens": stats.total_cached_tokens,
                "last_heartbeat_at": datetime.now(timezone.utc).isoformat(),
            })

    async def set_worker_offline(self, worker_id: str):
        if worker_id in self.workers:
            self.workers[worker_id]["status"] = "offline"

    async def set_worker_error(self, worker_id: str, error: str):
        if worker_id in self.workers:
            self.workers[worker_id]["status"] = "error"
            self.workers[worker_id]["last_error"] = error

    async def insert_results(self, results: list[dict]):
        self.results.extend(results)

    async def insert_flags(self, flags: list[dict]):
        self.flags.extend(flags)


class PostgresDB:
    """Direct PostgreSQL via asyncpg — bypasses Supabase REST/PostgREST entirely."""

    def __init__(self, database_url: str):
        self._dsn = database_url
        self._pool = None

    async def connect(self):
        import asyncpg

        async def _init_conn(conn):
            await conn.set_type_codec(
                'jsonb', encoder=json.dumps, decoder=json.loads, schema='pg_catalog'
            )
            await conn.set_type_codec(
                'json', encoder=json.dumps, decoder=json.loads, schema='pg_catalog'
            )

        is_pooler = ":6543" in self._dsn
        self._pool = await asyncpg.create_pool(
            dsn=self._dsn,
            min_size=2,
            max_size=8,
            command_timeout=30,
            statement_cache_size=0 if is_pooler else 100,
            ssl="require",
            init=_init_conn,
        )
        logger.info(f"asyncpg pool ready (min=2, max=8, pooler={is_pooler})")

    async def close(self):
        if self._pool:
            await self._pool.close()
            logger.info("asyncpg pool closed")

    async def _retry(self, coro_fn, max_retries=3, op_name="db_op"):
        """Retry an async DB operation with exponential backoff + jitter."""
        for attempt in range(max_retries):
            try:
                return await coro_fn()
            except Exception as e:
                _db_stats["retries"] += 1
                if attempt < max_retries - 1:
                    wait = (2 ** attempt) + random.uniform(0.5, 2.0)
                    logger.warning(f"[{op_name}] attempt {attempt+1}/{max_retries} failed: {str(e)[:100]}, retry in {wait:.1f}s")
                    await asyncio.sleep(wait)
                else:
                    raise

    async def claim_video(self, worker_id: str) -> Optional[VideoTask]:
        """Atomic claim: FOR UPDATE SKIP LOCKED — no race conditions, no CAS retry."""
        for attempt in range(3):
            try:
                async with self._pool.acquire() as conn:
                    row = await conn.fetchrow("""
                        UPDATE video_queue
                        SET status = 'claimed',
                            claimed_by = $1,
                            claimed_at = now()
                        WHERE video_id = (
                            SELECT video_id FROM video_queue
                            WHERE status = 'pending'
                            LIMIT 1
                            FOR UPDATE SKIP LOCKED
                        )
                        RETURNING video_id, language, segment_count
                    """, worker_id)

                if row:
                    _db_stats["claims_ok"] += 1
                    return VideoTask(
                        video_id=row["video_id"],
                        language=row["language"] or "en",
                        segment_count=row["segment_count"] or 0,
                        status="claimed",
                    )
                return None
            except Exception as e:
                _db_stats["claims_fail"] += 1
                if attempt < 2:
                    wait = (2 ** attempt) + random.uniform(0.5, 1.5)
                    logger.warning(f"Claim failed (attempt {attempt+1}): {str(e)[:100]}, retry in {wait:.1f}s")
                    await asyncio.sleep(wait)
                else:
                    logger.error(f"Claim failed after 3 attempts: {str(e)[:100]}")
                    return None

    async def release_video(self, video_id: str):
        try:
            async with self._pool.acquire() as conn:
                await conn.execute("""
                    UPDATE video_queue
                    SET status = 'pending', claimed_by = NULL, claimed_at = NULL
                    WHERE video_id = $1
                """, video_id)
        except Exception as e:
            logger.error(f"release_video failed for {video_id}: {str(e)[:80]}")

    async def mark_video_done(self, video_id: str):
        try:
            await self._retry(
                lambda: self._exec(
                    "UPDATE video_queue SET status = 'done', completed_at = now() WHERE video_id = $1",
                    video_id
                ),
                op_name="mark_done",
            )
        except Exception as e:
            logger.error(f"mark_video_done failed for {video_id} (non-fatal): {str(e)[:80]}")

    async def mark_video_failed(self, video_id: str, error: str):
        try:
            await self._retry(
                lambda: self._exec(
                    "UPDATE video_queue SET status = 'failed', error_message = $2 WHERE video_id = $1",
                    video_id, error[:500]
                ),
                op_name="mark_failed",
            )
        except Exception as e:
            logger.error(f"mark_video_failed failed for {video_id} (non-fatal): {str(e)[:80]}")

    async def register_worker(self, worker_id: str, provider: str, gpu_type: str, config_json: dict):
        for attempt in range(5):
            try:
                async with self._pool.acquire() as conn:
                    await conn.execute("""
                        INSERT INTO workers (
                            worker_id, status, provider, gpu_type, config_json,
                            started_at, last_heartbeat_at,
                            total_segments_sent, total_segments_completed,
                            total_segments_failed, total_segments_429,
                            total_cache_hits, batches_completed
                        ) VALUES ($1, 'online', $2, $3, $4, now(), now(), 0, 0, 0, 0, 0, 0)
                        ON CONFLICT (worker_id) DO UPDATE SET
                            status = 'online', provider = $2, gpu_type = $3,
                            config_json = $4, started_at = now(), last_heartbeat_at = now(),
                            total_segments_sent = 0, total_segments_completed = 0,
                            total_segments_failed = 0, total_segments_429 = 0,
                            total_cache_hits = 0, batches_completed = 0
                    """, worker_id, provider, gpu_type, config_json)
                logger.info(f"Worker {worker_id} registered via direct PG")
                return
            except Exception as e:
                wait = 3 * (attempt + 1) + random.uniform(0, 2)
                logger.warning(f"Register worker failed (attempt {attempt+1}/5): {str(e)[:80]}, retry in {wait:.0f}s")
                await asyncio.sleep(wait)
        logger.error("Register worker failed after 5 attempts, continuing anyway")

    async def update_heartbeat(self, worker_id: str, stats: WorkerStats):
        try:
            async with self._pool.acquire() as conn:
                await conn.execute("""
                    UPDATE workers SET
                        total_segments_sent = $2, total_segments_completed = $3,
                        total_segments_failed = $4, total_segments_429 = $5,
                        total_cache_hits = $6, batches_completed = $7,
                        avg_batch_latency_ms = $8, current_video_id = $9,
                        segments_remaining = $10, active_rpm = $11, active_tpm = $12,
                        total_input_tokens = $13, total_output_tokens = $14,
                        total_cached_tokens = $15, last_heartbeat_at = now()
                    WHERE worker_id = $1
                """,
                    worker_id,
                    stats.segments_sent, stats.segments_completed,
                    stats.segments_failed, stats.segments_429,
                    stats.cache_hits, stats.batches_completed,
                    stats.avg_batch_latency_ms, stats.current_video_id,
                    stats.segments_remaining, stats.active_rpm, stats.active_tpm,
                    stats.total_input_tokens, stats.total_output_tokens,
                    stats.total_cached_tokens,
                )
            _db_stats["heartbeats_ok"] += 1
        except Exception as e:
            _db_stats["heartbeats_fail"] += 1
            logger.warning(f"Heartbeat update failed (non-fatal): {str(e)[:80]}")

    async def set_worker_offline(self, worker_id: str):
        try:
            await self._exec(
                "UPDATE workers SET status = 'offline', last_heartbeat_at = now() WHERE worker_id = $1",
                worker_id
            )
        except Exception as e:
            logger.warning(f"set_worker_offline failed: {str(e)[:80]}")

    async def set_worker_error(self, worker_id: str, error: str):
        try:
            await self._exec(
                "UPDATE workers SET status = 'error', last_error = $2, last_heartbeat_at = now() WHERE worker_id = $1",
                worker_id, error[:500]
            )
        except Exception as e:
            logger.warning(f"set_worker_error failed: {str(e)[:80]}")

    async def insert_results(self, results: list[dict]):
        if not results:
            return
        await self._batch_insert("transcription_results", results)

    async def insert_flags(self, flags: list[dict]):
        if not flags:
            return
        await self._batch_insert("transcription_flags", flags)

    async def _batch_insert(self, table: str, rows: list[dict], chunk_size: int = 50):
        """Insert rows in chunks via executemany — far more efficient than REST chunked inserts."""
        if not rows:
            return

        columns = list(rows[0].keys())
        placeholders = ", ".join(f"${i+1}" for i in range(len(columns)))
        col_names = ", ".join(columns)
        sql = f"INSERT INTO {table} ({col_names}) VALUES ({placeholders}) ON CONFLICT DO NOTHING"

        for i in range(0, len(rows), chunk_size):
            chunk = rows[i:i + chunk_size]
            values = [tuple(row.get(col) for col in columns) for row in chunk]

            for attempt in range(4):
                try:
                    async with self._pool.acquire() as conn:
                        await conn.executemany(sql, values)
                    _db_stats["inserts_ok"] += 1
                    _db_stats["inserts_rows"] += len(chunk)
                    break
                except Exception as e:
                    _db_stats["retries"] += 1
                    if attempt < 3:
                        wait = (2 ** attempt) + random.uniform(0.5, 2.0)
                        logger.warning(
                            f"Insert {table} chunk {i//chunk_size} failed "
                            f"(attempt {attempt+1}/4): {str(e)[:100]}, retry in {wait:.1f}s"
                        )
                        await asyncio.sleep(wait)
                    else:
                        _db_stats["inserts_fail"] += 1
                        logger.error(
                            f"Insert {table} chunk {i//chunk_size} failed after 4 attempts, "
                            f"skipping {len(chunk)} rows: {str(e)[:100]}"
                        )

    async def _exec(self, sql: str, *args):
        async with self._pool.acquire() as conn:
            return await conn.execute(sql, *args)


def get_db(config) -> MockDB | PostgresDB:
    if config.mock_mode:
        return MockDB()
    if not config.database_url:
        raise RuntimeError("DATABASE_URL required — direct PostgreSQL replaces Supabase REST to avoid PostgREST bottleneck")
    return PostgresDB(config.database_url)
