"""
Database layer for transcript-variant shard workers.

Uses direct PostgreSQL via asyncpg, mirroring the main transcription pipeline's
claim/heartbeat style but with dedicated tables.
"""
from __future__ import annotations

import asyncio
import json
import logging
import random
from dataclasses import dataclass, field
from typing import Any, Optional


logger = logging.getLogger(__name__)


@dataclass
class VariantJob:
    shard_id: str
    input_bucket: str
    input_r2_key: str
    input_format: str
    output_bucket: str
    output_prefix: str
    total_rows: int
    metadata_json: dict[str, Any] = field(default_factory=dict)


@dataclass
class VariantWorkerStats:
    jobs_claimed: int = 0
    jobs_completed: int = 0
    jobs_failed: int = 0
    rows_processed: int = 0
    rows_skipped: int = 0
    rows_gemini: int = 0
    requests_sent: int = 0
    requests_succeeded: int = 0
    requests_failed: int = 0
    cache_hits: int = 0
    packs_uploaded: int = 0
    current_shard_id: Optional[str] = None
    rows_remaining: int = 0
    active_rpm: float = 0.0
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_cached_tokens: int = 0


class VariantPostgresDB:
    def __init__(self, database_url: str):
        self._dsn = database_url
        self._pool = None

    async def connect(self):
        import asyncpg

        async def _init_conn(conn):
            await conn.set_type_codec(
                "jsonb", encoder=json.dumps, decoder=json.loads, schema="pg_catalog"
            )
            await conn.set_type_codec(
                "json", encoder=json.dumps, decoder=json.loads, schema="pg_catalog"
            )

        is_pooler = ":6543" in self._dsn
        self._pool = await asyncpg.create_pool(
            dsn=self._dsn,
            min_size=1,
            max_size=6,
            command_timeout=30,
            statement_cache_size=0 if is_pooler else 100,
            ssl="require",
            init=_init_conn,
        )
        logger.info("variant asyncpg pool ready")

    async def close(self):
        if self._pool:
            await self._pool.close()
            logger.info("variant asyncpg pool closed")

    async def init_schema(self):
        ddl = """
        CREATE TABLE IF NOT EXISTS transcript_variant_job_queue (
            shard_id TEXT PRIMARY KEY,
            status TEXT NOT NULL DEFAULT 'pending',
            input_bucket TEXT NOT NULL,
            input_r2_key TEXT NOT NULL,
            input_format TEXT NOT NULL DEFAULT 'parquet',
            output_bucket TEXT NOT NULL,
            output_prefix TEXT NOT NULL,
            total_rows INTEGER NOT NULL DEFAULT 0,
            rows_processed INTEGER NOT NULL DEFAULT 0,
            rows_skipped INTEGER NOT NULL DEFAULT 0,
            rows_gemini INTEGER NOT NULL DEFAULT 0,
            packs_uploaded INTEGER NOT NULL DEFAULT 0,
            last_pack_key TEXT,
            claimed_by TEXT,
            claimed_at TIMESTAMPTZ,
            completed_at TIMESTAMPTZ,
            error_message TEXT,
            attempt_count INTEGER NOT NULL DEFAULT 0,
            metadata_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
        );

        CREATE INDEX IF NOT EXISTS idx_tvjq_claim
            ON transcript_variant_job_queue (status, shard_id);

        CREATE TABLE IF NOT EXISTS transcript_variant_workers (
            worker_id TEXT PRIMARY KEY,
            status TEXT NOT NULL DEFAULT 'online',
            provider TEXT NOT NULL,
            gpu_type TEXT NOT NULL DEFAULT 'unknown',
            config_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            last_heartbeat_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            jobs_claimed INTEGER NOT NULL DEFAULT 0,
            jobs_completed INTEGER NOT NULL DEFAULT 0,
            jobs_failed INTEGER NOT NULL DEFAULT 0,
            rows_processed BIGINT NOT NULL DEFAULT 0,
            rows_skipped BIGINT NOT NULL DEFAULT 0,
            rows_gemini BIGINT NOT NULL DEFAULT 0,
            requests_sent BIGINT NOT NULL DEFAULT 0,
            requests_succeeded BIGINT NOT NULL DEFAULT 0,
            requests_failed BIGINT NOT NULL DEFAULT 0,
            total_cache_hits BIGINT NOT NULL DEFAULT 0,
            packs_uploaded BIGINT NOT NULL DEFAULT 0,
            current_shard_id TEXT,
            rows_remaining BIGINT NOT NULL DEFAULT 0,
            active_rpm DOUBLE PRECISION NOT NULL DEFAULT 0,
            total_input_tokens BIGINT NOT NULL DEFAULT 0,
            total_output_tokens BIGINT NOT NULL DEFAULT 0,
            total_cached_tokens BIGINT NOT NULL DEFAULT 0,
            last_error TEXT
        );

        CREATE TABLE IF NOT EXISTS transcript_variant_pack_manifests (
            pack_id TEXT PRIMARY KEY,
            shard_id TEXT NOT NULL,
            worker_id TEXT NOT NULL,
            output_bucket TEXT NOT NULL,
            output_key TEXT NOT NULL,
            row_count INTEGER NOT NULL DEFAULT 0,
            gemini_row_count INTEGER NOT NULL DEFAULT 0,
            skipped_row_count INTEGER NOT NULL DEFAULT 0,
            distinct_video_count INTEGER NOT NULL DEFAULT 0,
            byte_size BIGINT NOT NULL DEFAULT 0,
            metadata_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            created_at TIMESTAMPTZ NOT NULL DEFAULT now()
        );

        CREATE INDEX IF NOT EXISTS idx_tvpm_shard
            ON transcript_variant_pack_manifests (shard_id, created_at);
        """
        async with self._pool.acquire() as conn:
            await conn.execute(ddl)
        logger.info("transcript variant schema ready")

    async def seed_jobs(self, jobs: list[dict[str, Any]]):
        if not jobs:
            return
        columns = list(jobs[0].keys())
        placeholders = ", ".join(f"${i+1}" for i in range(len(columns)))
        col_names = ", ".join(columns)
        update_cols = ", ".join(
            f"{col}=EXCLUDED.{col}" for col in columns if col != "shard_id"
        )
        sql = (
            f"INSERT INTO transcript_variant_job_queue ({col_names}) VALUES ({placeholders}) "
            f"ON CONFLICT (shard_id) DO UPDATE SET {update_cols}, updated_at = now()"
        )
        values = [tuple(job.get(col) for col in columns) for job in jobs]
        async with self._pool.acquire() as conn:
            await conn.executemany(sql, values)
        logger.info("seeded %s transcript variant jobs", len(jobs))

    async def claim_job(self, worker_id: str) -> Optional[VariantJob]:
        for attempt in range(3):
            try:
                async with self._pool.acquire() as conn:
                    row = await conn.fetchrow(
                        """
                        UPDATE transcript_variant_job_queue
                        SET status = 'claimed',
                            claimed_by = $1,
                            claimed_at = now(),
                            attempt_count = attempt_count + 1,
                            updated_at = now()
                        WHERE shard_id = (
                            SELECT shard_id
                            FROM transcript_variant_job_queue
                            WHERE status = 'pending'
                            ORDER BY shard_id
                            LIMIT 1
                            FOR UPDATE SKIP LOCKED
                        )
                        RETURNING shard_id, input_bucket, input_r2_key, input_format,
                                  output_bucket, output_prefix, total_rows, metadata_json
                        """,
                        worker_id,
                    )
                if row:
                    return VariantJob(
                        shard_id=row["shard_id"],
                        input_bucket=row["input_bucket"],
                        input_r2_key=row["input_r2_key"],
                        input_format=row["input_format"],
                        output_bucket=row["output_bucket"],
                        output_prefix=row["output_prefix"],
                        total_rows=row["total_rows"] or 0,
                        metadata_json=row["metadata_json"] or {},
                    )
                return None
            except Exception as e:
                if attempt < 2:
                    wait = (2 ** attempt) + random.uniform(0.5, 1.5)
                    logger.warning(
                        "variant claim failed (attempt %s): %s, retry in %.1fs",
                        attempt + 1,
                        str(e)[:120],
                        wait,
                    )
                    await asyncio.sleep(wait)
                else:
                    logger.error("variant claim failed after retries: %s", str(e)[:120])
                    return None

    async def update_job_progress(
        self,
        shard_id: str,
        *,
        rows_processed: int,
        rows_skipped: int,
        rows_gemini: int,
        packs_uploaded: int,
        last_pack_key: str = "",
    ):
        await self._exec(
            """
            UPDATE transcript_variant_job_queue
            SET rows_processed = $2,
                rows_skipped = $3,
                rows_gemini = $4,
                packs_uploaded = $5,
                last_pack_key = CASE WHEN $6 = '' THEN last_pack_key ELSE $6 END,
                updated_at = now()
            WHERE shard_id = $1
            """,
            shard_id,
            rows_processed,
            rows_skipped,
            rows_gemini,
            packs_uploaded,
            last_pack_key,
        )

    async def complete_job(
        self,
        shard_id: str,
        *,
        rows_processed: int,
        rows_skipped: int,
        rows_gemini: int,
        packs_uploaded: int,
        last_pack_key: str = "",
    ):
        await self._exec(
            """
            UPDATE transcript_variant_job_queue
            SET status = 'done',
                rows_processed = $2,
                rows_skipped = $3,
                rows_gemini = $4,
                packs_uploaded = $5,
                last_pack_key = CASE WHEN $6 = '' THEN last_pack_key ELSE $6 END,
                completed_at = now(),
                updated_at = now()
            WHERE shard_id = $1
            """,
            shard_id,
            rows_processed,
            rows_skipped,
            rows_gemini,
            packs_uploaded,
            last_pack_key,
        )

    async def fail_job(self, shard_id: str, error_message: str):
        await self._exec(
            """
            UPDATE transcript_variant_job_queue
            SET status = 'failed',
                error_message = $2,
                updated_at = now()
            WHERE shard_id = $1
            """,
            shard_id,
            error_message[:1000],
        )

    async def register_worker(self, worker_id: str, provider: str, gpu_type: str, config_json: dict):
        await self._exec(
            """
            INSERT INTO transcript_variant_workers (
                worker_id, status, provider, gpu_type, config_json, started_at, last_heartbeat_at
            ) VALUES ($1, 'online', $2, $3, $4, now(), now())
            ON CONFLICT (worker_id) DO UPDATE SET
                status = 'online',
                provider = $2,
                gpu_type = $3,
                config_json = $4,
                started_at = now(),
                last_heartbeat_at = now(),
                jobs_claimed = 0,
                jobs_completed = 0,
                jobs_failed = 0,
                rows_processed = 0,
                rows_skipped = 0,
                rows_gemini = 0,
                requests_sent = 0,
                requests_succeeded = 0,
                requests_failed = 0,
                total_cache_hits = 0,
                packs_uploaded = 0,
                current_shard_id = NULL,
                rows_remaining = 0,
                active_rpm = 0,
                total_input_tokens = 0,
                total_output_tokens = 0,
                total_cached_tokens = 0,
                last_error = NULL
            """,
            worker_id,
            provider,
            gpu_type,
            config_json,
        )

    async def update_heartbeat(self, worker_id: str, stats: VariantWorkerStats):
        await self._exec(
            """
            UPDATE transcript_variant_workers
            SET jobs_claimed = $2,
                jobs_completed = $3,
                jobs_failed = $4,
                rows_processed = $5,
                rows_skipped = $6,
                rows_gemini = $7,
                requests_sent = $8,
                requests_succeeded = $9,
                requests_failed = $10,
                total_cache_hits = $11,
                packs_uploaded = $12,
                current_shard_id = $13,
                rows_remaining = $14,
                active_rpm = $15,
                total_input_tokens = $16,
                total_output_tokens = $17,
                total_cached_tokens = $18,
                last_heartbeat_at = now()
            WHERE worker_id = $1
            """,
            worker_id,
            stats.jobs_claimed,
            stats.jobs_completed,
            stats.jobs_failed,
            stats.rows_processed,
            stats.rows_skipped,
            stats.rows_gemini,
            stats.requests_sent,
            stats.requests_succeeded,
            stats.requests_failed,
            stats.cache_hits,
            stats.packs_uploaded,
            stats.current_shard_id,
            stats.rows_remaining,
            stats.active_rpm,
            stats.total_input_tokens,
            stats.total_output_tokens,
            stats.total_cached_tokens,
        )

    async def set_worker_offline(self, worker_id: str):
        await self._exec(
            """
            UPDATE transcript_variant_workers
            SET status = 'offline', last_heartbeat_at = now()
            WHERE worker_id = $1
            """,
            worker_id,
        )

    async def set_worker_error(self, worker_id: str, error_message: str):
        await self._exec(
            """
            UPDATE transcript_variant_workers
            SET status = 'error',
                last_error = $2,
                last_heartbeat_at = now()
            WHERE worker_id = $1
            """,
            worker_id,
            error_message[:1000],
        )

    async def insert_pack_manifest(self, manifest: dict[str, Any]):
        columns = list(manifest.keys())
        placeholders = ", ".join(f"${i+1}" for i in range(len(columns)))
        col_names = ", ".join(columns)
        sql = (
            f"INSERT INTO transcript_variant_pack_manifests ({col_names}) VALUES ({placeholders}) "
            "ON CONFLICT (pack_id) DO NOTHING"
        )
        await self._exec(sql, *[manifest.get(col) for col in columns])

    async def _exec(self, sql: str, *args):
        async with self._pool.acquire() as conn:
            return await conn.execute(sql, *args)
