from __future__ import annotations

import asyncio
import json
import logging
import random
from dataclasses import dataclass, field
from typing import Any, Optional


logger = logging.getLogger(__name__)


@dataclass
class FinalExportVideoJob:
    video_id: str
    priority: int = 0
    metadata_json: dict[str, Any] = field(default_factory=dict)


@dataclass
class FinalExportMicroshardJob:
    microshard_id: str
    video_id: str
    language: str
    chunk_index: int
    row_count: int
    consumed_rows: int
    output_bucket: str
    metadata_key: str
    audio_key: str
    audio_index_key: str
    manifest_key: str
    metadata_json: dict[str, Any] = field(default_factory=dict)

    @property
    def remaining_rows(self) -> int:
        return max(self.row_count - self.consumed_rows, 0)


@dataclass
class FinalExportWorkerStats:
    jobs_claimed: int = 0
    jobs_completed: int = 0
    jobs_failed: int = 0
    rows_buffered: int = 0
    rows_uploaded: int = 0
    packs_uploaded: int = 0
    current_item: Optional[str] = None


class FinalExportPostgresDB:
    def __init__(self, database_url: str):
        self._dsn = database_url
        self._pool = None

    async def connect(self):
        import asyncpg

        async def _init_conn(conn):
            await conn.set_type_codec(
                "jsonb", encoder=json.dumps, decoder=json.loads, schema="pg_catalog"
            )
            await conn.set_type_codec(
                "json", encoder=json.dumps, decoder=json.loads, schema="pg_catalog"
            )

        is_pooler = ":6543" in self._dsn
        self._pool = await asyncpg.create_pool(
            dsn=self._dsn,
            min_size=1,
            max_size=8,
            command_timeout=60,
            statement_cache_size=0 if is_pooler else 100,
            ssl="require",
            init=_init_conn,
        )
        logger.info("final export asyncpg pool ready")

    async def close(self):
        if self._pool:
            await self._pool.close()
            logger.info("final export asyncpg pool closed")

    async def init_schema(self):
        ddl = """
        CREATE TABLE IF NOT EXISTS final_export_queue (
            video_id TEXT PRIMARY KEY,
            status TEXT NOT NULL DEFAULT 'pending',
            priority INTEGER NOT NULL DEFAULT 0,
            total_segments INTEGER NOT NULL DEFAULT 0,
            claimed_by TEXT,
            claimed_at TIMESTAMPTZ,
            spooled_at TIMESTAMPTZ,
            completed_at TIMESTAMPTZ,
            attempt_count INTEGER NOT NULL DEFAULT 0,
            error_message TEXT,
            metadata_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
        );

        CREATE INDEX IF NOT EXISTS idx_feq_claim
            ON final_export_queue (status, priority DESC, video_id);

        CREATE TABLE IF NOT EXISTS final_export_workers (
            worker_id TEXT PRIMARY KEY,
            stage TEXT NOT NULL,
            status TEXT NOT NULL DEFAULT 'online',
            gpu_type TEXT NOT NULL DEFAULT 'unknown',
            config_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            last_heartbeat_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            jobs_claimed BIGINT NOT NULL DEFAULT 0,
            jobs_completed BIGINT NOT NULL DEFAULT 0,
            jobs_failed BIGINT NOT NULL DEFAULT 0,
            rows_buffered BIGINT NOT NULL DEFAULT 0,
            rows_uploaded BIGINT NOT NULL DEFAULT 0,
            packs_uploaded BIGINT NOT NULL DEFAULT 0,
            current_item TEXT,
            last_error TEXT
        );

        CREATE TABLE IF NOT EXISTS final_export_video_outputs (
            video_id TEXT PRIMARY KEY,
            run_id TEXT NOT NULL,
            status TEXT NOT NULL,
            raw_parent_count INTEGER NOT NULL DEFAULT 0,
            replay_valid_count INTEGER NOT NULL DEFAULT 0,
            kept_count INTEGER NOT NULL DEFAULT 0,
            dropped_count INTEGER NOT NULL DEFAULT 0,
            microshard_count INTEGER NOT NULL DEFAULT 0,
            total_flac_bytes BIGINT NOT NULL DEFAULT 0,
            drop_counts_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            metadata_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
        );

        CREATE TABLE IF NOT EXISTS final_export_microshards (
            microshard_id TEXT PRIMARY KEY,
            run_id TEXT NOT NULL,
            video_id TEXT NOT NULL,
            language TEXT NOT NULL,
            chunk_index INTEGER NOT NULL DEFAULT 0,
            status TEXT NOT NULL DEFAULT 'pending',
            row_count INTEGER NOT NULL DEFAULT 0,
            consumed_rows INTEGER NOT NULL DEFAULT 0,
            output_bucket TEXT NOT NULL,
            metadata_key TEXT NOT NULL,
            audio_key TEXT NOT NULL,
            audio_index_key TEXT NOT NULL,
            manifest_key TEXT NOT NULL,
            metadata_size_bytes BIGINT NOT NULL DEFAULT 0,
            audio_size_bytes BIGINT NOT NULL DEFAULT 0,
            audio_index_size_bytes BIGINT NOT NULL DEFAULT 0,
            manifest_size_bytes BIGINT NOT NULL DEFAULT 0,
            metadata_sha256 TEXT NOT NULL DEFAULT '',
            audio_sha256 TEXT NOT NULL DEFAULT '',
            audio_index_sha256 TEXT NOT NULL DEFAULT '',
            segment_id_set_sha256 TEXT NOT NULL DEFAULT '',
            claimed_by TEXT,
            claimed_at TIMESTAMPTZ,
            compacted_at TIMESTAMPTZ,
            error_message TEXT,
            attempt_count INTEGER NOT NULL DEFAULT 0,
            metadata_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
        );

        CREATE INDEX IF NOT EXISTS idx_fem_claim
            ON final_export_microshards (status, language, created_at, microshard_id);

        CREATE TABLE IF NOT EXISTS final_export_shards (
            shard_id TEXT PRIMARY KEY,
            run_id TEXT NOT NULL,
            language TEXT NOT NULL,
            output_bucket TEXT NOT NULL,
            metadata_key TEXT NOT NULL,
            audio_key TEXT NOT NULL,
            audio_index_key TEXT NOT NULL,
            manifest_key TEXT NOT NULL,
            segment_count INTEGER NOT NULL DEFAULT 0,
            video_count INTEGER NOT NULL DEFAULT 0,
            source_microshard_count INTEGER NOT NULL DEFAULT 0,
            metadata_size_bytes BIGINT NOT NULL DEFAULT 0,
            audio_size_bytes BIGINT NOT NULL DEFAULT 0,
            audio_index_size_bytes BIGINT NOT NULL DEFAULT 0,
            manifest_size_bytes BIGINT NOT NULL DEFAULT 0,
            metadata_sha256 TEXT NOT NULL DEFAULT '',
            audio_sha256 TEXT NOT NULL DEFAULT '',
            audio_index_sha256 TEXT NOT NULL DEFAULT '',
            segment_id_set_sha256 TEXT NOT NULL DEFAULT '',
            metadata_json JSONB NOT NULL DEFAULT '{}'::jsonb,
            created_at TIMESTAMPTZ NOT NULL DEFAULT now()
        );

        CREATE INDEX IF NOT EXISTS idx_fes_language
            ON final_export_shards (language, created_at);

        CREATE TABLE IF NOT EXISTS final_export_language_leases (
            language TEXT PRIMARY KEY,
            claimed_by TEXT NOT NULL,
            claimed_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            heartbeat_at TIMESTAMPTZ NOT NULL DEFAULT now(),
            lease_expires_at TIMESTAMPTZ NOT NULL,
            updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
        );
        """
        async with self._pool.acquire() as conn:
            await conn.execute(ddl)
        logger.info("final export schema ready")

    async def seed_video_jobs(self, jobs: list[dict[str, Any]]):
        if not jobs:
            return
        columns = list(jobs[0].keys())
        placeholders = ", ".join(f"${i+1}" for i in range(len(columns)))
        col_names = ", ".join(columns)
        update_cols = ", ".join(
            f"{col}=EXCLUDED.{col}" for col in columns if col != "video_id"
        )
        sql = (
            f"INSERT INTO final_export_queue ({col_names}) VALUES ({placeholders}) "
            f"ON CONFLICT (video_id) DO UPDATE SET {update_cols}, updated_at = now()"
        )
        values = [tuple(job.get(col) for col in columns) for job in jobs]
        async with self._pool.acquire() as conn:
            await conn.executemany(sql, values)
        logger.info("seeded %s final export video jobs", len(jobs))

    async def claim_video_job(self, worker_id: str) -> Optional[FinalExportVideoJob]:
        for attempt in range(3):
            try:
                async with self._pool.acquire() as conn:
                    row = await conn.fetchrow(
                        """
                        UPDATE final_export_queue
                        SET status = 'claimed',
                            claimed_by = $1,
                            claimed_at = now(),
                            attempt_count = attempt_count + 1,
                            updated_at = now()
                        WHERE video_id = (
                            SELECT video_id
                            FROM final_export_queue
                            WHERE status = 'pending'
                            ORDER BY priority DESC, video_id
                            LIMIT 1
                            FOR UPDATE SKIP LOCKED
                        )
                        RETURNING video_id, priority, metadata_json
                        """,
                        worker_id,
                    )
                if row:
                    return FinalExportVideoJob(
                        video_id=row["video_id"],
                        priority=row["priority"] or 0,
                        metadata_json=row["metadata_json"] or {},
                    )
                return None
            except Exception as exc:
                if attempt < 2:
                    wait = (2 ** attempt) + random.uniform(0.5, 1.5)
                    logger.warning("final export claim failed: %s; retry %.1fs", str(exc)[:160], wait)
                    await asyncio.sleep(wait)
                else:
                    logger.error("final export claim failed after retries: %s", str(exc)[:160])
                    return None
        return None

    async def mark_video_processing(self, video_id: str, worker_id: str):
        await self._exec(
            """
            UPDATE final_export_queue
            SET status = 'processing',
                claimed_by = $2,
                updated_at = now()
            WHERE video_id = $1
            """,
            video_id,
            worker_id,
        )

    async def complete_video_spooled(self, video_id: str):
        await self._exec(
            """
            UPDATE final_export_queue
            SET status = 'spooled',
                spooled_at = now(),
                completed_at = now(),
                updated_at = now()
            WHERE video_id = $1
            """,
            video_id,
        )

    async def release_video_job(self, video_id: str):
        """Release a claimed/processing job back to pending (used by prefetch cancellation)."""
        await self._exec(
            """
            UPDATE final_export_queue
            SET status = 'pending',
                claimed_by = NULL,
                claimed_at = NULL,
                updated_at = now()
            WHERE video_id = $1
              AND status IN ('claimed', 'processing')
            """,
            video_id,
        )

    async def fail_video(self, video_id: str, error_message: str):
        await self._exec(
            """
            UPDATE final_export_queue
            SET status = 'failed',
                error_message = $2,
                updated_at = now()
            WHERE video_id = $1
            """,
            video_id,
            error_message[:1000],
        )

    async def insert_video_output(self, payload: dict[str, Any]):
        columns = list(payload.keys())
        placeholders = ", ".join(f"${i+1}" for i in range(len(columns)))
        col_names = ", ".join(columns)
        updates = ", ".join(
            f"{col}=EXCLUDED.{col}" for col in columns if col != "video_id"
        )
        sql = (
            f"INSERT INTO final_export_video_outputs ({col_names}) VALUES ({placeholders}) "
            f"ON CONFLICT (video_id) DO UPDATE SET {updates}, updated_at = now()"
        )
        await self._exec(sql, *[payload.get(col) for col in columns])

    async def insert_microshard(self, payload: dict[str, Any]):
        columns = list(payload.keys())
        placeholders = ", ".join(f"${i+1}" for i in range(len(columns)))
        col_names = ", ".join(columns)
        updates = ", ".join(
            f"{col}=EXCLUDED.{col}" for col in columns if col != "microshard_id"
        )
        sql = (
            f"INSERT INTO final_export_microshards ({col_names}) VALUES ({placeholders}) "
            f"ON CONFLICT (microshard_id) DO UPDATE SET {updates}, updated_at = now()"
        )
        await self._exec(sql, *[payload.get(col) for col in columns])

    async def acquire_language_lease(
        self,
        *,
        worker_id: str,
        run_id: str,
        lease_seconds: int,
        languages: list[str] | None = None,
    ) -> Optional[str]:
        args: list[Any] = []
        language_filter_sql = "AND run_id = $1"
        args.append(run_id)
        if languages:
            language_filter_sql += " AND language = ANY($2::text[])"
            args.append(languages)
        async with self._pool.acquire() as conn:
            rows = await conn.fetch(
                f"""
                WITH candidates AS (
                    SELECT language, min(created_at) AS oldest_created_at
                    FROM final_export_microshards
                    WHERE status = 'pending'
                      AND consumed_rows < row_count
                      {language_filter_sql}
                    GROUP BY language
                    ORDER BY oldest_created_at, language
                    LIMIT 20
                )
                SELECT language
                FROM candidates
                """,
                *args,
            )
            if not rows:
                return None
            for row in rows:
                language = row["language"]
                leased = await conn.fetchrow(
                    """
                    INSERT INTO final_export_language_leases (
                        language, claimed_by, claimed_at, heartbeat_at, lease_expires_at, updated_at
                    )
                    VALUES ($1, $2, now(), now(), now() + make_interval(secs => $3), now())
                    ON CONFLICT (language) DO UPDATE SET
                        claimed_by = EXCLUDED.claimed_by,
                        claimed_at = EXCLUDED.claimed_at,
                        heartbeat_at = EXCLUDED.heartbeat_at,
                        lease_expires_at = EXCLUDED.lease_expires_at,
                        updated_at = now()
                    WHERE final_export_language_leases.lease_expires_at < now()
                       OR final_export_language_leases.claimed_by = EXCLUDED.claimed_by
                    RETURNING language, claimed_by
                    """,
                    language,
                    worker_id,
                    lease_seconds,
                )
                if leased is not None:
                    return leased["language"]
            return None

    async def heartbeat_language_lease(self, language: str, worker_id: str, lease_seconds: int):
        await self._exec(
            """
            UPDATE final_export_language_leases
            SET heartbeat_at = now(),
                lease_expires_at = now() + make_interval(secs => $3),
                updated_at = now()
            WHERE language = $1
              AND claimed_by = $2
            """,
            language,
            worker_id,
            lease_seconds,
        )

    async def release_language_lease(self, language: str, worker_id: str):
        await self._exec(
            """
            DELETE FROM final_export_language_leases
            WHERE language = $1
              AND claimed_by = $2
            """,
            language,
            worker_id,
        )

    async def claim_microshards_for_language(
        self,
        *,
        worker_id: str,
        run_id: str,
        language: str,
        limit: int,
    ) -> list[FinalExportMicroshardJob]:
        async with self._pool.acquire() as conn:
            rows = await conn.fetch(
                """
                UPDATE final_export_microshards
                SET status = 'claimed',
                    claimed_by = $1,
                    claimed_at = now(),
                    attempt_count = attempt_count + 1,
                    updated_at = now()
                WHERE microshard_id IN (
                    SELECT microshard_id
                    FROM final_export_microshards
                    WHERE status = 'pending'
                      AND run_id = $2
                      AND language = $3
                      AND consumed_rows < row_count
                    ORDER BY created_at, microshard_id
                    LIMIT $4
                    FOR UPDATE SKIP LOCKED
                )
                RETURNING microshard_id, video_id, language, chunk_index, row_count, consumed_rows,
                          output_bucket, metadata_key, audio_key, audio_index_key, manifest_key, metadata_json
                """,
                worker_id,
                run_id,
                language,
                limit,
            )
        return [
            FinalExportMicroshardJob(
                microshard_id=row["microshard_id"],
                video_id=row["video_id"],
                language=row["language"],
                chunk_index=row["chunk_index"] or 0,
                row_count=row["row_count"] or 0,
                consumed_rows=row["consumed_rows"] or 0,
                output_bucket=row["output_bucket"],
                metadata_key=row["metadata_key"],
                audio_key=row["audio_key"],
                audio_index_key=row["audio_index_key"],
                manifest_key=row["manifest_key"],
                metadata_json=row["metadata_json"] or {},
            )
            for row in rows
        ]

    async def release_microshards(self, microshard_ids: list[str], worker_id: str):
        if not microshard_ids:
            return
        await self._exec(
            """
            UPDATE final_export_microshards
            SET status = 'pending',
                claimed_by = NULL,
                claimed_at = NULL,
                updated_at = now()
            WHERE microshard_id = ANY($1::text[])
              AND claimed_by = $2
            """,
            microshard_ids,
            worker_id,
        )

    async def commit_microshard_consumption(
        self,
        *,
        worker_id: str,
        consumption: dict[str, int],
    ):
        if not consumption:
            return
        async with self._pool.acquire() as conn:
            async with conn.transaction():
                for microshard_id, rows_used in consumption.items():
                    row = await conn.fetchrow(
                        """
                        SELECT row_count, consumed_rows
                        FROM final_export_microshards
                        WHERE microshard_id = $1
                          AND claimed_by = $2
                        FOR UPDATE
                        """,
                        microshard_id,
                        worker_id,
                    )
                    if row is None:
                        raise RuntimeError(f"Microshard not claimed by worker: {microshard_id}")
                    row_count = int(row["row_count"] or 0)
                    consumed_rows = int(row["consumed_rows"] or 0)
                    next_consumed = min(consumed_rows + rows_used, row_count)
                    next_status = "compacted" if next_consumed >= row_count else "claimed"
                    await conn.execute(
                        """
                        UPDATE final_export_microshards
                        SET consumed_rows = $3,
                            status = $4,
                            compacted_at = CASE WHEN $4 = 'compacted' THEN now() ELSE compacted_at END,
                            claimed_by = CASE WHEN $4 = 'compacted' THEN NULL ELSE claimed_by END,
                            claimed_at = CASE WHEN $4 = 'compacted' THEN NULL ELSE claimed_at END,
                            updated_at = now()
                        WHERE microshard_id = $1
                          AND claimed_by = $2
                        """,
                        microshard_id,
                        worker_id,
                        next_consumed,
                        next_status,
                    )

    async def fail_microshard(self, microshard_id: str, error_message: str):
        await self._exec(
            """
            UPDATE final_export_microshards
            SET status = 'failed',
                error_message = $2,
                updated_at = now()
            WHERE microshard_id = $1
            """,
            microshard_id,
            error_message[:1000],
        )

    async def insert_final_shard(self, payload: dict[str, Any]):
        columns = list(payload.keys())
        placeholders = ", ".join(f"${i+1}" for i in range(len(columns)))
        col_names = ", ".join(columns)
        sql = (
            f"INSERT INTO final_export_shards ({col_names}) VALUES ({placeholders}) "
            "ON CONFLICT (shard_id) DO NOTHING"
        )
        await self._exec(sql, *[payload.get(col) for col in columns])

    async def register_worker(self, worker_id: str, stage: str, gpu_type: str, config_json: dict):
        await self._exec(
            """
            INSERT INTO final_export_workers (
                worker_id, stage, status, gpu_type, config_json, started_at, last_heartbeat_at
            ) VALUES ($1, $2, 'online', $3, $4, now(), now())
            ON CONFLICT (worker_id) DO UPDATE SET
                stage = $2,
                status = 'online',
                gpu_type = $3,
                config_json = $4,
                started_at = now(),
                last_heartbeat_at = now(),
                jobs_claimed = 0,
                jobs_completed = 0,
                jobs_failed = 0,
                rows_buffered = 0,
                rows_uploaded = 0,
                packs_uploaded = 0,
                current_item = NULL,
                last_error = NULL
            """,
            worker_id,
            stage,
            gpu_type,
            config_json,
        )

    async def update_heartbeat(self, worker_id: str, stats: FinalExportWorkerStats):
        await self._exec(
            """
            UPDATE final_export_workers
            SET jobs_claimed = $2,
                jobs_completed = $3,
                jobs_failed = $4,
                rows_buffered = $5,
                rows_uploaded = $6,
                packs_uploaded = $7,
                current_item = $8,
                last_heartbeat_at = now()
            WHERE worker_id = $1
            """,
            worker_id,
            stats.jobs_claimed,
            stats.jobs_completed,
            stats.jobs_failed,
            stats.rows_buffered,
            stats.rows_uploaded,
            stats.packs_uploaded,
            stats.current_item,
        )

    async def set_worker_offline(self, worker_id: str):
        await self._exec(
            """
            UPDATE final_export_workers
            SET status = 'offline', last_heartbeat_at = now()
            WHERE worker_id = $1
            """,
            worker_id,
        )

    async def set_worker_error(self, worker_id: str, error_message: str):
        await self._exec(
            """
            UPDATE final_export_workers
            SET status = 'error',
                last_error = $2,
                last_heartbeat_at = now()
            WHERE worker_id = $1
            """,
            worker_id,
            error_message[:1000],
        )

    async def reset_stale_claims(self, stale_after_s: int):
        await self._exec(
            """
            UPDATE final_export_queue
            SET status = 'pending',
                claimed_by = NULL,
                claimed_at = NULL,
                updated_at = now()
            WHERE status IN ('claimed', 'processing')
              AND claimed_at < now() - make_interval(secs => $1)
            """,
            stale_after_s,
        )
        await self._exec(
            """
            UPDATE final_export_microshards
            SET status = 'pending',
                claimed_by = NULL,
                claimed_at = NULL,
                updated_at = now()
            WHERE status = 'claimed'
              AND claimed_at < now() - make_interval(secs => $1)
            """,
            stale_after_s,
        )
        await self._exec(
            """
            DELETE FROM final_export_language_leases
            WHERE lease_expires_at < now()
            """,
        )

    async def _exec(self, sql: str, *args):
        async with self._pool.acquire() as conn:
            return await conn.execute(sql, *args)
