"""Finetuned model: no reference + CUDA graph diffusion + streaming TTFB + concurrency."""
import sys, torch, time, threading, subprocess
sys.path.insert(0, '/home/ubuntu/vibevoice')

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.lora_loading import load_lora_assets
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

# Load model + finetuned adapter
processor = VibeVoiceProcessor.from_pretrained('microsoft/VibeVoice-1.5B')
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
    'microsoft/VibeVoice-1.5B', torch_dtype=torch.bfloat16, device_map='cuda', attn_implementation='sdpa')
model.eval()
print('Loading finetuned adapter...', flush=True)
load_lora_assets(model, '/home/ubuntu/vibevoice_finetune_output/lora')

# Setup CUDA graph for diffusion
from vibevoice.cuda_graph_v2 import setup_cuda_graph_diffusion, patched_sample_speech_tokens
graph, sx, st_buf, sc, so = setup_cuda_graph_diffusion(model, batch_size=1)
orig_sample = model.sample_speech_tokens
patched_sample_speech_tokens(model, graph, sx, st_buf, sc, so)
cg_sample = model.sample_speech_tokens
print('CUDA graph ready.', flush=True)

text = 'Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.'

def run(batch, steps=10):
    # CUDA graph only for batch=1, fallback for larger
    if batch == 1:
        model.sample_speech_tokens = cg_sample
    else:
        model.sample_speech_tokens = orig_sample

    texts = [text] * batch
    inp = processor(text=texts, voice_samples=None, padding=True, return_tensors='pt', return_attention_mask=True)
    for k, v in inp.items():
        if torch.is_tensor(v): inp[k] = v.to('cuda')

    streamer = AudioStreamer(batch_size=batch, stop_signal=None)
    fcs = [None]*batch; ts = [0]*batch; st = [None]
    gpu_s = []; stop = [False]

    def consumer(idx):
        for ch in streamer.get_stream(idx):
            t = time.perf_counter()
            if fcs[idx] is None: fcs[idx] = t
            ts[idx] += ch.shape[-1]

    def mon():
        while not stop[0]:
            r = subprocess.run(['nvidia-smi','--query-gpu=utilization.gpu,memory.used','--format=csv,noheader,nounits'],
                               capture_output=True, text=True)
            parts = r.stdout.strip().split(', ')
            if len(parts) >= 2: gpu_s.append((float(parts[0]), float(parts[1])/1024))
            time.sleep(0.15)

    cons = [threading.Thread(target=consumer, args=(i,), daemon=True) for i in range(batch)]
    for c in cons: c.start()
    mt = threading.Thread(target=mon, daemon=True); mt.start()

    torch.manual_seed(42); torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=steps)
    torch.cuda.synchronize(); st[0] = time.perf_counter()
    # No CFG needed for finetuned model without reference
    out = model.generate(**inp, max_new_tokens=None, cfg_scale=1.0, tokenizer=processor.tokenizer,
        generation_config={'do_sample': False}, verbose=False, is_prefill=True,
        audio_streamer=streamer, show_progress_bar=False)
    torch.cuda.synchronize(); end = time.perf_counter()
    stop[0] = True
    for c in cons: c.join(timeout=5)
    time.sleep(0.3)

    gen = end - st[0]
    ttfbs = [(fcs[i]-st[0])*1000 for i in range(batch) if fcs[i]]
    avg_ttfb = sum(ttfbs)/len(ttfbs) if ttfbs else -1
    total_audio = sum(s/24000.0 for s in ts)
    thru = total_audio/gen if gen > 0 else 0
    avg_dur = total_audio/batch
    rtf = gen/avg_dur if avg_dur > 0 else 999
    avg_gpu = sum(g[0] for g in gpu_s)/len(gpu_s) if gpu_s else 0
    avg_vram = sum(g[1] for g in gpu_s)/len(gpu_s) if gpu_s else 0
    return avg_ttfb, rtf, thru, avg_gpu, avg_vram

# Warmup
run(1, 10)

print(f'\nFINETUNED MODEL: No reference, no CFG, CUDA graph diffusion', flush=True)
print(f'{"Batch":>6} {"TTFB":>8} {"RTF":>7} {"Thru":>9} {"GPU%":>6} {"VRAM":>7} {"Stream":>7}', flush=True)
print('-'*55, flush=True)

for batch in [1, 4, 8, 16, 32]:
    ttfb, rtf, thru, gpu, vram = run(batch, 10)
    can = 'YES' if rtf < 1.0 else 'NO'
    print(f'{batch:>6} {ttfb:>7.0f}ms {rtf:>6.3f}x {thru:>6.1f} s/s {gpu:>5.0f}% {vram:>5.1f}GB {can:>6}', flush=True)

print('\n20 steps:', flush=True)
for batch in [1, 8, 16]:
    if batch == 1:
        model.sample_speech_tokens = cg_sample
    else:
        model.sample_speech_tokens = orig_sample
    ttfb, rtf, thru, gpu, vram = run(batch, 20)
    can = 'YES' if rtf < 1.0 else 'NO'
    print(f'{batch:>6} {ttfb:>7.0f}ms {rtf:>6.3f}x {thru:>6.1f} s/s {gpu:>5.0f}% {vram:>5.1f}GB {can:>6}', flush=True)

print('\nDone.', flush=True)
