"""CLI for the encoding pipeline.

Commands:
  ingest        — Upload video metadata from CSVs to Supabase
  run           — Start the pipeline worker (pretraining)
  setup-db      — Create Supabase tables + RPC functions only
  stats         — Show pipeline progress stats
  bench         — Profile encoding on this GPU for optimal config
  sft-snapshot  — Snapshot R2 SFT shard paths into Supabase
  sft-run       — Start SFT shard processing worker
  sft-stats     — Show SFT processing stats
"""

from __future__ import annotations

import argparse
import csv
import gc
import logging
import os
import sys
import time
from pathlib import Path

import torch


def setup_logging(verbose: bool = False) -> None:
    level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(
        level=level,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )


def cmd_ingest(args: argparse.Namespace) -> None:
    """Ingest video metadata from CSVs into Supabase."""
    from codecbench.pipeline.config import PipelineConfig
    from codecbench.pipeline.supabase_client import SupabaseOrchestrator

    cfg = PipelineConfig.from_env()
    orch = SupabaseOrchestrator(cfg.supabase)
    orch.ensure_tables()
    orch.create_claim_rpc()
    orch.create_release_stale_rpc()

    for csv_path in args.csv_files:
        csv_path = Path(csv_path)
        if not csv_path.exists():
            logging.error("CSV not found: %s", csv_path)
            continue

        logging.info("Reading %s...", csv_path)
        rows = []
        with open(csv_path, "r") as f:
            reader = csv.DictReader(f)
            for row in reader:
                lang = row.get("language", "unknown")
                bucket = "pt-english" if lang == "english" else "pt-indic"
                rows.append({
                    "video_id": row["video_id"],
                    "title": row.get("title", ""),
                    "duration_min": float(row.get("duration_min", 0)),
                    "classification": row.get("classification", ""),
                    "channel": row.get("channel", ""),
                    "language": lang,
                    "source_bucket": bucket,
                    "status": "PENDING",
                })

        logging.info("Ingesting %d videos from %s", len(rows), csv_path.name)
        orch.ingest_videos(rows, batch_size=args.batch_size)
        logging.info("Done: %d videos ingested from %s", len(rows), csv_path.name)


def cmd_sync(args: argparse.Namespace) -> None:
    """Pull CSVs from R2 metafiles bucket and ingest into Supabase."""
    from codecbench.pipeline.config import PipelineConfig
    from codecbench.pipeline.supabase_client import SupabaseOrchestrator
    from codecbench.pipeline.r2_client import R2Client
    import io

    cfg = PipelineConfig.from_env()
    r2 = R2Client(cfg.r2)
    orch = SupabaseOrchestrator(cfg.supabase)
    orch.ensure_tables()
    orch.create_claim_rpc()
    orch.create_release_stale_rpc()

    bucket = cfg.r2.metafiles_bucket

    if args.list_csvs:
        resp = r2._client.list_objects_v2(Bucket=bucket, MaxKeys=100)
        for obj in resp.get("Contents", []):
            print(f"  {obj['Key']:50s}  {obj['Size']/1e6:.1f} MB")
        return

    csv_keys = args.csv_keys
    if not csv_keys:
        logging.error("No CSV keys specified. Use --list to see available files.")
        sys.exit(1)

    total_ingested = 0
    for key in csv_keys:
        logging.info("Downloading %s/%s ...", bucket, key)
        resp = r2._client.get_object(Bucket=bucket, Key=key)
        content = resp["Body"].read().decode("utf-8")
        reader = csv.DictReader(io.StringIO(content))

        rows = []
        for row in reader:
            lang = row.get("language", "unknown")
            bucket_name = "pt-english" if lang == "english" else "pt-indic"
            rows.append({
                "video_id": row["video_id"],
                "title": row.get("title", ""),
                "duration_min": float(row.get("duration_min", 0)),
                "classification": row.get("classification", ""),
                "channel": row.get("channel", ""),
                "language": lang,
                "source_bucket": bucket_name,
                "status": "PENDING",
            })

        # Deduplicate by video_id (keep last occurrence)
        seen = {}
        for r in rows:
            seen[r["video_id"]] = r
        rows = list(seen.values())

        logging.info("Ingesting %d unique videos from %s", len(rows), key)
        count = orch.ingest_videos(rows, batch_size=args.batch_size)
        total_ingested += count
        logging.info("Done: %d videos from %s", count, key)

    logging.info("Total ingested: %d videos from %d CSVs", total_ingested, len(csv_keys))


