"""Benchmark NanoCodec (Magpie TTS codec): encode → decode round-trip on Modi audio."""

import time
import torch
import torchaudio
import numpy as np
from pathlib import Path

SEGMENTS_DIR = Path("/home/ubuntu/modi_processed/segments/001_ERW9i1lwnBw")
DEVICE = "cuda"
NANOCODEC_ID = "nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps"  # The one Magpie TTS uses


def load_model():
    from nemo.collections.tts.models import AudioCodecModel

    print(f"Loading NanoCodec: {NANOCODEC_ID}")
    model = AudioCodecModel.from_pretrained(NANOCODEC_ID)
    model.eval().to(DEVICE)

    print("Warming up...")
    dummy = torch.randn(1, 22050 * 6, device=DEVICE)
    dummy_len = torch.tensor([22050 * 6], device=DEVICE)
    for _ in range(3):
        with torch.no_grad():
            codes, codes_len = model.encode(audio=dummy, audio_len=dummy_len)
            audio, audio_len = model.decode(tokens=codes, tokens_len=codes_len)
    torch.cuda.synchronize()

    vq = model.vector_quantizer
    print(f"NanoCodec ready. SR={model.sample_rate}, num_groups(codebooks)={vq.num_groups}, "
          f"samples_per_frame={model.samples_per_frame}, FPS={model.sample_rate/model.samples_per_frame:.1f}\n")
    return model


def benchmark_segment(model, wav_path: Path):
    wav, sr = torchaudio.load(str(wav_path))

    target_sr = model.sample_rate
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
        sr = target_sr

    duration_s = wav.shape[1] / sr
    wav_gpu = wav.squeeze(0).to(DEVICE)  # [T]
    wav_len = torch.tensor([wav_gpu.shape[0]], device=DEVICE)

    # Encode
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with torch.no_grad():
        codes, codes_len = model.encode(audio=wav_gpu.unsqueeze(0), audio_len=wav_len)
    torch.cuda.synchronize()
    encode_ms = (time.perf_counter() - t0) * 1000

    n_tokens = codes.shape[-1]
    n_codebooks = codes.shape[1]

    # Decode
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with torch.no_grad():
        recon_audio, recon_len = model.decode(tokens=codes, tokens_len=codes_len)
    torch.cuda.synchronize()
    decode_ms = (time.perf_counter() - t0) * 1000

    recon_wav = recon_audio.squeeze().cpu()
    orig_wav = wav.squeeze().cpu()

    # Trim to same length
    min_len = min(orig_wav.shape[0], recon_wav.shape[0])
    orig_wav = orig_wav[:min_len]
    recon_wav = recon_wav[:min_len]

    # SI-SDR
    orig_np = orig_wav.numpy().astype(np.float64)
    recon_np = recon_wav.numpy().astype(np.float64)
    dot = np.dot(orig_np, recon_np)
    s_target = dot * orig_np / (np.dot(orig_np, orig_np) + 1e-9)
    e_noise = recon_np - s_target
    si_sdr = 10 * np.log10(np.dot(s_target, s_target) / (np.dot(e_noise, e_noise) + 1e-9))

    # Save
    out_path = wav_path.parent.parent.parent / f"recon_nanocodec_{wav_path.stem}.wav"
    torchaudio.save(str(out_path), recon_wav.unsqueeze(0), target_sr)

    return {
        "file": wav_path.name,
        "duration_s": duration_s,
        "n_tokens": n_tokens,
        "n_codebooks": n_codebooks,
        "encode_ms": encode_ms,
        "decode_ms": decode_ms,
        "encode_rtf": encode_ms / (duration_s * 1000),
        "decode_rtf": decode_ms / (duration_s * 1000),
        "si_sdr_db": si_sdr,
        "recon_path": str(out_path),
    }


def main():
    model = load_model()

    segments = sorted(SEGMENTS_DIR.glob("seg_*.wav"))[:5]
    print(f"Benchmarking {len(segments)} segments from {SEGMENTS_DIR.name}")
    print(f"{'File':<16} {'Dur(s)':>7} {'Tokens':>7} {'CB':>3} {'Enc(ms)':>9} {'Dec(ms)':>9} {'Enc RTF':>8} {'Dec RTF':>8} {'SI-SDR':>8}")
    print("-" * 95)

    results = []
    for seg_path in segments:
        r = benchmark_segment(model, seg_path)
        results.append(r)
        print(f"{r['file']:<16} {r['duration_s']:>7.2f} {r['n_tokens']:>7} {r['n_codebooks']:>3} {r['encode_ms']:>9.1f} {r['decode_ms']:>9.1f} {r['encode_rtf']:>8.4f} {r['decode_rtf']:>8.4f} {r['si_sdr_db']:>8.2f}")

    print("-" * 95)
    avg_enc = np.mean([r["encode_ms"] for r in results])
    avg_dec = np.mean([r["decode_ms"] for r in results])
    avg_dur = np.mean([r["duration_s"] for r in results])
    avg_si = np.mean([r["si_sdr_db"] for r in results])
    avg_enc_rtf = np.mean([r["encode_rtf"] for r in results])
    avg_dec_rtf = np.mean([r["decode_rtf"] for r in results])
    print(f"{'AVERAGE':<16} {avg_dur:>7.2f} {'':>7} {'':>3} {avg_enc:>9.1f} {avg_dec:>9.1f} {avg_enc_rtf:>8.4f} {avg_dec_rtf:>8.4f} {avg_si:>8.2f}")

    print(f"\nReconstructed audio saved to: {results[0]['recon_path'].rsplit('/', 1)[0]}/")
    print(f"\n=== COMPARISON SUMMARY ===")
    print(f"NanoCodec:  SR=22kHz, {results[0]['n_codebooks']} codebooks, {results[0]['n_tokens']/results[0]['duration_s']:.1f} FPS, SI-SDR={avg_si:.2f} dB")
    print(f"XCodec:     SR=16kHz, 1 codebook,  50.0 FPS, SI-SDR=3.00 dB (from previous run)")


if __name__ == "__main__":
    main()
