#!/usr/bin/env python3
"""GPU Encoding Benchmark: detailed per-stage timing, VRAM, utilization metrics.

Runs N real videos through the full pipeline and captures:
  - Per-stage timing: download, extract, VAD, encode, pack, upload
  - GPU metrics: VRAM peak/avg, utilization%, idle time
  - CPU usage during each stage
  - Encoding RTF at different batch sizes and parallel modes
  - Per-codec breakdown (XCodec2 vs BiCodec timing)

Results written to stdout as a report + JSON for cross-GPU comparison.
"""

from __future__ import annotations

import gc
import json
import logging
import os
import subprocess
import sys
import time
from dataclasses import dataclass, field, asdict
from pathlib import Path
from threading import Thread, Event

import psutil
import torch

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from codecbench.pipeline.config import PipelineConfig
from codecbench.pipeline.r2_client import R2Client, extract_audio_from_video
from codecbench.pipeline.vad import segment_audio, Segment
from codecbench.pipeline.encoder import HotEncoder, EncodedSegment
from codecbench.pipeline.shard_packer import ShardBuffer, VideoTokens, pack_shard
from codecbench.pipeline.supabase_client import SupabaseOrchestrator

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("botocore").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger("benchmark")


@dataclass
class StageTimings:
    download_s: float = 0.0
    extract_s: float = 0.0
    vad_s: float = 0.0
    encode_s: float = 0.0
    pack_s: float = 0.0
    upload_s: float = 0.0

    @property
    def total_s(self) -> float:
        return self.download_s + self.extract_s + self.vad_s + self.encode_s + self.pack_s + self.upload_s

    @property
    def gpu_active_s(self) -> float:
        return self.encode_s

    @property
    def cpu_active_s(self) -> float:
        return self.download_s + self.extract_s + self.vad_s + self.pack_s + self.upload_s

    @property
    def gpu_idle_pct(self) -> float:
        if self.total_s == 0:
            return 0
        return 100 * self.cpu_active_s / self.total_s


@dataclass
class VideoResult:
    video_id: str
    language: str
    duration_min: float
    num_segments: int
    usable_audio_s: float
    timings: StageTimings
    encode_rtf: float = 0.0
    overall_rtf: float = 0.0


@dataclass
class GPUMetrics:
    name: str = ""
    vram_total_mb: float = 0
    vram_model_mb: float = 0
    vram_peak_mb: float = 0
    tflops_fp32: float = 0
    cuda_version: str = ""
    driver_version: str = ""


@dataclass
class BenchmarkConfig:
    xcodec_batch_size: int = 2
    parallel: bool = True
    num_videos: int = 10
    custom_ckpt: str | None = None
    standalone: bool = False
    standalone_bucket: str | None = None


@dataclass
class BenchmarkResult:
    gpu: GPUMetrics = field(default_factory=GPUMetrics)
    config: BenchmarkConfig = field(default_factory=BenchmarkConfig)
    videos: list[VideoResult] = field(default_factory=list)
    cpu_count: int = 0
    ram_gb: float = 0
    avg_cpu_pct: float = 0

    @property
    def total_audio_s(self) -> float:
        return sum(v.usable_audio_s for v in self.videos)

    @property
    def total_wall_s(self) -> float:
        return sum(v.timings.total_s for v in self.videos)

    @property
    def avg_encode_rtf(self) -> float:
        rtfs = [v.encode_rtf for v in self.videos if v.encode_rtf > 0]
        return sum(rtfs) / len(rtfs) if rtfs else 0

    @property
    def avg_overall_rtf(self) -> float:
        rtfs = [v.overall_rtf for v in self.videos if v.overall_rtf > 0]
        return sum(rtfs) / len(rtfs) if rtfs else 0

    @property
    def avg_gpu_idle_pct(self) -> float:
        idles = [v.timings.gpu_idle_pct for v in self.videos]
        return sum(idles) / len(idles) if idles else 0

    def stage_totals(self) -> StageTimings:
        t = StageTimings()
        for v in self.videos:
            t.download_s += v.timings.download_s
            t.extract_s += v.timings.extract_s
            t.vad_s += v.timings.vad_s
            t.encode_s += v.timings.encode_s
            t.pack_s += v.timings.pack_s
            t.upload_s += v.timings.upload_s
        return t