def cmd_setup_db(args: argparse.Namespace) -> None:
    """Create Supabase tables and RPC functions."""
    from codecbench.pipeline.config import PipelineConfig
    from codecbench.pipeline.supabase_client import SupabaseOrchestrator

    cfg = PipelineConfig.from_env()
    orch = SupabaseOrchestrator(cfg.supabase)
    orch.ensure_tables()
    orch.create_claim_rpc()
    orch.create_release_stale_rpc()
    logging.info("Database setup complete")


def cmd_run(args: argparse.Namespace) -> None:
    """Start the pipeline worker."""
    from codecbench.pipeline.config import PipelineConfig
    from codecbench.pipeline.worker import PipelineWorker

    cfg = PipelineConfig.from_env()

    # CLI overrides
    if args.language:
        pass  # used in worker.run()
    if args.batch_size:
        cfg.codec.xcodec_batch_size = args.batch_size
    if args.no_parallel:
        cfg.worker.parallel_encode = False
    if args.shard_count:
        cfg.worker.shard_pack_count = args.shard_count
    if args.prefetch:
        cfg.worker.prefetch_videos = args.prefetch
    if args.offer_id:
        cfg.worker.offer_id = args.offer_id
    if args.custom_ckpt:
        cfg.codec.xcodec2_custom_ckpt = args.custom_ckpt
    if args.tmp_dir:
        cfg.worker.local_tmp_dir = args.tmp_dir
    if args.oom_segment_threshold is not None:
        cfg.worker.oom_segment_threshold = max(args.oom_segment_threshold, 1)
    if args.use_async:
        cfg.worker.use_async_pipeline = True
    if args.extract_workers:
        cfg.worker.extract_workers = args.extract_workers

    worker = PipelineWorker(cfg)
    worker.setup()
    worker.run(language=args.language, max_videos=args.max_videos)


def cmd_stats(args: argparse.Namespace) -> None:
    """Show pipeline progress."""
    from codecbench.pipeline.config import PipelineConfig
    from codecbench.pipeline.supabase_client import SupabaseOrchestrator

    cfg = PipelineConfig.from_env()
    orch = SupabaseOrchestrator(cfg.supabase)
    stats = orch.get_stats()

    total = sum(v for v in stats.values() if v >= 0)
    print(f"\n{'Status':<15} {'Count':>10} {'%':>7}")
    print("-" * 35)
    for status, count in sorted(stats.items()):
        pct = 100 * count / max(total, 1)
        print(f"{status:<15} {count:>10,} {pct:>6.1f}%")
    print("-" * 35)
    print(f"{'TOTAL':<15} {total:>10,}")
    print()


def cmd_fleet(args: argparse.Namespace) -> None:
    """Show all workers and their current status."""
    from codecbench.pipeline.config import PipelineConfig
    from codecbench.pipeline.supabase_client import SupabaseOrchestrator

    cfg = PipelineConfig.from_env()
    orch = SupabaseOrchestrator(cfg.supabase)

    result = orch._client.table(cfg.supabase.workers_table).select("*").order(
        "status", desc=False
    ).execute()
    workers = result.data or []

    if not workers:
        print("\nNo workers registered.\n")
        return

    alive = [w for w in workers if w.get("status") == "ALIVE"]
    loading = [w for w in workers if w.get("status") == "LOADING_MODELS"]
    stopped = [w for w in workers if w.get("status") in ("STOPPED", "DEAD")]

    print(f"\n{'='*120}")
    print(f"  WORKER FLEET — {len(alive)} alive, {len(loading)} loading, {len(stopped)} stopped  |  {len(workers)} total")
    print(f"{'='*120}")
    print(f"  {'Worker ID':<40} {'Status':<12} {'GPU':<22} {'Videos':>7} {'Shards':>7} "
          f"{'RTF':>6} {'V/hr':>6} {'Audio':>8} {'Uptime':>8} {'Current'}")
    print(f"  {'-'*118}")

    from datetime import datetime, timezone
    for w in workers:
        wid = w.get("worker_id", "?")[:38]
        status = w.get("status", "?")
        gpu = (w.get("gpu_name") or "?")[:20]
        vids = w.get("total_videos_done") or 0
        shards = w.get("total_shards_produced") or 0
        rtf = w.get("avg_encode_rtf") or w.get("rtf") or 0
        vph = w.get("videos_per_hour") or 0
        audio_h = (w.get("total_audio_processed_s") or 0) / 3600
        uptime_h = (w.get("uptime_s") or 0) / 3600
        cur_vid = (w.get("current_video_id") or "—")[:11]
        stage = w.get("current_stage") or ""

        status_icon = {"ALIVE": "●", "LOADING_MODELS": "◐", "STOPPED": "○", "DEAD": "✗"}.get(status, "?")

        print(f"  {wid:<40} {status_icon} {status:<10} {gpu:<22} {vids:>7} {shards:>7} "
              f"{rtf:>5.0f}x {vph:>5.0f} {audio_h:>7.1f}h {uptime_h:>7.1f}h {cur_vid} {stage}")

    # Totals
    total_vids = sum(w.get("total_videos_done", 0) or 0 for w in workers)
    total_shards = sum(w.get("total_shards_produced", 0) or 0 for w in workers)
    total_audio_h = sum((w.get("total_audio_processed_s", 0) or 0) for w in workers) / 3600
    total_vph = sum((w.get("videos_per_hour", 0) or 0) for w in alive)
    print(f"  {'-'*118}")
    print(f"  {'TOTAL':<40} {'':12} {'':<22} {total_vids:>7} {total_shards:>7} "
          f"{'':>6} {total_vph:>5.0f} {total_audio_h:>7.1f}h")

    if total_vph > 0:
        remaining = 2965716 - total_vids
        eta_h = remaining / total_vph
        print(f"\n  ETA at current rate: {eta_h:,.0f}h ({eta_h/24:,.1f} days) for {remaining:,} remaining videos")
    print()


