#!/usr/bin/env python3
"""Run throughput benchmarks for neural audio codecs.

Usage:
    python scripts/run_bench.py --codecs xcodec2 bicodec snac wavtokenizer \
        --bs 1 8 32 --seconds 6 --dtype bf16
"""

from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path

import torch

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("codecbench.bench")


def run_bench(args: argparse.Namespace) -> None:
    from codecbench.codecs import get_codec, CODEC_REGISTRY
    from codecbench.bench.runner import (
        BenchConfig,
        run_codec_suite,
        save_results_jsonl,
        RunResult,
    )
    from codecbench.reporting import generate_summary_csv, generate_report_md

    device = args.device
    if device == "cuda" and not torch.cuda.is_available():
        logger.warning("CUDA not available, falling back to CPU (timings will use wall clock)")
        device = "cpu"

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    config = BenchConfig(
        batch_sizes=args.bs,
        clip_seconds=args.seconds,
        warmup_iters=args.warmup,
        measure_iters=args.iters,
        dtypes=args.dtype,
    )

    all_results: list[RunResult] = []

    for codec_name in args.codecs:
        try:
            codec = get_codec(codec_name)
        except KeyError as e:
            logger.warning("Skipping %s: %s", codec_name, e)
            continue

        logger.info("=" * 60)
        logger.info("Loading %s", codec_name)
        try:
            codec.load(device=device, dtype=torch.float32)
        except Exception as e:
            logger.error("Failed to load %s: %s", codec_name, e)
            continue

        logger.info("Running benchmark suite for %s", codec_name)
        results = run_codec_suite(codec, config, device=device)
        all_results.extend(results)

        # Free model memory between codecs
        del codec
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    if not all_results:
        logger.error("No benchmark results produced")
        sys.exit(1)

    # Save outputs
    jsonl_path = output_dir / "bench_results.jsonl"
    save_results_jsonl(all_results, jsonl_path)

    csv_path = output_dir / "summary.csv"
    jsonl_data = [
        {
            "codec": r.codec, "sr": r.sr, "batch_size": r.batch_size,
            "clip_seconds": r.clip_seconds, "dtype": r.dtype,
            "encode_ms": r.encode_stats["mean_ms"],
            "encode_p50_ms": r.encode_stats["p50_ms"],
            "encode_p95_ms": r.encode_stats["p95_ms"],
            "decode_ms": r.decode_stats["mean_ms"],
            "decode_p50_ms": r.decode_stats["p50_ms"],
            "decode_p95_ms": r.decode_stats["p95_ms"],
            "e2e_ms": r.e2e_stats["mean_ms"],
            "tokens_per_sec": r.tokens_per_sec,
            "tokens_shape": r.tokens_shape,
            "peak_vram_mb": r.peak_vram_mb,
            "gpu_name": r.gpu_name,
            "torch": r.torch_version,
        }
        for r in all_results
    ]
    generate_summary_csv(jsonl_data, csv_path)

    report_path = output_dir / "report.md"
    generate_report_md(jsonl_data, report_path)

    logger.info("=" * 60)
    logger.info("Benchmark complete. Outputs:")
    logger.info("  JSONL:  %s", jsonl_path)
    logger.info("  CSV:    %s", csv_path)
    logger.info("  Report: %s", report_path)


def main() -> None:
    parser = argparse.ArgumentParser(description="CodecBench throughput benchmark")
    parser.add_argument("--codecs", nargs="+", default=["xcodec2", "snac", "wavtokenizer", "bicodec"])
    parser.add_argument("--bs", nargs="+", type=int, default=[1, 8, 32])
    parser.add_argument("--seconds", type=float, default=6.0)
    parser.add_argument("--dtype", nargs="+", default=["fp32", "bf16"])
    parser.add_argument("--warmup", type=int, default=10)
    parser.add_argument("--iters", type=int, default=50)
    parser.add_argument("--output-dir", default="results")
    parser.add_argument("--device", default="cuda")
    args = parser.parse_args()
    run_bench(args)


if __name__ == "__main__":
    main()