class GPUMonitor:
    """Background thread sampling GPU utilization + VRAM during encoding."""

    def __init__(self, interval_s: float = 0.5):
        self._interval = interval_s
        self._stop = Event()
        self._thread: Thread | None = None
        self.samples: list[dict] = []

    def start(self):
        self._stop.clear()
        self._thread = Thread(target=self._run, daemon=True)
        self._thread.start()

    def _run(self):
        while not self._stop.is_set():
            try:
                self.samples.append({
                    "ts": time.time(),
                    "vram_mb": torch.cuda.memory_allocated() / 1e6,
                    "vram_reserved_mb": torch.cuda.memory_reserved() / 1e6,
                    "gpu_util": self._get_gpu_util(),
                    "cpu_pct": psutil.cpu_percent(interval=None),
                })
            except Exception:
                pass
            self._stop.wait(self._interval)

    @staticmethod
    def _get_gpu_util() -> float:
        try:
            out = subprocess.check_output(
                ["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv,noheader,nounits"],
                text=True, timeout=2,
            )
            return float(out.strip().split("\n")[0])
        except Exception:
            return -1

    def stop(self) -> dict:
        self._stop.set()
        if self._thread:
            self._thread.join(timeout=3)
        if not self.samples:
            return {}
        vrams = [s["vram_mb"] for s in self.samples]
        utils = [s["gpu_util"] for s in self.samples if s["gpu_util"] >= 0]
        cpus = [s["cpu_pct"] for s in self.samples]
        return {
            "vram_avg_mb": sum(vrams) / len(vrams),
            "vram_peak_mb": max(vrams),
            "gpu_util_avg": sum(utils) / len(utils) if utils else 0,
            "gpu_util_peak": max(utils) if utils else 0,
            "cpu_avg_pct": sum(cpus) / len(cpus) if cpus else 0,
            "n_samples": len(self.samples),
        }


def get_gpu_metrics() -> GPUMetrics:
    m = GPUMetrics()
    if not torch.cuda.is_available():
        m.name = "CPU"
        return m

    props = torch.cuda.get_device_properties(0)
    m.name = props.name
    m.vram_total_mb = props.total_memory / 1e6
    m.cuda_version = torch.version.cuda or ""

    # TFLOPS estimate from SMs and clock
    try:
        out = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=driver_version,clocks.max.sm",
             "--format=csv,noheader,nounits"],
            text=True, timeout=5,
        )
        parts = out.strip().split(", ")
        m.driver_version = parts[0]
        sm_clock_mhz = float(parts[1])
        # FP32 TFLOPS = SMs * FP32_cores_per_SM * 2 * clock_GHz / 1000
        # Approximate: 128 FP32 cores per SM for Ampere/Ada
        cores_per_sm = 128
        n_sms = props.multi_processor_count
        m.tflops_fp32 = n_sms * cores_per_sm * 2 * sm_clock_mhz / 1e6
    except Exception:
        pass

    return m


def list_r2_videos(r2_client, bucket: str, max_videos: int) -> list[dict]:
    """List videos directly from R2 for standalone benchmarking (no Supabase needed)."""
    resp = r2_client._client.list_objects_v2(Bucket=bucket, MaxKeys=max_videos * 3)
    videos = []
    seen_ids = set()
    for obj in resp.get("Contents", []):
        key = obj["Key"]
        if not key.endswith((".webm", ".mp4")):
            continue
        video_id = key.split("_")[0] if "_" in key else key.rsplit(".", 1)[0]
        if video_id in seen_ids:
            continue
        seen_ids.add(video_id)
        lang = "english" if "english" in bucket else "indic"
        videos.append({"video_id": video_id, "language": lang, "duration_min": 0, "_r2_key": key})
        if len(videos) >= max_videos:
            break
    return videos