def cmd_bench(args: argparse.Namespace) -> None:
    """Profile XCodec2 encoding on this GPU to find optimal batch size."""
    from codecbench.pipeline.config import PipelineConfig
    from codecbench.pipeline.encoder import HotEncoder
    from codecbench.pipeline.vad import Segment

    cfg = PipelineConfig.from_env()
    if args.custom_ckpt:
        cfg.codec.xcodec2_custom_ckpt = args.custom_ckpt

    device = "cuda" if torch.cuda.is_available() else "cpu"
    gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
    vram_total = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

    print(f"\n=== GPU Encoding Benchmark (XCodec2 only) ===")
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {vram_total:.1f} GB")
    print()

    results = []
    for xbs in [1, 2, 4, 8]:
        cfg.codec.xcodec_batch_size = xbs
        encoder = HotEncoder(cfg.codec, device=device)
        encoder.load()

        n_segments = max(xbs * 2, 8)
        segments = []
        for i in range(n_segments):
            wav = torch.randn(1, 96_000)
            segments.append(Segment(start_s=0, end_s=6.0, audio=wav))

        encoder.encode_segments(segments[:2])
        torch.cuda.synchronize()

        t0 = time.perf_counter()
        encoded = encoder.encode_segments(segments)
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - t0

        ms_per_seg = (elapsed * 1000) / len(segments)
        rtf = (6.0 * len(segments)) / elapsed
        vram_used = torch.cuda.memory_allocated() / 1e6

        results.append({
            "xcodec_bs": xbs,
            "ms_per_seg": ms_per_seg,
            "rtf": rtf,
            "vram_mb": vram_used,
        })

        print(f"XCodec BS={xbs}: {ms_per_seg:>7.1f} ms/seg, RTF={rtf:>6.1f}x, VRAM={vram_used:>6.0f} MB")

        del encoder
        torch.cuda.empty_cache()
        gc.collect()

    best = min(results, key=lambda r: r["ms_per_seg"])
    print(f"\n{'='*60}")
    print(f"Best config: XCodec BS={best['xcodec_bs']}")
    print(f"  {best['ms_per_seg']:.1f} ms/seg, RTF={best['rtf']:.1f}x, VRAM={best['vram_mb']:.0f} MB")
    print(f"  Estimated 1hr audio in: {3600/best['rtf']:.0f} s")
    print()


# ── SFT commands ─────────────────────────────────────────────────────

