#!/usr/bin/env python3
"""Download ~1hr audio, segment into 6s chunks, benchmark encode/decode,
and produce scaling estimates for production processing.

Usage:
    python scripts/download_and_bench.py [--url URL] [--codecs xcodec2 snac]
"""
from __future__ import annotations
import argparse, json, logging, math, os, subprocess, sys, time
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
import torchaudio

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s", datefmt="%H:%M:%S")
log = logging.getLogger("bench_1hr")

ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = ROOT / "data"
RAW_DIR = DATA_DIR / "raw"
RESULTS_DIR = ROOT / "results"

def generate_long_audio(output_path: Path, duration_s: int = 3600) -> Path:
    wav_path = output_path.with_suffix(".wav")
    if wav_path.exists():
        log.info("Audio already exists: %s", wav_path)
        return wav_path
    sr = 48000
    log.info("Generating %ds synthetic audio at %dHz", duration_s, sr)
    chunk_duration = 60
    n_chunks = math.ceil(duration_s / chunk_duration)
    with sf.SoundFile(str(wav_path), mode='w', samplerate=sr, channels=1) as f:
        for i in range(n_chunks):
            chunk_len = min(chunk_duration, duration_s - i * chunk_duration)
            n_samples = int(chunk_len * sr)
            t = np.linspace(0, chunk_len, n_samples, dtype=np.float32)
            signal = np.zeros(n_samples, dtype=np.float32)
            f0 = 120 + 80 * np.sin(2 * np.pi * 0.3 * t)
            phase = np.cumsum(2 * np.pi * f0 / sr)
            signal += 0.3 * np.sin(phase)
            for h in [2, 3, 5, 7]:
                signal += (0.15 / h) * np.sin(h * phase)
            for freq in [500.0, 1500.0, 2500.0, 3500.0, 5000.0, 7000.0]:
                amp = 0.08 / (1 + freq / 2000)
                signal += amp * np.sin(2 * np.pi * freq * t + np.random.uniform(0, 2*np.pi))
            signal += 0.02 * np.random.randn(n_samples).astype(np.float32)
            envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 4.5 * t)
            signal *= envelope
            peak = np.abs(signal).max()
            if peak > 0:
                signal = signal / peak * 0.9
            f.write(signal)
            if (i + 1) % 10 == 0:
                log.info("  Generated %d/%d minutes", i + 1, n_chunks)
    log.info("Generated: %s (%.1f MB)", wav_path, wav_path.stat().st_size / 1e6)
    return wav_path

def download_audio(url, output_path):
    wav_path = output_path.with_suffix(".wav")
    if wav_path.exists():
        log.info("Audio already exists: %s", wav_path)
        return wav_path
    output_path.parent.mkdir(parents=True, exist_ok=True)
    if url:
        log.info("Downloading from YouTube: %s", url)
        cmd = [sys.executable, "-m", "yt_dlp", "-x", "--audio-format", "wav",
               "--audio-quality", "0", "-o", str(output_path) + ".%(ext)s", url]
        try:
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
            if result.returncode == 0:
                for ext in [".wav", ".opus", ".m4a", ".webm", ".mp3"]:
                    candidate = Path(str(output_path) + ext)
                    if candidate.exists():
                        if ext != ".wav":
                            subprocess.run(["ffmpeg", "-y", "-i", str(candidate),
                                "-ar", "48000", "-ac", "1", str(wav_path)],
                                capture_output=True, timeout=300)
                            candidate.unlink(missing_ok=True)
                        else:
                            if candidate != wav_path:
                                candidate.rename(wav_path)
                        if wav_path.exists():
                            return wav_path
            log.warning("yt-dlp failed, falling back to synthetic")
        except Exception as e:
            log.warning("Download error: %s, using synthetic", e)
    return generate_long_audio(output_path, duration_s=3600)

def get_audio_info(path):
    info = sf.info(str(path))
    return {"duration_s": info.duration, "sr": info.samplerate,
            "channels": info.channels, "size_mb": path.stat().st_size / 1e6}

def segment_audio(source, output_dir, target_sr, chunk_seconds=6.0):
    output_dir.mkdir(parents=True, exist_ok=True)
    existing = sorted(output_dir.glob("chunk_*.wav"))
    if existing:
        log.info("Chunks already exist in %s (%d files)", output_dir, len(existing))
        return existing
    log.info("Loading and resampling %s to %dHz...", source.name, target_sr)
    wav, sr = torchaudio.load(str(source))
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr:
        log.info("Resampling %d -> %d Hz...", sr, target_sr)
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    chunk_len = int(chunk_seconds * target_sr)
    n_chunks = wav.shape[-1] // chunk_len
    log.info("Segmenting into %d chunks of %.1fs", n_chunks, chunk_seconds)
    paths = []
    for i in range(n_chunks):
        chunk = wav[:, i*chunk_len:(i+1)*chunk_len]
        out = output_dir / f"chunk_{i:05d}.wav"
        torchaudio.save(str(out), chunk, target_sr)
        paths.append(out)
    log.info("Saved %d chunks to %s", len(paths), output_dir)
    return paths