def run_benchmark(cfg: PipelineConfig, bench_cfg: BenchmarkConfig) -> BenchmarkResult:
    result = BenchmarkResult()
    result.gpu = get_gpu_metrics()
    result.config = bench_cfg
    result.cpu_count = psutil.cpu_count()
    result.ram_gb = psutil.virtual_memory().total / 1e9

    cfg.codec.xcodec_batch_size = bench_cfg.xcodec_batch_size
    if bench_cfg.custom_ckpt:
        cfg.codec.xcodec2_custom_ckpt = bench_cfg.custom_ckpt

    r2 = R2Client(cfg.r2)

    standalone = bench_cfg.standalone
    orch = None
    if not standalone:
        orch = SupabaseOrchestrator(cfg.supabase)
        orch.ensure_tables()

    # Load encoder
    logger.info("Loading encoder (BS=%d, parallel=%s)...", bench_cfg.xcodec_batch_size, bench_cfg.parallel)
    encoder = HotEncoder(cfg.codec)
    encoder.load()
    result.gpu.vram_model_mb = torch.cuda.memory_allocated() / 1e6

    # Get videos: either claim from Supabase or list from R2
    if standalone:
        bucket = bench_cfg.standalone_bucket or "pt-english"
        logger.info("Standalone mode: listing %d videos from R2 bucket '%s'",
                     bench_cfg.num_videos, bucket)
        videos_to_process = list_r2_videos(r2, bucket, bench_cfg.num_videos)
    else:
        videos_to_process = []
        for _ in range(bench_cfg.num_videos):
            video = orch.claim_video("benchmark_worker")
            if video is None:
                logger.warning("No more PENDING videos to claim")
                break
            videos_to_process.append(video)

    if not videos_to_process:
        logger.error("No videos available for benchmarking!")
        return result

    logger.info("Got %d videos for benchmarking", len(videos_to_process))

    tmp_dir = Path("/tmp/benchmark")
    tmp_dir.mkdir(parents=True, exist_ok=True)
    monitor = GPUMonitor(interval_s=0.5)
    monitor.start()

    shard_buffer = ShardBuffer()

    for vi, video in enumerate(videos_to_process):
        video_id = video["video_id"]
        language = video.get("language", "unknown")
        bucket = "pt-english" if language == "english" else "pt-indic"
        timings = StageTimings()

        logger.info("[%d/%d] Processing %s (%s, %.1f min)...",
                    vi + 1, len(videos_to_process), video_id, language, video.get("duration_min", 0))

        # Stage 1: Download
        t = time.perf_counter()
        video_path = r2.download_video(video_id, bucket, tmp_dir / "videos")
        timings.download_s = time.perf_counter() - t

        if not video_path:
            if orch:
                orch.mark_video_failed(video_id, "Download failed")
            continue

        # Stage 2: Extract audio
        t = time.perf_counter()
        audio_path = tmp_dir / "audio" / f"{video_id}.wav"
        extract_audio_from_video(video_path, audio_path, 16_000)
        timings.extract_s = time.perf_counter() - t

        # Stage 3: VAD
        t = time.perf_counter()
        segments = segment_audio(audio_path, cfg.vad)
        timings.vad_s = time.perf_counter() - t

        if not segments:
            if orch:
                orch.update_video_status(video_id, "DONE", {"num_segments": 0, "usable_audio_s": 0})
            continue

        usable_audio = sum(s.duration_s for s in segments)

        # Stage 4: Encode (GPU)
        torch.cuda.reset_peak_memory_stats()
        t = time.perf_counter()
        encoded = encoder.encode_segments(segments, parallel=bench_cfg.parallel)
        torch.cuda.synchronize()
        timings.encode_s = time.perf_counter() - t

        encode_rtf = usable_audio / max(timings.encode_s, 0.001)
        vram_peak = torch.cuda.max_memory_allocated() / 1e6

        # Stage 5: Pack shard
        vt = VideoTokens(video_id=video_id, language=language,
                         duration_s=usable_audio, segments=encoded, usable_audio_s=usable_audio)
        shard_buffer.add(vt)

        t = time.perf_counter()
        shard_id, shard_path, shard_size = pack_shard(shard_buffer, tmp_dir / "shards")
        timings.pack_s = time.perf_counter() - t

        # Stage 6: Upload shard (skip in standalone to avoid polluting R2)
        if not standalone:
            t = time.perf_counter()
            r2_key = f"benchmark-shards/{shard_id}.tar.zst"
            r2.upload_shard(shard_path, r2_key)
            timings.upload_s = time.perf_counter() - t
        shard_buffer.clear()

        overall_rtf = usable_audio / max(timings.total_s, 0.001)
        if vram_peak > result.gpu.vram_peak_mb:
            result.gpu.vram_peak_mb = vram_peak

        vr = VideoResult(
            video_id=video_id,
            language=language,
            duration_min=video.get("duration_min", 0),
            num_segments=len(segments),
            usable_audio_s=usable_audio,
            timings=timings,
            encode_rtf=encode_rtf,
            overall_rtf=overall_rtf,
        )
        result.videos.append(vr)

        logger.info(
            "  %s: %.0fs audio, %d segs | dl=%.1f ext=%.1f vad=%.1f enc=%.1f pk=%.1f up=%.1f | "
            "enc_RTF=%.1fx overall_RTF=%.1fx VRAM_peak=%.0fMB",
            video_id, usable_audio, len(segments),
            timings.download_s, timings.extract_s, timings.vad_s,
            timings.encode_s, timings.pack_s, timings.upload_s,
            encode_rtf, overall_rtf, vram_peak,
        )

        if orch:
            orch.mark_video_done(video_id, shard_id, ["xcodec2_fast", "bicodec_fast"])

        # Cleanup
        for p in [video_path, audio_path, shard_path]:
            try:
                p.unlink(missing_ok=True)
            except Exception:
                pass

    gpu_stats = monitor.stop()
    result.avg_cpu_pct = gpu_stats.get("cpu_avg_pct", 0)

    return result


