#!/usr/bin/env python3
"""Quick test of v2 model with multiple seeds and do_sample=True."""
import os, sys, time, torch
sys.path.insert(0, "/home/ubuntu/vibevoice-community")

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.lora_loading import load_lora_assets
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

MODEL_PATH = "tarun7r/vibevoice-hindi-1.5B"
LORA_PATH = "/home/ubuntu/modi_processed/lora_output_v2/lora"
OUTPUT_DIR = "/home/ubuntu/modi_processed/v2_tests"
os.makedirs(OUTPUT_DIR, exist_ok=True)

TEXTS = {
    "greeting": "Speaker 0: मेरे प्यारे देशवासियों, आप सभी को मेरा नमस्कार।",
    "digital": "Speaker 0: भाइयों और बहनों, Digital India का सपना आज हक़ीक़त बन रहा है। हमारे देश ने technology के क्षेत्र में जो प्रगति की है, वह पूरी दुनिया देख रही है।",
    "emotional": "Speaker 0: साथियों, जब मैं सोचता हूँ कि एक चाय बेचने वाला लड़का आज देश का प्रधानमंत्री है, तो मुझे लगता है कि यह सिर्फ मेरी कहानी नहीं है। यह भारत की कहानी है।",
}

REF_AUDIO = "/home/ubuntu/modi_processed/segments/023_Ytzlp3Umct8/seg_0056.wav"

print("Loading model...")
processor = VibeVoiceProcessor.from_pretrained(MODEL_PATH)
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
    MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa",
)
report = load_lora_assets(model, LORA_PATH)
print(f"LoRA loaded: lm={report.language_model}, diff={report.diffusion_head_full}")
model.eval()
model.set_ddpm_inference_steps(num_steps=10)


def gen(name, script, voice_samples, is_prefill, do_sample, seed, cfg=1.3):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    inputs = processor(
        text=[script],
        voice_samples=[voice_samples] if voice_samples else None,
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg,
        tokenizer=processor.tokenizer,
        generation_config={"do_sample": do_sample, "temperature": 0.8 if do_sample else 1.0},
        verbose=False, is_prefill=is_prefill,
    )
    has_audio = outputs.speech_outputs and outputs.speech_outputs[0] is not None
    if has_audio:
        dur = outputs.speech_outputs[0].shape[-1] / 24000
        out_path = os.path.join(OUTPUT_DIR, f"{name}.wav")
        processor.save_audio(outputs.speech_outputs[0], output_path=out_path)
        print(f"  {name}: {dur:.2f}s")
        return dur
    print(f"  {name}: FAILED (no audio)")
    return 0


print("\n=== NOPREFILL + do_sample=True + multiple seeds ===")
for tname, text in TEXTS.items():
    for seed in [42, 123, 777, 2024]:
        gen(f"noprefill_{tname}_s{seed}", text, [], False, True, seed)

print("\n=== NOPREFILL + do_sample=False + multiple seeds ===")
for tname, text in TEXTS.items():
    for seed in [42, 123, 777, 2024]:
        gen(f"noprefill_greedy_{tname}_s{seed}", text, [], False, False, seed)

print("\n=== CLONE (ref audio) + do_sample=True ===")
for tname, text in TEXTS.items():
    for seed in [42, 123]:
        gen(f"clone_{tname}_s{seed}", text, [REF_AUDIO], True, True, seed)

print("\n=== DONE ===")
