"""Concurrency + TTFB benchmark with CUDA graph diffusion."""
import torch, time, threading, subprocess
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from cuda_graph_v2 import setup_cuda_graph_diffusion, patched_sample_speech_tokens

processor = VibeVoiceProcessor.from_pretrained("microsoft/VibeVoice-1.5B")
try:
    model = VibeVoiceForConditionalGenerationInference.from_pretrained(
        "microsoft/VibeVoice-1.5B", torch_dtype=torch.bfloat16, device_map="cuda",
        attn_implementation="flash_attention_2")
except:
    model = VibeVoiceForConditionalGenerationInference.from_pretrained(
        "microsoft/VibeVoice-1.5B", torch_dtype=torch.bfloat16, device_map="cuda",
        attn_implementation="sdpa")
model.eval()

# Load finetuned LoRA adapter
from vibevoice.modular.lora_loading import load_lora_assets
load_lora_assets(model, '/home/ubuntu/vibevoice_finetune_output/lora')
print("LoRA loaded", flush=True)

# Setup CUDA graph
graph, sx, st_buf, sc, so = setup_cuda_graph_diffusion(model, batch_size=1)
model._orig_sample = model.sample_speech_tokens
patched_sample_speech_tokens(model, graph, sx, st_buf, sc, so)
model._cuda_graph_sample = model.sample_speech_tokens
print("CUDA graph ready", flush=True)

voice = "demo/voices/modi.wav"
text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

def run(batch, steps=10):
    texts = [text] * batch
    voices = [[voice]] * batch
    inp = processor(text=texts, voice_samples=voices, padding=True, return_tensors="pt", return_attention_mask=True)
    for k, v in inp.items():
        if torch.is_tensor(v): inp[k] = v.to("cuda")

    streamer = AudioStreamer(batch_size=batch, stop_signal=None)
    fcs = [None]*batch; ts = [0]*batch; st = [None]
    def consumer(idx):
        for ch in streamer.get_stream(idx):
            t = time.perf_counter()
            if fcs[idx] is None: fcs[idx] = t
            ts[idx] += ch.shape[-1]
    cons = [threading.Thread(target=consumer, args=(i,), daemon=True) for i in range(batch)]
    for c in cons: c.start()

    gpu_s = []; stop = [False]
    def mon():
        while not stop[0]:
            r = subprocess.run(['nvidia-smi','--query-gpu=utilization.gpu,memory.used','--format=csv,noheader,nounits'],
                               capture_output=True, text=True)
            parts = r.stdout.strip().split(', ')
            if len(parts) >= 2: gpu_s.append((float(parts[0]), float(parts[1])/1024))
            time.sleep(0.15)
    mt = threading.Thread(target=mon, daemon=True); mt.start()

    torch.manual_seed(42); torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=steps)
    torch.cuda.synchronize(); st[0] = time.perf_counter()
    # CUDA graph only works for batch=1 (cfg_batch=2). Restore original for larger batches.
    if batch > 1 and hasattr(model, '_orig_sample'):
        model.sample_speech_tokens = model._orig_sample
    elif batch == 1 and hasattr(model, '_cuda_graph_sample'):
        model.sample_speech_tokens = model._cuda_graph_sample
    out = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
        generation_config={"do_sample": False}, verbose=False, is_prefill=True,
        audio_streamer=streamer, show_progress_bar=False)
    torch.cuda.synchronize(); end = time.perf_counter()
    stop[0] = True
    for c in cons: c.join(timeout=5)
    time.sleep(0.3)

    gen = end - st[0]
    ttfbs = [(fcs[i]-st[0])*1000 for i in range(batch) if fcs[i]]
    avg_ttfb = sum(ttfbs)/len(ttfbs) if ttfbs else -1
    total_audio = sum(s/24000.0 for s in ts)
    thru = total_audio/gen if gen > 0 else 0
    avg_dur = total_audio/batch
    rtf = gen/avg_dur if avg_dur > 0 else 999
    avg_gpu = sum(g[0] for g in gpu_s)/len(gpu_s) if gpu_s else 0
    avg_vram = sum(g[1] for g in gpu_s)/len(gpu_s) if gpu_s else 0
    return avg_ttfb, rtf, thru, avg_gpu, avg_vram

# Warmup
run(1, 10)

print(f"\n{'Batch':>6} {'TTFB':>8} {'RTF':>7} {'Thru':>9} {'GPU%':>6} {'VRAM':>7} {'Stream':>7}", flush=True)
print("-"*55, flush=True)

for batch in [1, 4, 8, 16, 32]:
    ttfb, rtf, thru, gpu, vram = run(batch, 10)
    can = "YES" if rtf < 1.0 else "NO"
    print(f"{batch:>6} {ttfb:>7.0f}ms {rtf:>6.3f}x {thru:>6.1f} s/s {gpu:>5.0f}% {vram:>5.1f}GB {can:>6}", flush=True)

print("\n20 steps:", flush=True)
for batch in [1, 8, 16]:
    ttfb, rtf, thru, gpu, vram = run(batch, 20)
    can = "YES" if rtf < 1.0 else "NO"
    print(f"{batch:>6} {ttfb:>7.0f}ms {rtf:>6.3f}x {thru:>6.1f} s/s {gpu:>5.0f}% {vram:>5.1f}GB {can:>6}", flush=True)