def run_benchmark_async(cfg: PipelineConfig, bench_cfg: BenchmarkConfig) -> BenchmarkResult:
    """Async benchmark: measures wall-clock throughput with overlapping stages.

    Download + Extract + VAD all run in parallel background workers.
    GPU drains the ready queue. Measures how fast the full pipeline actually runs.
    """
    from codecbench.pipeline.async_pipeline import AsyncPipeline
    from codecbench.pipeline.r2_client import extract_audio_pipe, normalize_audio_peak
    from codecbench.pipeline.vad import segment_tensor

    result = BenchmarkResult()
    result.gpu = get_gpu_metrics()
    result.config = bench_cfg
    result.cpu_count = psutil.cpu_count()
    result.ram_gb = psutil.virtual_memory().total / 1e9

    cfg.codec.xcodec_batch_size = bench_cfg.xcodec_batch_size
    if bench_cfg.custom_ckpt:
        cfg.codec.xcodec2_custom_ckpt = bench_cfg.custom_ckpt

    r2 = R2Client(cfg.r2)

    logger.info("Loading encoder (BS=%d, parallel=%s)...", bench_cfg.xcodec_batch_size, bench_cfg.parallel)
    encoder = HotEncoder(cfg.codec)
    encoder.load()
    result.gpu.vram_model_mb = torch.cuda.memory_allocated() / 1e6

    # List videos from R2
    bucket = bench_cfg.standalone_bucket or "pt-english"
    logger.info("Async benchmark: listing %d videos from R2 '%s' (extract_workers=%d)",
                bench_cfg.num_videos, bucket, cfg.worker.extract_workers)
    videos = list_r2_videos(r2, bucket, bench_cfg.num_videos)
    if not videos:
        logger.error("No videos available!")
        return result

    # Start async pipeline
    pipeline = AsyncPipeline(cfg, r2)
    pipeline.start()

    # Submit all videos
    for v in videos:
        pipeline.submit(v["video_id"], v["language"], bucket)
    pipeline.drain()

    monitor = GPUMonitor(interval_s=0.5)
    monitor.start()
    shard_buffer = ShardBuffer()
    wall_start = time.perf_counter()

    videos_done = 0
    while True:
        prepared = pipeline.get_ready(timeout=300)
        if prepared is None:
            break

        # GPU encode
        torch.cuda.reset_peak_memory_stats()
        t0 = time.perf_counter()
        encoded = encoder.encode_segments(prepared.segments, parallel=bench_cfg.parallel)
        torch.cuda.synchronize()
        encode_time = time.perf_counter() - t0

        vram_peak = torch.cuda.max_memory_allocated() / 1e6

        # Pack shard (local only, no upload)
        vt = VideoTokens(video_id=prepared.video_id, language=prepared.language,
                         duration_s=prepared.usable_audio_s, segments=encoded,
                         usable_audio_s=prepared.usable_audio_s)
        shard_buffer.add(vt)
        t_pk = time.perf_counter()
        shard_id, shard_path, _ = pack_shard(shard_buffer, Path("/tmp/benchmark/shards"))
        pack_time = time.perf_counter() - t_pk
        shard_buffer.clear()
        shard_path.unlink(missing_ok=True)

        encode_rtf = prepared.usable_audio_s / max(encode_time, 0.001)
        wall_so_far = time.perf_counter() - wall_start
        overall_rtf = prepared.usable_audio_s / max(
            prepared.extract_time_s + prepared.vad_time_s + encode_time, 0.001
        )

        if vram_peak > result.gpu.vram_peak_mb:
            result.gpu.vram_peak_mb = vram_peak

        timings = StageTimings(
            download_s=0,  # overlapped
            extract_s=prepared.extract_time_s,
            vad_s=prepared.vad_time_s,
            encode_s=encode_time,
            pack_s=pack_time,
            upload_s=0,
        )

        vr = VideoResult(
            video_id=prepared.video_id,
            language=prepared.language,
            duration_min=0,
            num_segments=prepared.num_segments,
            usable_audio_s=prepared.usable_audio_s,
            timings=timings,
            encode_rtf=encode_rtf,
            overall_rtf=overall_rtf,
        )
        result.videos.append(vr)
        videos_done += 1

        logger.info(
            "  [%d/%d] %s: %.0fs audio, %d segs | ext=%.1f vad=%.1f enc=%.1f | "
            "enc_RTF=%.0fx | ready_q=%d | wall=%.1fs",
            videos_done, len(videos), prepared.video_id,
            prepared.usable_audio_s, prepared.num_segments,
            prepared.extract_time_s, prepared.vad_time_s, encode_time,
            encode_rtf, pipeline.ready_count, wall_so_far,
        )

        if prepared.video_path:
            prepared.video_path.unlink(missing_ok=True)

    wall_total = time.perf_counter() - wall_start
    gpu_stats = monitor.stop()
    result.avg_cpu_pct = gpu_stats.get("cpu_avg_pct", 0)

    pipeline.stop()

    # Attach true wall-clock metrics (stage sums are misleading for async)
    total_audio = result.total_audio_s
    async_rtf = total_audio / max(wall_total, 0.001)
    result._async_wall_s = wall_total
    result._async_rtf = async_rtf
    result._async_gpu_util_avg = gpu_stats.get("gpu_util_avg", 0)
    result._async_gpu_util_peak = gpu_stats.get("gpu_util_peak", 0)

    print(f"\n  ASYNC PIPELINE SUMMARY")
    print(f"  Wall clock:       {wall_total:.1f}s for {total_audio:.0f}s audio ({total_audio/3600:.2f} hrs)")
    print(f"  Pipeline RTF:     {async_rtf:.1f}x (wall-clock, all stages overlapping)")
    print(f"  Extract workers:  {cfg.worker.extract_workers}")
    print(f"  GPU util avg:     {gpu_stats.get('gpu_util_avg', 0):.1f}%")
    print(f"  Pipeline stats:   extract={pipeline.stats.total_extract_s:.1f}s "
          f"vad={pipeline.stats.total_vad_s:.1f}s failed={pipeline.stats.videos_failed}")

    return result


