import time
import torch
import soundfile as sf
from voxcpm import VoxCPM

def get_gpu_mem():
    return torch.cuda.memory_allocated() / 1024**3, torch.cuda.max_memory_allocated() / 1024**3

print("=" * 60)
print("VoxCPM 1.5 Benchmark on A100-80GB")
print("=" * 60)

torch.cuda.reset_peak_memory_stats()
mem_before = get_gpu_mem()
print(f"\nGPU memory before load: {mem_before[0]:.2f} GB")

print("\nLoading model...")
t0 = time.time()
model = VoxCPM.from_pretrained("openbmb/VoxCPM1.5")
load_time = time.time() - t0
mem_after_load = get_gpu_mem()
print(f"Load time: {load_time:.1f}s")
print(f"GPU memory after load: {mem_after_load[0]:.2f} GB (peak: {mem_after_load[1]:.2f} GB)")
model_mem = mem_after_load[0] - mem_before[0]
print(f"Model VRAM: {model_mem:.2f} GB")

print("\nWarmup run (includes torch.compile)...")
t0 = time.time()
_ = model.generate(text="Warmup test.", cfg_value=2.0, inference_timesteps=10, normalize=False, denoise=False)
warmup_time = time.time() - t0
print(f"Warmup time: {warmup_time:.1f}s")

torch.cuda.reset_peak_memory_stats()

tests = [
    ("Short English", "Hello, how are you doing today?", None, None),
    ("Medium English", "VoxCPM is an innovative text to speech model that generates highly expressive and natural sounding speech from text input.", None, None),
    ("Long English", "Artificial intelligence has transformed the way we interact with technology. From voice assistants to autonomous vehicles, AI systems are becoming increasingly sophisticated. The development of large language models has opened up new possibilities in natural language processing, enabling machines to understand and generate human-like text with remarkable accuracy.", None, None),
    ("Hindi (no clone)", "नमस्कार, मैं आज आपसे बात करना चाहता हूँ। भारत एक महान देश है और हम सब मिलकर इसे और भी आगे ले जा सकते हैं।", None, None),
    ("Hindi (voice clone)", "नमस्कार, मैं आज आपसे बात करना चाहता हूँ। भारत एक महान देश है।", "/home/ubuntu/modi_clip_5s.wav", "मैं पवित्र महीने के लिए सभी को शुभकामनाएं देता हूँ।"),
]

print("\n" + "=" * 60)
print(f"{'Test':<22} {'Gen(s)':>7} {'Audio(s)':>8} {'RTF':>7} {'Peak VRAM':>10}")
print("-" * 60)

for name, text, prompt_wav, prompt_text in tests:
    torch.cuda.reset_peak_memory_stats()
    
    t0 = time.time()
    wav = model.generate(
        text=text,
        prompt_wav_path=prompt_wav,
        prompt_text=prompt_text,
        cfg_value=2.0,
        inference_timesteps=10,
        normalize=False,
        denoise=False,
        retry_badcase=True,
        retry_badcase_max_times=3,
        retry_badcase_ratio_threshold=6.0,
    )
    gen_time = time.time() - t0
    
    audio_duration = len(wav) / model.tts_model.sample_rate
    rtf = gen_time / audio_duration
    peak_vram = torch.cuda.max_memory_allocated() / 1024**3
    
    print(f"{name:<22} {gen_time:>7.2f} {audio_duration:>8.2f} {rtf:>7.3f} {peak_vram:>9.2f}G")

print("=" * 60)

print(f"\n--- Concurrency Estimates (A100 80GB) ---")
peak_total = torch.cuda.max_memory_allocated() / 1024**3
print(f"Peak VRAM per inference: {peak_total:.2f} GB")
print(f"Model base VRAM: {model_mem:.2f} GB")
inference_overhead = peak_total - model_mem
print(f"Per-inference overhead: {inference_overhead:.2f} GB")
# Model is shared, each concurrent stream adds overhead
for target in [1, 2, 4, 8]:
    total_needed = model_mem + (inference_overhead * target)
    fits = "YES" if total_needed <= 78 else "NO"
    print(f"  {target} concurrent streams: ~{total_needed:.1f} GB needed [{fits}]")

print(f"\nNote: These are sequential benchmarks. True concurrent serving")
print(f"would need batched inference or multiple model instances.")
