#!/usr/bin/env python3
"""Phase 2: Benchmark decode-only throughput on converted 16kHz corpus.

Measures audio-sec/sec for different worker counts.

Usage:
  python3 tools/phase2_bench_decode_only.py --workers 4,8,16,24,32
"""

import argparse
import io
import json
import os
import sys
import tarfile
import time
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

import pyarrow.parquet as pq
import soundfile

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

ARTIFACTS_DIR = Path("/workspace/maya-asr/artifacts/phase2")
INVENTORY_PATH = ARTIFACTS_DIR / "shard_inventory.parquet"


def bench_shard(tar_path_str: str) -> dict:
    """Decode all audio in a tar, return timing."""
    tar_path = Path(tar_path_str)
    if not tar_path.exists():
        return {"audio_secs": 0, "elapsed": 0, "count": 0}

    t0 = time.time()
    audio_secs = 0.0
    count = 0
    with tarfile.open(tar_path, "r") as tf:
        for m in tf:
            if not m.isfile():
                continue
            f = tf.extractfile(m)
            data, sr = soundfile.read(io.BytesIO(f.read()))
            audio_secs += len(data) / sr
            count += 1
    elapsed = time.time() - t0
    return {"audio_secs": audio_secs, "elapsed": elapsed, "count": count}


def main():
    parser = argparse.ArgumentParser(description="Decode-only benchmark")
    parser.add_argument("--workers", type=str, default="4,8,16,24,32")
    parser.add_argument("--shards", type=int, default=20, help="Shards to bench")
    args = parser.parse_args()

    if not INVENTORY_PATH.exists():
        print("ERROR: Run inventory first", file=sys.stderr)
        sys.exit(1)

    inv_df = pq.read_table(INVENTORY_PATH).to_pandas()
    # Pick a mix of converted and original shards
    tar_paths = [str(Path(r["shard_dir"]) / "audio.tar") for _, r in inv_df.iterrows()]
    tar_paths = [p for p in tar_paths if Path(p).exists()]
    tar_paths = tar_paths[: args.shards]

    worker_counts = [int(w) for w in args.workers.split(",")]
    results = []

    print(f"Benchmarking decode-only on {len(tar_paths)} shards")
    print(f"Worker sweep: {worker_counts}")
    print()

    for nw in worker_counts:
        t0 = time.time()
        total_audio = 0.0
        total_count = 0

        with ProcessPoolExecutor(max_workers=nw) as executor:
            for result in executor.map(bench_shard, tar_paths):
                total_audio += result["audio_secs"]
                total_count += result["count"]

        elapsed = time.time() - t0
        rate = total_audio / max(elapsed, 0.01)
        results.append(
            {
                "workers": nw,
                "shards": len(tar_paths),
                "samples": total_count,
                "audio_hours": round(total_audio / 3600, 2),
                "elapsed_s": round(elapsed, 1),
                "audio_sec_per_sec": round(rate, 0),
            }
        )
        print(f"  workers={nw:>3}: {rate:>6.0f}x realtime, {elapsed:.1f}s")

    # Write results
    out_path = ARTIFACTS_DIR / "decode_only_bench_results.json"
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults: {out_path}")


if __name__ == "__main__":
    main()