def print_report(r: BenchmarkResult) -> None:
    totals = r.stage_totals()
    print(f"\n{'='*80}")
    print(f"  GPU ENCODING BENCHMARK REPORT")
    print(f"{'='*80}")
    print(f"  GPU:           {r.gpu.name}")
    print(f"  VRAM:          {r.gpu.vram_total_mb/1024:.1f} GB total, {r.gpu.vram_model_mb:.0f} MB models, {r.gpu.vram_peak_mb:.0f} MB peak")
    print(f"  FP32 TFLOPS:   {r.gpu.tflops_fp32:.1f}")
    print(f"  CUDA:          {r.gpu.cuda_version}  Driver: {r.gpu.driver_version}")
    print(f"  CPU:           {r.cpu_count} cores, {r.ram_gb:.0f} GB RAM")
    print(f"  Config:        XCodec BS={r.config.xcodec_batch_size}, parallel={r.config.parallel}")
    print(f"{'='*80}")

    print(f"\n  STAGE TIMING BREAKDOWN (totals across {len(r.videos)} videos)")
    print(f"  {'Stage':<15} {'Total (s)':>10} {'Avg/video':>10} {'% of total':>10}")
    print(f"  {'-'*50}")
    stages = [
        ("Download", totals.download_s),
        ("Extract", totals.extract_s),
        ("VAD", totals.vad_s),
        ("Encode (GPU)", totals.encode_s),
        ("Pack", totals.pack_s),
        ("Upload", totals.upload_s),
    ]
    for name, val in stages:
        pct = 100 * val / max(totals.total_s, 0.001)
        avg = val / max(len(r.videos), 1)
        print(f"  {name:<15} {val:>10.1f} {avg:>10.1f} {pct:>9.1f}%")
    print(f"  {'-'*50}")
    print(f"  {'TOTAL':<15} {totals.total_s:>10.1f} {totals.total_s/max(len(r.videos),1):>10.1f}")

    print(f"\n  GPU IDLE TIME (waiting for CPU/IO stages)")
    print(f"  GPU active:   {totals.encode_s:.1f}s ({100*totals.encode_s/max(totals.total_s,0.001):.1f}%)")
    print(f"  GPU idle:     {totals.cpu_active_s:.1f}s ({100*totals.cpu_active_s/max(totals.total_s,0.001):.1f}%)")
    print(f"  Avg CPU:      {r.avg_cpu_pct:.1f}%")

    print(f"\n  THROUGHPUT")
    print(f"  Total audio:      {r.total_audio_s:.0f}s ({r.total_audio_s/3600:.2f} hrs)")
    print(f"  Wall time:        {r.total_wall_s:.0f}s")
    print(f"  Encode-only RTF:  {r.avg_encode_rtf:.1f}x (GPU encoding only)")
    print(f"  Overall RTF:      {r.avg_overall_rtf:.1f}x (full pipeline incl I/O)")
    print(f"  1hr audio in:     {3600/max(r.avg_overall_rtf,0.001):.0f}s")
    est_1m = (1_000_000 * 3600) / max(r.avg_overall_rtf, 0.001) / 86400
    print(f"  1M hrs on 1 GPU:  {est_1m:.0f} days")
    print(f"  1M hrs on 100 GPU: {est_1m/100:.1f} days")

    print(f"\n  PER-VIDEO DETAILS")
    print(f"  {'video_id':<15} {'lang':>8} {'segs':>5} {'audio_s':>8} {'enc_RTF':>8} {'total_RTF':>9} {'dl':>5} {'ext':>5} {'vad':>5} {'enc':>5} {'pk':>5} {'up':>5}")
    print(f"  {'-'*95}")
    for v in r.videos:
        t = v.timings
        print(f"  {v.video_id:<15} {v.language:>8} {v.num_segments:>5} {v.usable_audio_s:>8.0f} {v.encode_rtf:>8.1f} {v.overall_rtf:>9.1f} "
              f"{t.download_s:>5.1f} {t.extract_s:>5.1f} {t.vad_s:>5.1f} {t.encode_s:>5.1f} {t.pack_s:>5.1f} {t.upload_s:>5.1f}")
    print(f"{'='*80}\n")


