"""Supabase orchestrator for SFT shard encoding pipeline.

Table: sft_encoding_shards — one row per audio shard to process.
Uses the same atomic FOR UPDATE SKIP LOCKED claim pattern as the
pretraining pipeline to allow hundreds of GPU workers.
"""

from __future__ import annotations

import logging
import os
from datetime import datetime, timezone
from typing import Any

import psycopg2
from dotenv import load_dotenv

load_dotenv()

logger = logging.getLogger(__name__)

TABLE = "sft_encoding_shards"

SHARD_STATES = [
    "PENDING", "CLAIMED", "PROCESSING", "DONE", "FAILED",
]


def _get_pg_conn():
    db_url = os.getenv("DATABASE_URL", "")
    if not db_url:
        raise RuntimeError("DATABASE_URL not set")
    return psycopg2.connect(db_url)


class SFTOrchestrator:
    """Manages sft_encoding_shards table for distributed shard processing."""

    def __init__(self):
        pass

    def _exec_sql(self, sql: str, params: tuple = ()) -> None:
        conn = _get_pg_conn()
        try:
            conn.autocommit = True
            with conn.cursor() as cur:
                cur.execute(sql, params)
        finally:
            conn.close()

    def _query(self, sql: str, params: tuple = ()) -> list[dict]:
        conn = _get_pg_conn()
        try:
            with conn.cursor() as cur:
                cur.execute(sql, params)
                cols = [d[0] for d in cur.description]
                return [dict(zip(cols, row)) for row in cur.fetchall()]
        finally:
            conn.close()

    def ensure_tables(self) -> None:
        self._exec_sql(f"""
        CREATE TABLE IF NOT EXISTS {TABLE} (
            shard_key text PRIMARY KEY,
            dataset text NOT NULL,
            language text NOT NULL,
            segment_count integer,
            status text DEFAULT 'PENDING',
            claimed_by text,
            claimed_at timestamptz,
            segments_encoded integer,
            segments_failed integer,
            total_audio_s real,
            total_encode_ms real,
            output_r2_key text,
            error_detail text,
            started_at timestamptz,
            finished_at timestamptz,
            updated_at timestamptz DEFAULT now()
        );

        CREATE INDEX IF NOT EXISTS idx_{TABLE}_status ON {TABLE} (status);
        CREATE INDEX IF NOT EXISTS idx_{TABLE}_dataset ON {TABLE} (dataset);
        CREATE INDEX IF NOT EXISTS idx_{TABLE}_language ON {TABLE} (language);
        """)
        logger.info("SFT table %s verified/created", TABLE)

    def create_claim_rpc(self) -> None:
        """Atomic claim using FOR UPDATE SKIP LOCKED."""
        self._exec_sql(f"""
        CREATE OR REPLACE FUNCTION claim_next_sft_shard(
            p_worker_id text,
            p_dataset text DEFAULT NULL,
            p_language text DEFAULT NULL
        )
        RETURNS SETOF {TABLE} AS $$
        DECLARE
            v_row {TABLE};
        BEGIN
            IF p_dataset IS NOT NULL AND p_language IS NOT NULL THEN
                SELECT * INTO v_row FROM {TABLE}
                WHERE status = 'PENDING' AND dataset = p_dataset AND language = p_language
                LIMIT 1 FOR UPDATE SKIP LOCKED;
            ELSIF p_dataset IS NOT NULL THEN
                SELECT * INTO v_row FROM {TABLE}
                WHERE status = 'PENDING' AND dataset = p_dataset
                LIMIT 1 FOR UPDATE SKIP LOCKED;
            ELSIF p_language IS NOT NULL THEN
                SELECT * INTO v_row FROM {TABLE}
                WHERE status = 'PENDING' AND language = p_language
                LIMIT 1 FOR UPDATE SKIP LOCKED;
            ELSE
                SELECT * INTO v_row FROM {TABLE}
                WHERE status = 'PENDING'
                LIMIT 1 FOR UPDATE SKIP LOCKED;
            END IF;

            IF v_row IS NULL THEN RETURN; END IF;

            UPDATE {TABLE} SET
                status = 'CLAIMED',
                claimed_by = p_worker_id,
                claimed_at = now(),
                updated_at = now()
            WHERE shard_key = v_row.shard_key;

            v_row.status := 'CLAIMED';
            v_row.claimed_by := p_worker_id;
            RETURN NEXT v_row;
        END;
        $$ LANGUAGE plpgsql;
        """)
        logger.info("Created claim_next_sft_shard RPC")

    def claim_shard(
        self, worker_id: str, dataset: str | None = None, language: str | None = None,
    ) -> dict | None:
        """Atomically claim the next PENDING shard."""
        conn = _get_pg_conn()
        try:
            conn.autocommit = False
            with conn.cursor() as cur:
                cur.execute(
                    "SELECT * FROM claim_next_sft_shard(%s, %s, %s)",
                    (worker_id, dataset, language),
                )
                cols = [d[0] for d in cur.description]
                row = cur.fetchone()
                conn.commit()
                if row:
                    shard = dict(zip(cols, row))
                    logger.info("Claimed shard %s (%s/%s)", shard["shard_key"], shard["dataset"], shard["language"])
                    return shard
                return None
        except Exception as e:
            conn.rollback()
            logger.error("claim_shard failed: %s", e)
            return None
        finally:
            conn.close()

    def update_shard_status(self, shard_key: str, status: str, extra: dict[str, Any] | None = None) -> None:
        now = datetime.now(timezone.utc).isoformat()
        sets = ["status = %s", "updated_at = %s"]
        vals: list[Any] = [status, now]

        if status == "PROCESSING":
            sets.append("started_at = %s")
            vals.append(now)
        elif status in ("DONE", "FAILED"):
            sets.append("finished_at = %s")
            vals.append(now)

        if extra:
            for k, v in extra.items():
                sets.append(f"{k} = %s")
                vals.append(v)

        vals.append(shard_key)
        self._exec_sql(
            f"UPDATE {TABLE} SET {', '.join(sets)} WHERE shard_key = %s",
            tuple(vals),
        )

    def ingest_shards(self, rows: list[dict], batch_size: int = 500) -> int:
        """Bulk insert shard rows. Skip conflicts (already ingested)."""
        conn = _get_pg_conn()
        total = 0
        try:
            conn.autocommit = True
            with conn.cursor() as cur:
                for i in range(0, len(rows), batch_size):
                    batch = rows[i : i + batch_size]
                    args = []
                    for r in batch:
                        args.append(cur.mogrify(
                            "(%s, %s, %s, %s, 'PENDING')",
                            (r["shard_key"], r["dataset"], r["language"],
                             r.get("segment_count")),
                        ).decode())
                    sql = (
                        f"INSERT INTO {TABLE} (shard_key, dataset, language, segment_count, status) "
                        f"VALUES {','.join(args)} ON CONFLICT (shard_key) DO NOTHING"
                    )
                    cur.execute(sql)
                    total += len(batch)
                    if total % 2000 == 0:
                        logger.info("Ingested %d / %d shards", total, len(rows))
        finally:
            conn.close()
        logger.info("Ingested %d shards total", total)
        return total

    def reset_pending(
        self,
        statuses: list[str] | None = None,
        clear_outputs: bool = True,
    ) -> int:
        """Reset selected statuses back to PENDING for a clean rerun.

        Keeps static shard metadata (dataset/language/segment_count) but clears
        worker/runtime fields so workers can reclaim from scratch.
        """
        if statuses is None:
            statuses = ["CLAIMED", "PROCESSING", "DONE", "FAILED"]

        sets = [
            "status = 'PENDING'",
            "claimed_by = NULL",
            "claimed_at = NULL",
            "started_at = NULL",
            "finished_at = NULL",
            "updated_at = now()",
            "error_detail = NULL",
        ]
        if clear_outputs:
            sets.extend([
                "segments_encoded = NULL",
                "segments_failed = NULL",
                "total_audio_s = NULL",
                "total_encode_ms = NULL",
                "output_r2_key = NULL",
            ])

        conn = _get_pg_conn()
        try:
            conn.autocommit = True
            with conn.cursor() as cur:
                cur.execute(
                    f"""
                    UPDATE {TABLE}
                    SET {', '.join(sets)}
                    WHERE status = ANY(%s)
                    """,
                    (statuses,),
                )
                count = cur.rowcount or 0
        finally:
            conn.close()

        logger.info("Reset %d shards back to PENDING (statuses=%s)", count, ",".join(statuses))
        return count

    def get_stats(self) -> dict[str, Any]:
        rows = self._query(f"""
            SELECT status, dataset, count(*) as cnt,
                   coalesce(sum(total_audio_s), 0) as audio_s
            FROM {TABLE}
            GROUP BY status, dataset
            ORDER BY dataset, status
        """)
        return rows
