#!/usr/bin/env python3
"""Sync R2 bucket contents → Supabase encoding_videos table.

Modes:
  --purge  (default)  Delete everything not backed by an R2 file. Table = R2 truth only.
  --dry-run           Report what would happen, don't write.

Caches R2 listing to /tmp/pipeline/r2_cache.json (~35 min to build, reusable).
Use --refresh to force re-listing R2 buckets.

Usage:
  python scripts/sync_r2_to_supabase.py                       # purge mode, use cache if fresh
  python scripts/sync_r2_to_supabase.py --refresh              # force re-list R2
  python scripts/sync_r2_to_supabase.py --dry-run              # report only
  python scripts/sync_r2_to_supabase.py --bucket pt-english    # single bucket
"""

from __future__ import annotations

import argparse
import logging
import os
import re
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from dotenv import load_dotenv
load_dotenv()

from codecbench.pipeline.r2_client import R2Client
from codecbench.pipeline.config import R2Config, SupabaseConfig
from codecbench.pipeline.supabase_client import SupabaseOrchestrator, _get_pg_conn

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)

# YouTube video IDs: exactly 11 chars of [A-Za-z0-9_-]
_YT_ID = r"[A-Za-z0-9_-]{11}"

KNOWN_INDIC_LANGS = {"telugu","tamil","hindi","malayalam","kannada","marathi","bengali",
                     "gujarati","punjabi","assamese","odia","indic","unk"}


def _extract_english(key: str):
    """Extract (video_id, 'english') from any pt-english or pretraindata-english key."""
    # english_pretrain/{vid}.ext  (dominant, ~400K)
    m = re.match(r"^english_pretrain/(.+?)\.(webm|m4a)$", key)
    if m:
        return (m.group(1), "english")
    # audio/{vid}.ext
    m = re.match(r"^audio/(.+?)\.(webm|m4a)$", key)
    if m:
        return (m.group(1), "english")
    # {vid}_english_pretrain.ext  (flat, ~10K)
    m = re.match(r"^(.+)_english_pretrain\.(webm|m4a)$", key)
    if m:
        return (m.group(1), "english")
    return None


def _norm_indic_lang(lang: str) -> str:
    return lang if lang in KNOWN_INDIC_LANGS else "indic"


def _extract_indic(key: str):
    """Extract (video_id, language) from any pt-indic key."""
    # {language}/{vid}.ext  (directory: bengali/, telugu/, hindi/, indic/, etc.)
    m = re.match(r"^(\w+)/(.+?)\.(webm|m4a)$", key)
    if m:
        lang = _norm_indic_lang(m.group(1))
        return (m.group(2), lang)
    # Flat: {11-char-ytid}_{anything}_pretrain.ext
    # Use 11-char YouTube ID anchor to cleanly separate vid from description
    m = re.match(rf"^({_YT_ID})_(.+?)_pretrain\.(webm|m4a)$", key)
    if m:
        vid = m.group(1)
        desc = m.group(2)
        # Check if desc is a known language name (common case)
        lang = desc if desc in KNOWN_INDIC_LANGS else "indic"
        return (vid, lang)
    # Fallback: any key ending _pretrain.ext, extract first 11 chars as vid
    m = re.match(r"^(.+)_pretrain\.(webm|m4a)$", key)
    if m:
        prefix = m.group(1)
        # Take first 11 chars if they look like a YouTube ID
        if len(prefix) >= 11 and re.match(rf"^{_YT_ID}", prefix):
            vid = prefix[:11]
            remainder = prefix[12:]  # skip the underscore
            lang = remainder if remainder in KNOWN_INDIC_LANGS else "indic"
            return (vid, lang)
        return (prefix, "indic")
    return None


BUCKET_CONFIG = {
    "pretraindata-english": {"extractor": _extract_english, "source_bucket": "pretraindata-english"},
    "pt-english":           {"extractor": _extract_english, "source_bucket": "pt-english"},
    "pt-indic":             {"extractor": _extract_indic,   "source_bucket": "pt-indic"},
}