def save_results(r: BenchmarkResult, path: Path) -> None:
    """Save benchmark results as JSON for cross-GPU comparison."""
    data = {
        "gpu": asdict(r.gpu),
        "config": asdict(r.config),
        "system": {"cpu_count": r.cpu_count, "ram_gb": r.ram_gb, "avg_cpu_pct": r.avg_cpu_pct},
        "summary": {
            "num_videos": len(r.videos),
            "total_audio_s": r.total_audio_s,
            "total_wall_s": r.total_wall_s,
            "avg_encode_rtf": r.avg_encode_rtf,
            "avg_overall_rtf": r.avg_overall_rtf,
            "avg_gpu_idle_pct": r.avg_gpu_idle_pct,
            "vram_model_mb": r.gpu.vram_model_mb,
            "vram_peak_mb": r.gpu.vram_peak_mb,
        },
        "stage_totals": asdict(r.stage_totals()),
        "videos": [
            {**asdict(v), "timings": asdict(v.timings)}
            for v in r.videos
        ],
    }
    if hasattr(r, "_async_wall_s"):
        data["async_pipeline"] = {
            "wall_clock_s": r._async_wall_s,
            "pipeline_rtf": r._async_rtf,
            "gpu_util_avg_pct": r._async_gpu_util_avg,
            "gpu_util_peak_pct": r._async_gpu_util_peak,
        }
        data["summary"]["pipeline_wall_s"] = r._async_wall_s
        data["summary"]["pipeline_rtf"] = r._async_rtf
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(data, f, indent=2)
    logger.info("Results saved to %s", path)