def cmd_sft_snapshot(args: argparse.Namespace) -> None:
    """Snapshot all SFT shard paths from R2 and ingest into Supabase."""
    import boto3
    from dotenv import load_dotenv
    from codecbench.pipeline.sft_supabase import SFTOrchestrator

    load_dotenv()
    s3 = boto3.client(
        "s3",
        endpoint_url=os.environ["R2_ENDPOINT_URL"],
        aws_access_key_id=os.environ["R2_ACCESS_KEY_ID"],
        aws_secret_access_key=os.environ["R2_SECRET_ACCESS_KEY"],
        region_name="auto",
    )
    bucket = os.environ.get("R2_BUCKET_DESTINATION", "finalsftdata")

    datasets_config = {
        "hifitts2":       [("hifitts2/lang=en/", "en")],
        "indicvoices-r":  [(f"indicvoices-r/lang={l}/", l) for l in
                           ["as","bn","gu","hi","kn","ml","mr","or","pa","ta","te"]],
        "josh":           [(f"josh/lang={l}/", l) for l in
                           ["bn","en","gu","hi","mr","ta","te"]],
        "joshdelivery":   [(f"joshdelivery/lang={l}/", l) for l in
                           ["bn","en","gu","hi","te"]],
        "final-export":   [(f"final-export/production/shards/lang={l}/", l) for l in
                           ["as","bn","en","gu","hi","kn","ml","mr","or","pa","ta","te"]],
    }

    all_shards = []
    for ds_name, prefixes in datasets_config.items():
        for prefix, lang in prefixes:
            paginator = s3.get_paginator("list_objects_v2")
            for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"):
                for cp in page.get("CommonPrefixes", []):
                    shard_key = cp["Prefix"]
                    all_shards.append({
                        "shard_key": shard_key,
                        "dataset": ds_name,
                        "language": lang,
                    })
            logging.info("  %s %s: found %d shards so far", ds_name, lang,
                        sum(1 for s in all_shards if s["dataset"] == ds_name and s["language"] == lang))

    logging.info("Total shards discovered: %d", len(all_shards))

    orch = SFTOrchestrator()
    orch.ensure_tables()
    orch.create_claim_rpc()
    count = orch.ingest_shards(all_shards, batch_size=args.batch_size)
    logging.info("Ingested %d shards into Supabase", count)


def cmd_sft_reset(args: argparse.Namespace) -> None:
    """Reset shard statuses back to PENDING for a fresh production run."""
    from codecbench.pipeline.sft_supabase import SFTOrchestrator

    statuses = [s.strip().upper() for s in args.statuses.split(",") if s.strip()]
    orch = SFTOrchestrator()
    count = orch.reset_pending(statuses=statuses, clear_outputs=not args.keep_outputs)
    logging.info("Reset %d shard rows to PENDING", count)


def cmd_sft_run(args: argparse.Namespace) -> None:
    """Start the SFT shard processing worker."""
    from codecbench.pipeline.config import PipelineConfig
    from codecbench.pipeline.sft_worker import SFTWorker

    cfg = PipelineConfig.from_env()
    if args.batch_size:
        cfg.codec.xcodec_batch_size = args.batch_size
    if args.offer_id:
        cfg.worker.offer_id = args.offer_id
    if args.custom_ckpt:
        cfg.codec.xcodec2_custom_ckpt = args.custom_ckpt

    worker = SFTWorker(cfg)
    worker.setup()
    worker.run(
        max_shards=args.max_shards,
        dataset=args.dataset,
        language=args.language,
        benchmark=getattr(args, "benchmark", False),
    )


def cmd_sft_stats(args: argparse.Namespace) -> None:
    """Show SFT shard processing stats."""
    from codecbench.pipeline.sft_supabase import SFTOrchestrator

    orch = SFTOrchestrator()
    rows = orch.get_stats()

    print(f"\n{'Dataset':<20} {'Status':<12} {'Count':>8} {'Audio (h)':>10}")
    print("-" * 55)
    totals = {}
    for r in rows:
        ds = r["dataset"]
        st = r["status"]
        cnt = r["cnt"]
        audio_h = (r["audio_s"] or 0) / 3600
        print(f"{ds:<20} {st:<12} {cnt:>8} {audio_h:>9.1f}")
        totals[st] = totals.get(st, 0) + cnt
    print("-" * 55)
    for st, cnt in sorted(totals.items()):
        print(f"{'TOTAL':<20} {st:<12} {cnt:>8}")
    print()