def list_bucket_video_ids(r2: R2Client, bucket: str) -> dict[str, dict]:
    """List ALL objects in a bucket, extract video IDs.

    Returns {video_id: {"key": r2_key, "size": bytes, "language": str, "source_bucket": str}}
    """
    cfg = BUCKET_CONFIG[bucket]
    extractor = cfg["extractor"]
    source_bucket = cfg["source_bucket"]

    videos = {}
    continuation_token = None
    page = 0
    unmatched = 0

    while True:
        kwargs = {"Bucket": bucket, "MaxKeys": 1000}
        if continuation_token:
            kwargs["ContinuationToken"] = continuation_token

        resp = r2._client.list_objects_v2(**kwargs)
        page += 1

        for obj in resp.get("Contents", []):
            key = obj["Key"]
            result = extractor(key)
            if result:
                vid, lang = result
                videos[vid] = {
                    "key": key,
                    "size_bytes": obj["Size"],
                    "language": lang,
                    "source_bucket": source_bucket,
                }
            else:
                unmatched += 1

        if page % 200 == 0:
            logger.info("  %s: page %d, %d video IDs, %d unmatched keys",
                        bucket, page, len(videos), unmatched)

        if not resp.get("IsTruncated", False):
            break
        continuation_token = resp.get("NextContinuationToken")

    logger.info("  %s: DONE — %d video IDs from %d pages (%d unmatched keys)",
                bucket, len(videos), page, unmatched)
    return videos


def get_supabase_video_ids(conn) -> dict[str, str]:
    """Fetch all (video_id, status) from encoding_videos. Returns {video_id: status}."""
    with conn.cursor() as cur:
        cur.execute("SELECT video_id, status FROM encoding_videos")
        rows = cur.fetchall()
    return {r[0]: r[1] for r in rows}


def sync_to_supabase(
    conn,
    r2_videos: dict[str, dict],
    supabase_videos: dict[str, str],
    bucket: str,
    dry_run: bool = False,
) -> dict:
    """Sync R2 video IDs into Supabase. Returns stats dict."""
    new_inserts = []
    no_file_fixes = []
    already_ok = 0
    already_done = 0

    for vid, info in r2_videos.items():
        status = supabase_videos.get(vid)
        if status is None:
            new_inserts.append((vid, info))
        elif status == "NO_FILE":
            no_file_fixes.append((vid, info))
        elif status in ("DONE", "PACKED", "ENCODED", "PROCESSING"):
            already_done += 1
        else:
            already_ok += 1

    stats = {
        "bucket": bucket,
        "r2_total": len(r2_videos),
        "new_inserts": len(new_inserts),
        "no_file_fixes": len(no_file_fixes),
        "already_ok": already_ok,
        "already_done": already_done,
    }

    logger.info(
        "%s: %d in R2 | %d new | %d NO_FILE→PENDING | %d already pending | %d done",
        bucket, stats["r2_total"], stats["new_inserts"],
        stats["no_file_fixes"], stats["already_ok"], stats["already_done"],
    )

    if dry_run:
        logger.info("  DRY RUN — no changes made")
        return stats

    batch_size = 500

    if new_inserts:
        logger.info("  Inserting %d new videos...", len(new_inserts))
        with conn.cursor() as cur:
            values = []
            for vid, info in new_inserts:
                values.append(cur.mogrify(
                    "(%s, %s, %s, 'PENDING')",
                    (vid, info["language"], info["source_bucket"]),
                ).decode())

                if len(values) >= batch_size:
                    sql = (
                        "INSERT INTO encoding_videos (video_id, language, source_bucket, status) "
                        "VALUES " + ", ".join(values) +
                        " ON CONFLICT (video_id) DO NOTHING"
                    )
                    cur.execute(sql)
                    conn.commit()
                    values = []

            if values:
                sql = (
                    "INSERT INTO encoding_videos (video_id, language, source_bucket, status) "
                    "VALUES " + ", ".join(values) +
                    " ON CONFLICT (video_id) DO NOTHING"
                )
                cur.execute(sql)
                conn.commit()

        logger.info("  Inserted %d new videos", len(new_inserts))

    if no_file_fixes:
        logger.info("  Fixing %d NO_FILE → PENDING...", len(no_file_fixes))
        with conn.cursor() as cur:
            fix_ids = [vid for vid, _ in no_file_fixes]
            for i in range(0, len(fix_ids), batch_size):
                batch = fix_ids[i:i+batch_size]
                cur.execute(
                    "UPDATE encoding_videos SET status = 'PENDING', updated_at = now() "
                    "WHERE video_id = ANY(%s) AND status = 'NO_FILE'",
                    (batch,),
                )
            conn.commit()
        logger.info("  Fixed %d NO_FILE → PENDING", len(no_file_fixes))

    return stats