def main():
    import argparse
    parser = argparse.ArgumentParser(description="GPU encoding benchmark")
    parser.add_argument("--num-videos", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--no-parallel", action="store_true")
    parser.add_argument("--custom-ckpt", type=str, default=None)
    parser.add_argument("--standalone", action="store_true",
                        help="List videos from R2 directly (no Supabase needed)")
    parser.add_argument("--bucket", type=str, default="pt-english",
                        help="R2 bucket for standalone mode (default: pt-english)")
    parser.add_argument("--async-pipeline", action="store_true",
                        help="Use async 3-stage pipeline (overlapping download+extract+encode)")
    parser.add_argument("--extract-workers", type=int, default=4,
                        help="Number of parallel ffmpeg+VAD workers for async mode")
    parser.add_argument("--output", type=str, default=None,
                        help="JSON output path (default: results/bench_<gpu>.json)")
    args = parser.parse_args()

    cfg = PipelineConfig.from_env()
    if args.async_pipeline:
        cfg.worker.extract_workers = args.extract_workers
        cfg.worker.use_async_pipeline = True

    bench_cfg = BenchmarkConfig(
        xcodec_batch_size=args.batch_size,
        parallel=not args.no_parallel,
        num_videos=args.num_videos,
        custom_ckpt=args.custom_ckpt,
        standalone=args.standalone or args.async_pipeline,
        standalone_bucket=args.bucket,
    )

    if args.async_pipeline:
        result = run_benchmark_async(cfg, bench_cfg)
    else:
        result = run_benchmark(cfg, bench_cfg)
    print_report(result)

    gpu_slug = result.gpu.name.replace(" ", "_")
    out_path = Path(args.output) if args.output else Path(f"results/bench_{gpu_slug}_bs{args.batch_size}.json")
    save_results(result, out_path)


if __name__ == "__main__":
    main()
