import torch, time, os
import sys
sys.path.insert(0, '/home/ubuntu/vibevoice')
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.lora_loading import load_lora_assets
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

processor = VibeVoiceProcessor.from_pretrained('microsoft/VibeVoice-1.5B')
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
    'microsoft/VibeVoice-1.5B', torch_dtype=torch.bfloat16, device_map='cuda', attn_implementation='sdpa')
model.eval()

print('Loading finetuned adapter (10 epochs)...', flush=True)
report = load_lora_assets(model, '/home/ubuntu/vibevoice_finetune_output/lora')
print(f'Loaded: lm={report.language_model}, diff={report.diffusion_head_full or report.diffusion_head_lora}', flush=True)
model.set_ddpm_inference_steps(num_steps=10)

os.makedirs('/home/ubuntu/finetuned_samples_v2', exist_ok=True)
voice = '/home/ubuntu/vibevoice/demo/voices/modi.wav'

tests = [
    ('no_ref_medium', 'Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.', None, 1.0),
    ('no_ref_speech', 'Speaker 1: भारत आज दुनिया की पाँचवीं सबसे बड़ी अर्थव्यवस्था है. हमारे युवाओं की ऊर्जा, हमारे वैज्ञानिकों की प्रतिभा, और हमारे किसानों की मेहनत, यही हमारी असली ताकत है.', None, 1.0),
    ('ref_medium', 'Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.', voice, 1.3),
    ('ref_speech', 'Speaker 1: भारत आज दुनिया की पाँचवीं सबसे बड़ी अर्थव्यवस्था है. हमारे युवाओं की ऊर्जा, हमारे वैज्ञानिकों की प्रतिभा, और हमारे किसानों की मेहनत, यही हमारी असली ताकत है.', voice, 1.3),
]

for label, text, ref, cfg in tests:
    vs = [[ref]] if ref else None
    inp = processor(text=[text], voice_samples=vs, padding=True, return_tensors='pt', return_attention_mask=True)
    for k,v in inp.items():
        if torch.is_tensor(v): inp[k]=v.to('cuda')
    torch.manual_seed(42); torch.cuda.synchronize(); t0=time.perf_counter()
    out = model.generate(**inp, max_new_tokens=None, cfg_scale=cfg, tokenizer=processor.tokenizer,
        generation_config={'do_sample': False}, verbose=False, is_prefill=True, show_progress_bar=False)
    torch.cuda.synchronize(); gen=time.perf_counter()-t0
    if out.speech_outputs[0] is not None:
        path = f'/home/ubuntu/finetuned_samples_v2/{label}.wav'
        processor.save_audio(out.speech_outputs[0], output_path=path)
        dur = out.speech_outputs[0].shape[-1]/24000
        print(f'{label}: {dur:.2f}s audio | {gen:.2f}s gen | RTF={gen/dur:.3f}x', flush=True)
    else:
        print(f'{label}: no audio', flush=True)

print('Done.', flush=True)