CACHE_PATH = Path("/tmp/pipeline/r2_cache.json")


def save_r2_cache(all_r2: dict[str, dict[str, dict]]) -> None:
    """Save R2 listing to disk for reuse."""
    import json
    CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
    # Flatten to {video_id: {language, source_bucket}} — drop key/size to save space
    flat = {}
    for bucket, videos in all_r2.items():
        for vid, info in videos.items():
            flat[vid] = {"language": info["language"], "source_bucket": info["source_bucket"]}
    with open(CACHE_PATH, "w") as f:
        json.dump({"ts": time.time(), "count": len(flat), "videos": flat}, f)
    logger.info("Saved R2 cache: %d video IDs → %s", len(flat), CACHE_PATH)


def load_r2_cache(max_age_h: float = 24.0) -> dict[str, dict] | None:
    """Load cached R2 listing if fresh enough."""
    import json
    if not CACHE_PATH.exists():
        return None
    with open(CACHE_PATH) as f:
        data = json.load(f)
    age_h = (time.time() - data["ts"]) / 3600
    if age_h > max_age_h:
        logger.info("R2 cache too old (%.1f h), will re-list", age_h)
        return None
    logger.info("Loaded R2 cache: %d video IDs (%.1f h old)", data["count"], age_h)
    return data["videos"]


def list_all_r2(r2: R2Client, buckets: list[str], refresh: bool = False) -> dict[str, dict]:
    """List all R2 video IDs, with caching. Returns {video_id: {language, source_bucket}}."""
    if not refresh:
        cached = load_r2_cache()
        if cached:
            return cached

    logger.info("Listing R2 buckets: %s (this takes ~35 min for pt-indic)", buckets)
    all_r2: dict[str, dict[str, dict]] = {}
    with ThreadPoolExecutor(max_workers=3) as pool:
        futures = {pool.submit(list_bucket_video_ids, r2, b): b for b in buckets}
        for fut in as_completed(futures):
            bucket = futures[fut]
            try:
                all_r2[bucket] = fut.result()
            except Exception as e:
                logger.error("Failed listing %s: %s", bucket, e)
                all_r2[bucket] = {}

    # Flatten and cache
    flat = {}
    for bucket, videos in all_r2.items():
        for vid, info in videos.items():
            flat[vid] = {"language": info["language"], "source_bucket": info["source_bucket"]}
    save_r2_cache(all_r2)
    return flat