def run_benchmark_on_chunks(codec_name, chunk_paths, target_sr, device="cuda", batch_size=16, chunk_seconds=6.0):
    from codecbench.codecs import get_codec
    from codecbench.bench.timer import CUDATimer, reset_vram_stats, measure_peak_vram
    from codecbench.metrics import log_mel_l1, multi_resolution_stft_loss, si_sdr, hf_energy_delta_db

    codec = get_codec(codec_name)
    codec.load(device=device)
    chunk_len = int(chunk_seconds * target_sr)
    total_chunks = len(chunk_paths)
    total_audio_s = total_chunks * chunk_seconds
    log.info("Benchmarking %s: %d chunks (%.0fmin), bs=%d", codec_name, total_chunks, total_audio_s/60, batch_size)

    all_wavs = []
    for p in chunk_paths:
        w, _ = torchaudio.load(str(p))
        if w.shape[-1] < chunk_len:
            w = torch.nn.functional.pad(w, (0, chunk_len - w.shape[-1]))
        else:
            w = w[:, :chunk_len]
        all_wavs.append(w)

    # Warmup
    log.info("Warming up %s...", codec_name)
    wb = torch.stack(all_wavs[:min(batch_size, len(all_wavs))], dim=0).to(device)
    for _ in range(5):
        with torch.inference_mode():
            tb = codec.encode(wb, target_sr)
            _ = codec.decode(tb)
    torch.cuda.synchronize()

    timer = CUDATimer(device)
    reset_vram_stats(device)
    encode_times, decode_times = [], []
    macc = {"mel_l1": [], "mrstft": [], "sisdr": [], "hf": []}
    n_batches = math.ceil(total_chunks / batch_size)
    log.info("Processing %d batches...", n_batches)
    wall_start = time.time()

    for bi in range(n_batches):
        s, e = bi * batch_size, min((bi+1) * batch_size, total_chunks)
        bw = torch.stack(all_wavs[s:e], dim=0).to(device)
        with torch.inference_mode():
            timer.record_start()
            tb = codec.encode(bw, target_sr)
            encode_times.append(timer.record_end())
            timer.record_start()
            recon = codec.decode(tb)
            decode_times.append(timer.record_end())
        if bi < 10:
            macc["mel_l1"].append(log_mel_l1(bw, recon, target_sr))
            macc["mrstft"].append(multi_resolution_stft_loss(bw, recon))
            macc["sisdr"].append(si_sdr(bw, recon))
            macc["hf"].append(hf_energy_delta_db(bw, recon, target_sr))
        if (bi+1) % 20 == 0 or bi == n_batches - 1:
            log.info("  [%s] %d/%d (%.0f%%) %.1fs", codec_name, bi+1, n_batches,
                     (bi+1)/n_batches*100, time.time()-wall_start)

    wall_total = time.time() - wall_start
    peak_vram = measure_peak_vram(device)
    te, td = sum(encode_times), sum(decode_times)

    result = {
        "codec": codec_name, "sr": target_sr,
        "total_chunks": total_chunks, "total_audio_seconds": round(total_audio_s, 1),
        "total_audio_minutes": round(total_audio_s / 60, 1),
        "batch_size": batch_size, "n_batches": n_batches,
        "total_encode_ms": round(te, 1), "total_decode_ms": round(td, 1),
        "total_e2e_ms": round(te + td, 1), "wall_seconds": round(wall_total, 2),
        "encode_per_chunk_ms": round(te / total_chunks, 2),
        "decode_per_chunk_ms": round(td / total_chunks, 2),
        "realtime_factor_encode": round(total_audio_s / (te / 1000), 1),
        "realtime_factor_decode": round(total_audio_s / (td / 1000), 1),
        "realtime_factor_e2e": round(total_audio_s / ((te + td) / 1000), 1),
        "peak_vram_mb": round(peak_vram / (1024**2), 1),
        "mel_l1": round(float(np.mean(macc["mel_l1"])), 6),
        "mrstft": round(float(np.mean(macc["mrstft"])), 6),
        "sisdr": round(float(np.mean(macc["sisdr"])), 2),
        "hf_energy_delta_db": round(float(np.mean(macc["hf"])), 2),
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu",
        "torch": torch.__version__,
    }
    del codec; torch.cuda.empty_cache()
    return result

