"""Supabase orchestration client: video lifecycle, worker heartbeats, shard tracking.

Single source of truth for which worker owns which video, encoding status,
and analytics. Uses the admin key for REST and direct Postgres for DDL.
"""

from __future__ import annotations

import logging
import os
import time
from datetime import datetime, timezone
from typing import Any
from urllib.parse import urlparse

import psycopg2
from supabase import create_client, Client

from codecbench.pipeline.config import SupabaseConfig

logger = logging.getLogger(__name__)


# ── Video lifecycle states ──────────────────────────────────────────────
# PENDING → CLAIMED → DOWNLOADING → PROCESSING → ENCODED → PACKED → DONE
#   └─ FAILED (from any state, with error detail)
#   └─ TIMEOUT (claimed but no heartbeat, released for re-claim)
VIDEO_STATES = [
    "PENDING", "CLAIMED", "DOWNLOADING", "PROCESSING",
    "ENCODED", "PACKED", "DONE", "FAILED", "TIMEOUT",
]


def _get_pg_conn():
    """Direct Postgres connection for DDL operations."""
    db_url = os.getenv("DATABASE_URL", "")
    if not db_url:
        raise RuntimeError("DATABASE_URL not set — needed for table creation")
    return psycopg2.connect(db_url)


class SupabaseOrchestrator:
    def __init__(self, cfg: SupabaseConfig):
        self._cfg = cfg
        self._client: Client = create_client(cfg.url, cfg.admin_key)
        self._tables_ready = False

    def _exec_sql(self, sql: str) -> None:
        """Execute DDL SQL directly via Postgres connection."""
        conn = _get_pg_conn()
        try:
            conn.autocommit = True
            with conn.cursor() as cur:
                cur.execute(sql)
        finally:
            conn.close()

    # ── Schema management ───────────────────────────────────────────────

    def ensure_tables(self) -> None:
        """Create tables via direct Postgres DDL."""
        if self._tables_ready:
            return

        self._exec_sql(f"""
        CREATE TABLE IF NOT EXISTS {self._cfg.videos_table} (
            video_id text PRIMARY KEY,
            title text,
            duration_min real,
            classification text,
            channel text,
            language text,
            source_bucket text,
            status text DEFAULT 'PENDING',
            claimed_by text,
            claimed_at timestamptz,
            audio_duration_s real,
            num_segments integer,
            usable_audio_s real,
            codecs_used text[],
            shard_id text,
            error_detail text,
            started_at timestamptz,
            finished_at timestamptz,
            updated_at timestamptz DEFAULT now()
        );

        CREATE INDEX IF NOT EXISTS idx_{self._cfg.videos_table}_status
            ON {self._cfg.videos_table} (status);
        CREATE INDEX IF NOT EXISTS idx_{self._cfg.videos_table}_language
            ON {self._cfg.videos_table} (language);
        CREATE INDEX IF NOT EXISTS idx_{self._cfg.videos_table}_claimed
            ON {self._cfg.videos_table} (claimed_by);
        CREATE INDEX IF NOT EXISTS idx_{self._cfg.videos_table}_shard
            ON {self._cfg.videos_table} (shard_id);
        """)

        self._exec_sql(f"""
        CREATE TABLE IF NOT EXISTS {self._cfg.workers_table} (
            worker_id text PRIMARY KEY,
            offer_id text,
            gpu_name text,
            ip_address text,
            status text DEFAULT 'STARTING',
            current_video_id text,
            current_stage text,
            -- Performance
            rtf real,
            avg_encode_rtf real,
            videos_per_hour real,
            -- Cumulative progress
            total_audio_processed_s real DEFAULT 0,
            total_videos_done integer DEFAULT 0,
            total_videos_failed integer DEFAULT 0,
            total_shards_produced integer DEFAULT 0,
            shard_buffer_count integer DEFAULT 0,
            -- Lifecycle
            last_heartbeat timestamptz DEFAULT now(),
            started_at timestamptz DEFAULT now(),
            stopped_at timestamptz,
            stop_reason text,
            uptime_s real DEFAULT 0,
            -- Errors
            error_detail text,
            error_count integer DEFAULT 0
        );

        CREATE INDEX IF NOT EXISTS idx_{self._cfg.workers_table}_status
            ON {self._cfg.workers_table} (status);
        """)

        self._exec_sql(f"""
        CREATE TABLE IF NOT EXISTS {self._cfg.shards_table} (
            shard_id text PRIMARY KEY,
            worker_id text,
            video_ids text[],
            languages text[],
            total_segments integer,
            total_audio_s real,
            codecs text[],
            r2_key text,
            size_bytes bigint,
            created_at timestamptz DEFAULT now()
        );
        """)

        self._tables_ready = True
        logger.info("Supabase tables verified/created")

    # ── Bulk video ingestion ────────────────────────────────────────────

    def ingest_videos(self, rows: list[dict[str, Any]], batch_size: int = 500) -> int:
        """Bulk upsert video metadata into the videos table.

        Returns number of rows ingested.
        """
        table = self._cfg.videos_table
        total = 0
        for i in range(0, len(rows), batch_size):
            batch = rows[i : i + batch_size]
            self._client.table(table).upsert(batch, on_conflict="video_id").execute()
            total += len(batch)
            if total % 5000 == 0:
                logger.info("Ingested %d / %d videos", total, len(rows))
        logger.info("Ingested %d videos total", total)
        return total

    # ── Video claiming (atomic) ─────────────────────────────────────────

    def claim_video(self, worker_id: str, language: str | None = None) -> dict | None:
        """Atomically claim the next PENDING video for this worker.

        Optionally filter by language for grouped processing.
        Returns the video row dict, or None if nothing available.
        """
        try:
            result = self._client.rpc("claim_next_video", {
                "p_worker_id": worker_id,
                "p_language": language,
            }).execute()

            if result.data and len(result.data) > 0:
                video = result.data[0] if isinstance(result.data, list) else result.data
                if video and video.get("video_id"):
                    logger.info("Claimed video %s (%s)", video["video_id"], video.get("language"))
                    return video
            return None
        except Exception as e:
            logger.error("claim_video RPC failed: %s", e)
            return self._claim_video_fallback(worker_id, language)

    def _claim_video_fallback(self, worker_id: str, language: str | None = None) -> dict | None:
        """Non-atomic fallback claim — uses select+update (race possible but tolerable)."""
        query = (
            self._client.table(self._cfg.videos_table)
            .select("*")
            .eq("status", "PENDING")
            .limit(1)
        )
        if language:
            query = query.eq("language", language)
        result = query.execute()

        if not result.data:
            return None

        video = result.data[0]
        now = datetime.now(timezone.utc).isoformat()
        self._client.table(self._cfg.videos_table).update({
            "status": "CLAIMED",
            "claimed_by": worker_id,
            "claimed_at": now,
            "updated_at": now,
        }).eq("video_id", video["video_id"]).eq("status", "PENDING").execute()

        return video

    # ── Video status updates ────────────────────────────────────────────

    def update_video_status(
        self,
        video_id: str,
        status: str,
        extra: dict[str, Any] | None = None,
    ) -> None:
        """Update video status + optional metadata fields."""
        payload: dict[str, Any] = {
            "status": status,
            "updated_at": datetime.now(timezone.utc).isoformat(),
        }
        if extra:
            payload.update(extra)
        self._client.table(self._cfg.videos_table).update(payload).eq(
            "video_id", video_id
        ).execute()

    def mark_video_done(self, video_id: str, shard_id: str, codecs: list[str]) -> None:
        self.update_video_status(video_id, "DONE", {
            "shard_id": shard_id,
            "codecs_used": codecs,
            "finished_at": datetime.now(timezone.utc).isoformat(),
        })

    def mark_video_failed(self, video_id: str, error: str) -> None:
        self.update_video_status(video_id, "FAILED", {
            "error_detail": error[:2000],
            "finished_at": datetime.now(timezone.utc).isoformat(),
        })

    # ── Worker lifecycle ────────────────────────────────────────────────

    def register_worker(self, worker_id: str, offer_id: str, gpu_name: str) -> None:
        now = datetime.now(timezone.utc).isoformat()
        ip = self._get_public_ip()
        self._client.table(self._cfg.workers_table).upsert({
            "worker_id": worker_id,
            "offer_id": offer_id,
            "gpu_name": gpu_name,
            "ip_address": ip,
            "status": "LOADING_MODELS",
            "last_heartbeat": now,
            "started_at": now,
            "stopped_at": None,
            "stop_reason": None,
            "total_audio_processed_s": 0,
            "total_videos_done": 0,
            "total_videos_failed": 0,
            "total_shards_produced": 0,
            "error_count": 0,
        }, on_conflict="worker_id").execute()
        logger.info("Registered worker %s (offer=%s, gpu=%s, ip=%s)", worker_id, offer_id, gpu_name, ip)

    def mark_worker_alive(self, worker_id: str) -> None:
        """Transition from LOADING_MODELS → ALIVE once codecs are loaded."""
        self._client.table(self._cfg.workers_table).update({
            "status": "ALIVE",
            "last_heartbeat": datetime.now(timezone.utc).isoformat(),
        }).eq("worker_id", worker_id).execute()

    def heartbeat(
        self,
        worker_id: str,
        current_video: str | None = None,
        current_stage: str | None = None,
        rtf: float | None = None,
        avg_encode_rtf: float | None = None,
        total_audio_s: float = 0,
        total_videos: int = 0,
        total_failed: int = 0,
        total_shards: int = 0,
        shard_buffer_count: int = 0,
        uptime_s: float = 0,
    ) -> None:
        now = datetime.now(timezone.utc).isoformat()
        videos_per_hour = (total_videos / max(uptime_s, 1)) * 3600 if uptime_s > 0 else 0.0
        payload: dict[str, Any] = {
            "status": "ALIVE",
            "last_heartbeat": now,
            "total_audio_processed_s": total_audio_s,
            "total_videos_done": total_videos,
            "total_videos_failed": total_failed,
            "total_shards_produced": total_shards,
            "shard_buffer_count": shard_buffer_count,
            "videos_per_hour": round(videos_per_hour, 1),
            "uptime_s": round(uptime_s, 0),
        }
        if current_video is not None:
            payload["current_video_id"] = current_video
        if current_stage is not None:
            payload["current_stage"] = current_stage
        if rtf is not None:
            payload["rtf"] = round(rtf, 1)
        if avg_encode_rtf is not None:
            payload["avg_encode_rtf"] = round(avg_encode_rtf, 1)
        self._client.table(self._cfg.workers_table).update(payload).eq(
            "worker_id", worker_id
        ).execute()

    def report_worker_error(self, worker_id: str, error: str) -> None:
        self._client.table(self._cfg.workers_table).update({
            "error_detail": error[:2000],
            "error_count": self._get_error_count(worker_id) + 1,
            "last_heartbeat": datetime.now(timezone.utc).isoformat(),
        }).eq("worker_id", worker_id).execute()

    def mark_worker_stopped(self, worker_id: str, reason: str = "graceful_shutdown") -> None:
        """Mark worker as cleanly stopped."""
        now = datetime.now(timezone.utc).isoformat()
        self._client.table(self._cfg.workers_table).update({
            "status": "STOPPED",
            "stopped_at": now,
            "stop_reason": reason,
            "current_video_id": None,
            "current_stage": None,
            "last_heartbeat": now,
        }).eq("worker_id", worker_id).execute()

    def mark_worker_dead(self, worker_id: str) -> None:
        """Mark worker as dead (crash or forced kill)."""
        self._client.table(self._cfg.workers_table).update({
            "status": "DEAD",
            "stopped_at": datetime.now(timezone.utc).isoformat(),
            "stop_reason": "crash_or_kill",
            "last_heartbeat": datetime.now(timezone.utc).isoformat(),
        }).eq("worker_id", worker_id).execute()

    def _get_error_count(self, worker_id: str) -> int:
        try:
            result = (
                self._client.table(self._cfg.workers_table)
                .select("error_count")
                .eq("worker_id", worker_id)
                .limit(1)
                .execute()
            )
            if result.data:
                return result.data[0].get("error_count", 0) or 0
        except Exception:
            pass
        return 0

    @staticmethod
    def _get_public_ip() -> str:
        import urllib.request
        try:
            return urllib.request.urlopen("https://api.ipify.org", timeout=5).read().decode().strip()
        except Exception:
            return "unknown"

    # ── Shard tracking ──────────────────────────────────────────────────

    def register_shard(
        self,
        shard_id: str,
        worker_id: str,
        video_ids: list[str],
        languages: list[str],
        total_segments: int,
        total_audio_s: float,
        codecs: list[str],
        r2_key: str,
        size_bytes: int,
    ) -> None:
        self._client.table(self._cfg.shards_table).upsert({
            "shard_id": shard_id,
            "worker_id": worker_id,
            "video_ids": video_ids,
            "languages": languages,
            "total_segments": total_segments,
            "total_audio_s": total_audio_s,
            "codecs": codecs,
            "r2_key": r2_key,
            "size_bytes": size_bytes,
        }, on_conflict="shard_id").execute()
        logger.info("Registered shard %s (%d videos, %.0f s audio)", shard_id, len(video_ids), total_audio_s)

    # ── Recovery: release stale claims ──────────────────────────────────

    def release_stale_claims(self, timeout_s: int | None = None) -> int:
        """Release videos claimed by dead/timed-out workers back to PENDING."""
        if timeout_s is None:
            timeout_s = self._cfg.claim_timeout_s
        try:
            result = self._client.rpc("release_stale_claims", {
                "p_timeout_s": timeout_s,
            }).execute()
            released = result.data if isinstance(result.data, int) else 0
            if released:
                logger.info("Released %d stale video claims", released)
            return released
        except Exception as e:
            logger.warning("release_stale_claims RPC failed: %s", e)
            return 0

    def recover_worker_videos(self, worker_id: str) -> list[str]:
        """On restart, find videos claimed by this worker that aren't DONE/FAILED."""
        result = (
            self._client.table(self._cfg.videos_table)
            .select("video_id, status")
            .eq("claimed_by", worker_id)
            .in_("status", ["CLAIMED", "DOWNLOADING", "PROCESSING", "ENCODED"])
            .execute()
        )
        video_ids = [r["video_id"] for r in (result.data or [])]
        if video_ids:
            logger.info("Recovery: found %d in-progress videos for worker %s", len(video_ids), worker_id)
        return video_ids

    # ── Stats ───────────────────────────────────────────────────────────

    def get_stats(self) -> dict[str, Any]:
        """Get pipeline-wide progress stats."""
        stats = {}
        for status in VIDEO_STATES:
            try:
                result = (
                    self._client.table(self._cfg.videos_table)
                    .select("video_id", count="exact")
                    .eq("status", status)
                    .execute()
                )
                stats[status] = result.count or 0
            except Exception:
                stats[status] = -1
        return stats

    def create_claim_rpc(self) -> None:
        """Create the atomic claim_next_video RPC function.

        Uses FOR UPDATE SKIP LOCKED to guarantee no two workers
        get the same video — essential for multi-worker deployments.
        """
        table = self._cfg.videos_table
        sql = f"""
        CREATE OR REPLACE FUNCTION claim_next_video(p_worker_id text, p_language text DEFAULT NULL)
        RETURNS SETOF {table} AS $$
        DECLARE
            v_row {table};
        BEGIN
            IF p_language IS NOT NULL THEN
                SELECT * INTO v_row FROM {table}
                WHERE status = 'PENDING' AND language = p_language
                LIMIT 1
                FOR UPDATE SKIP LOCKED;
            ELSE
                SELECT * INTO v_row FROM {table}
                WHERE status = 'PENDING'
                LIMIT 1
                FOR UPDATE SKIP LOCKED;
            END IF;

            IF v_row IS NULL THEN
                RETURN;
            END IF;

            UPDATE {table} SET
                status = 'CLAIMED',
                claimed_by = p_worker_id,
                claimed_at = now(),
                updated_at = now()
            WHERE video_id = v_row.video_id;

            v_row.status := 'CLAIMED';
            v_row.claimed_by := p_worker_id;
            RETURN NEXT v_row;
        END;
        $$ LANGUAGE plpgsql;
        """
        try:
            self._exec_sql(sql)
            logger.info("Created claim_next_video RPC")
        except Exception as e:
            logger.warning("Could not create claim RPC (may already exist): %s", e)

    def create_release_stale_rpc(self) -> None:
        """Create RPC to release stale claims back to PENDING."""
        table = self._cfg.videos_table
        sql = f"""
        CREATE OR REPLACE FUNCTION release_stale_claims(p_timeout_s integer DEFAULT 600)
        RETURNS integer AS $$
        DECLARE
            released integer;
        BEGIN
            WITH stale AS (
                SELECT video_id FROM {table}
                WHERE status IN ('CLAIMED', 'DOWNLOADING', 'PROCESSING')
                AND updated_at < now() - (p_timeout_s || ' seconds')::interval
            )
            UPDATE {table} SET
                status = 'TIMEOUT',
                updated_at = now()
            FROM stale WHERE {table}.video_id = stale.video_id;
            GET DIAGNOSTICS released = ROW_COUNT;

            UPDATE {table} SET status = 'PENDING', claimed_by = NULL, claimed_at = NULL
            WHERE status = 'TIMEOUT';

            RETURN released;
        END;
        $$ LANGUAGE plpgsql;
        """
        try:
            self._exec_sql(sql)
            logger.info("Created release_stale_claims RPC")
        except Exception as e:
            logger.warning("Could not create stale release RPC: %s", e)