def purge_supabase(r2_videos: dict[str, dict], dry_run: bool = False) -> None:
    """Make encoding_videos match R2 exactly: delete non-R2 rows, insert missing R2 rows."""
    conn = _get_pg_conn()
    conn.autocommit = True

    # 1. Load current Supabase state
    logger.info("Fetching Supabase state...")
    with conn.cursor() as cur:
        cur.execute("SELECT video_id, status FROM encoding_videos")
        sb_rows = {r[0]: r[1] for r in cur.fetchall()}
    logger.info("Supabase: %d rows", len(sb_rows))

    # Partition
    r2_set = set(r2_videos.keys())
    sb_set = set(sb_rows.keys())

    to_delete = sb_set - r2_set          # in Supabase but not in R2
    to_insert = r2_set - sb_set          # in R2 but not in Supabase
    already_ok = r2_set & sb_set         # in both

    done_kept = sum(1 for v in already_ok if sb_rows[v] == "DONE")
    pending_kept = sum(1 for v in already_ok if sb_rows[v] == "PENDING")

    logger.info("R2 total:      %10s", f"{len(r2_set):,}")
    logger.info("Delete (no R2): %9s", f"{len(to_delete):,}")
    logger.info("Insert (new):   %9s", f"{len(to_insert):,}")
    logger.info("Keep (overlap): %9s (DONE=%d, PENDING=%d)", f"{len(already_ok):,}", done_kept, pending_kept)

    if dry_run:
        logger.info("DRY RUN — no changes")
        conn.close()
        return

    # 2. Delete rows not backed by R2
    if to_delete:
        logger.info("Deleting %d rows without R2 files...", len(to_delete))
        batch = 1000
        deleted_total = 0
        to_del_list = list(to_delete)
        with conn.cursor() as cur:
            for i in range(0, len(to_del_list), batch):
                chunk = to_del_list[i:i+batch]
                cur.execute("DELETE FROM encoding_videos WHERE video_id = ANY(%s)", (chunk,))
                deleted_total += cur.rowcount
                if deleted_total % 50000 == 0:
                    logger.info("  deleted %d / %d", deleted_total, len(to_delete))
        logger.info("Deleted %d rows", deleted_total)

    # 3. Insert R2 videos not yet in Supabase
    if to_insert:
        logger.info("Inserting %d new R2-backed videos...", len(to_insert))
        batch = 500
        inserted = 0
        with conn.cursor() as cur:
            values = []
            for vid in to_insert:
                info = r2_videos[vid]
                values.append(cur.mogrify(
                    "(%s, %s, %s, 'PENDING')",
                    (vid, info["language"], info["source_bucket"]),
                ).decode())
                if len(values) >= batch:
                    cur.execute(
                        "INSERT INTO encoding_videos (video_id, language, source_bucket, status) "
                        "VALUES " + ", ".join(values) +
                        " ON CONFLICT (video_id) DO NOTHING"
                    )
                    inserted += len(values)
                    values = []
            if values:
                cur.execute(
                    "INSERT INTO encoding_videos (video_id, language, source_bucket, status) "
                    "VALUES " + ", ".join(values) +
                    " ON CONFLICT (video_id) DO NOTHING"
                )
                inserted += len(values)
        logger.info("Inserted %d rows", inserted)

    # 4. Reset any stuck states back to PENDING
    with conn.cursor() as cur:
        cur.execute("""
            UPDATE encoding_videos SET status = 'PENDING', claimed_by = NULL, claimed_at = NULL
            WHERE status IN ('DOWNLOADING', 'CLAIMED', 'PROCESSING', 'ENCODED', 'FAILED', 'TIMEOUT', 'NO_FILE')
        """)
        reset = cur.rowcount
        if reset:
            logger.info("Reset %d stuck/failed rows → PENDING", reset)

    conn.close()

    # Final count
    conn2 = _get_pg_conn()
    with conn2.cursor() as cur:
        cur.execute("SELECT status, count(*) FROM encoding_videos GROUP BY status ORDER BY status")
        rows = cur.fetchall()
    conn2.close()

    print("\n" + "=" * 50)
    print("encoding_videos — FINAL STATE")
    print("=" * 50)
    total = 0
    for status, cnt in rows:
        print(f"  {status:<15} {cnt:>10,}")
        total += cnt
    print(f"  {'TOTAL':<15} {total:>10,}")
    print(f"\n  Every PENDING row has a verified R2 file.")
    print()


def main():
    parser = argparse.ArgumentParser(description="Sync R2 buckets → Supabase")
    parser.add_argument("--bucket", nargs="*", default=None,
                        help="Specific buckets (default: all)")
    parser.add_argument("--dry-run", action="store_true",
                        help="Report only, don't modify Supabase")
    parser.add_argument("--refresh", action="store_true",
                        help="Force re-listing R2 buckets (ignores cache)")
    args = parser.parse_args()

    buckets = args.bucket or list(BUCKET_CONFIG.keys())
    for b in buckets:
        if b not in BUCKET_CONFIG:
            logger.error("Unknown bucket: %s (known: %s)", b, list(BUCKET_CONFIG.keys()))
            sys.exit(1)

    r2 = R2Client(R2Config())
    r2_videos = list_all_r2(r2, buckets, refresh=args.refresh)
    purge_supabase(r2_videos, dry_run=args.dry_run)


if __name__ == "__main__":
    main()