def scaling_estimates(results):
    lines = []
    lines.append("=" * 78)
    lines.append("SCALING ESTIMATES FOR PRODUCTION TOKENIZATION")
    lines.append("=" * 78)
    dataset_hours = [100, 1000, 10000, 50000, 100000]
    gpu_counts = [1, 4, 8, 32, 100]

    for r in results:
        enc_per_hr = (r["total_encode_ms"] / 1000) / (r["total_audio_seconds"] / 3600)
        rtf = r["realtime_factor_encode"]
        lines.append("")
        lines.append(f"--- {r['codec'].upper()} (SR={r['sr']}) ---")
        lines.append(f"  Encode RTF: {rtf}x (1 GPU processes {rtf} hrs of audio per wall-hour)")
        lines.append(f"  Encode time for 1hr audio: {enc_per_hr:.1f}s")
        lines.append(f"  VRAM per worker: {r['peak_vram_mb']:.0f} MB")
        lines.append(f"  Quality: mel_l1={r['mel_l1']:.4f} sisdr={r['sisdr']:.1f}dB hf={r['hf_energy_delta_db']:.1f}dB")
        lines.append("")
        hdr = f"  {'Dataset':>10s}"
        for g in gpu_counts:
            hdr += f" | {g:>3d} GPU{'s' if g > 1 else ' ':4s}"
        lines.append(hdr)
        lines.append("  " + "-" * (len(hdr) - 2))
        for hrs in dataset_hours:
            total_s = hrs * enc_per_hr
            row = f"  {hrs:>8d}hr"
            for g in gpu_counts:
                t = total_s / g
                if t < 3600:
                    ts = f"{t/60:.0f}min"
                elif t < 86400:
                    ts = f"{t/3600:.1f}hr"
                else:
                    ts = f"{t/86400:.1f}d"
                row += f" | {ts:>9s}"
            lines.append(row)

    lines.append("")
    lines.append("=" * 78)
    lines.append("NOTES:")
    lines.append("- Estimates are encode-only (typical production workflow).")
    lines.append("- Assumes linear GPU scaling (no IO/network bottleneck).")
    lines.append("- Add ~20-30% overhead for data loading, upload, coordination.")
    lines.append("- XCodec2 does not batch efficiently; per-GPU throughput is limited.")
    lines.append("=" * 78)
    return "\n".join(lines)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--url", default=None)
    parser.add_argument("--codecs", nargs="+", default=["xcodec2", "snac"])
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--max-chunks", type=int, default=None)
    args = parser.parse_args()
    device = args.device
    if device == "cuda" and not torch.cuda.is_available():
        device = "cpu"

    raw_path = RAW_DIR / "podcast"
    audio_path = download_audio(args.url, raw_path)
    info = get_audio_info(audio_path)
    log.info("Audio: %.0fs (%.1fmin), sr=%d, %.1fMB",
             info["duration_s"], info["duration_s"]/60, info["sr"], info["size_mb"])

    codec_sr = {"xcodec2": 16000, "bicodec": 16000, "snac": 24000, "wavtokenizer": 24000}
    all_results = []
    for cn in args.codecs:
        sr = codec_sr.get(cn, 16000)
        chunks = segment_audio(audio_path, DATA_DIR / f"chunks_{sr//1000}k", sr)
        if args.max_chunks:
            chunks = chunks[:args.max_chunks]
        try:
            r = run_benchmark_on_chunks(cn, chunks, sr, device=device, batch_size=args.batch_size)
            all_results.append(r)
            log.info("\n%s: %.1fmin in %.1fs | enc %dx RT | dec %dx RT | VRAM %.0fMB",
                     cn.upper(), r["total_audio_minutes"], r["wall_seconds"],
                     r["realtime_factor_encode"], r["realtime_factor_decode"], r["peak_vram_mb"])
        except Exception as e:
            log.error("Failed %s: %s", cn, e)
            import traceback; traceback.print_exc()

    if not all_results:
        sys.exit(1)

    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    with open(RESULTS_DIR / "1hr_bench_results.jsonl", "w") as f:
        for r in all_results:
            f.write(json.dumps(r) + "\n")
    est = scaling_estimates(all_results)
    print("\n" + est)
    with open(RESULTS_DIR / "scaling_estimates.txt", "w") as f:
        f.write(est)
    log.info("Done. Results in %s", RESULTS_DIR)

if __name__ == "__main__":
    main()