def main():
    parser = argparse.ArgumentParser(
        description="Audio codec encoding pipeline",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("-v", "--verbose", action="store_true")
    sub = parser.add_subparsers(dest="command", required=True)

    # ingest
    p_ingest = sub.add_parser("ingest", help="Upload video metadata from CSVs to Supabase")
    p_ingest.add_argument("csv_files", nargs="+", help="CSV file paths")
    p_ingest.add_argument("--batch-size", type=int, default=500)

    # sync
    p_sync = sub.add_parser("sync", help="Pull CSVs from R2 metafiles bucket → Supabase")
    p_sync.add_argument("csv_keys", nargs="*", help="R2 keys (e.g. indic_podcasts.csv english_podcasts.csv)")
    p_sync.add_argument("--list", dest="list_csvs", action="store_true",
                        help="List available CSVs in metafiles bucket")
    p_sync.add_argument("--batch-size", type=int, default=500)

    # setup-db
    sub.add_parser("setup-db", help="Create Supabase tables + RPC functions")

    # run
    p_run = sub.add_parser("run", help="Start pipeline worker")
    p_run.add_argument("--language", type=str, default=None,
                       help="Filter videos by language (None=any)")
    p_run.add_argument("--max-videos", type=int, default=None,
                       help="Stop after N videos")
    p_run.add_argument("--batch-size", type=int, default=None,
                       help="XCodec2 batch size override")
    p_run.add_argument("--no-parallel", action="store_true",
                       help="Disable parallel dual-stream encoding")
    p_run.add_argument("--shard-count", type=int, default=None,
                       help="Videos per shard pack")
    p_run.add_argument("--prefetch", type=int, default=None,
                       help="Number of videos to prefetch")
    p_run.add_argument("--offer-id", type=str, default=None,
                       help="Vast.ai offer ID for this instance")
    p_run.add_argument("--custom-ckpt", type=str, default=None,
                       help="Path to custom XCodec2 checkpoint")
    p_run.add_argument("--tmp-dir", type=str, default=None,
                       help="Local temp directory for downloads/processing")
    p_run.add_argument("--oom-segment-threshold", type=int, default=None,
                       help="Force safer OOM-resistant encode mode above this segment count")
    p_run.add_argument("--async", dest="use_async", action="store_true",
                       help="Use async 3-stage pipeline (overlapping download+extract+encode)")
    p_run.add_argument("--extract-workers", type=int, default=None,
                       help="Number of parallel ffmpeg+VAD workers for async mode")

    # stats
    sub.add_parser("stats", help="Show pipeline progress stats")

    # fleet
    sub.add_parser("fleet", help="Show worker fleet status and throughput")

    # bench
    p_bench = sub.add_parser("bench", help="Profile encoding for optimal config")
    p_bench.add_argument("--custom-ckpt", type=str, default=None)

    # ── SFT commands ──
    p_sft_snap = sub.add_parser("sft-snapshot", help="Snapshot R2 SFT shards → Supabase")
    p_sft_snap.add_argument("--batch-size", type=int, default=500)

    p_sft_reset = sub.add_parser("sft-reset", help="Reset SFT shard statuses back to PENDING")
    p_sft_reset.add_argument(
        "--statuses",
        type=str,
        default="CLAIMED,PROCESSING,DONE,FAILED",
        help="Comma-separated statuses to reset",
    )
    p_sft_reset.add_argument(
        "--keep-outputs",
        action="store_true",
        help="Keep output bookkeeping fields instead of clearing them",
    )

    p_sft_run = sub.add_parser("sft-run", help="Start SFT shard processing worker")
    p_sft_run.add_argument("--max-shards", type=int, default=None, help="Stop after N shards")
    p_sft_run.add_argument("--dataset", type=str, default=None, help="Filter by dataset name")
    p_sft_run.add_argument("--language", type=str, default=None, help="Filter by language")
    p_sft_run.add_argument("--batch-size", type=int, default=None, help="XCodec2 batch size")
    p_sft_run.add_argument("--offer-id", type=str, default=None, help="Vast.ai offer ID")
    p_sft_run.add_argument("--custom-ckpt", type=str, default=None, help="Custom XCodec2 checkpoint")
    p_sft_run.add_argument("--benchmark", action="store_true",
                           help="Benchmark mode: process 1 shard, confirm next download, report ETA")

    p_sft_stats = sub.add_parser("sft-stats", help="Show SFT processing stats")

    args = parser.parse_args()
    setup_logging(args.verbose)

    cmd_map = {
        "ingest": cmd_ingest,
        "sync": cmd_sync,
        "setup-db": cmd_setup_db,
        "run": cmd_run,
        "stats": cmd_stats,
        "fleet": cmd_fleet,
        "bench": cmd_bench,
        "sft-snapshot": cmd_sft_snapshot,
        "sft-reset": cmd_sft_reset,
        "sft-run": cmd_sft_run,
        "sft-stats": cmd_sft_stats,
    }
    cmd_map[args.command](args)


if __name__ == "__main__":
    main()
