#!/usr/bin/env python3
"""Generate multiple Modi voice samples with different configs."""
import os, sys, time, torch
sys.path.insert(0, "/home/ubuntu/vibevoice-community")

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.lora_loading import load_lora_assets
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

MODEL_PATH = "tarun7r/vibevoice-hindi-1.5B"
LORA_PATH = "/home/ubuntu/modi_processed/lora_output_v2/lora"
OUTPUT_DIR = "/home/ubuntu/modi_processed/final_samples_v2"
os.makedirs(OUTPUT_DIR, exist_ok=True)

REF_CLIPS = [
    "/home/ubuntu/modi_processed/segments/023_Ytzlp3Umct8/seg_0056.wav",
    "/home/ubuntu/modi_processed/segments/009_9p75x0UZUCg/seg_0055.wav",
    "/home/ubuntu/modi_processed/segments/018_VH2ZIZx29j8/seg_0048.wav",
]

TEXTS = [
    ("short_greeting", "Speaker 0: मेरे प्यारे देशवासियों, आप सभी को मेरा नमस्कार।"),
    ("digital_india", "Speaker 0: भाइयों और बहनों, Digital India का सपना आज हक़ीक़त बन रहा है। हमारे देश ने technology के क्षेत्र में जो प्रगति की है, वह पूरी दुनिया देख रही है। आज हर गाँव में internet पहुँच रहा है, हर हाथ में smartphone है।"),
    ("mann_ki_baat", "Speaker 0: मेरे प्यारे देशवासियों, नमस्कार। मन की बात में आज मैं आपसे कुछ खास बातें करना चाहता हूँ। पिछले महीने मुझे बहुत सारी चिट्ठियाँ मिलीं, बहुत सारे messages आए। एक छोटे से गाँव से एक बच्चे ने मुझे letter लिखा, उसने लिखा कि मोदी जी, मैं बड़ा होकर scientist बनना चाहता हूँ। मुझे उस बच्चे पर बहुत गर्व है।"),
    ("emotional", "Speaker 0: साथियों, जब मैं सोचता हूँ कि एक चाय बेचने वाला लड़का आज देश का प्रधानमंत्री है, तो मुझे लगता है कि यह सिर्फ मेरी कहानी नहीं है। यह भारत की कहानी है, यह लोकतंत्र की ताकत है, यह सवा सौ करोड़ भारतीयों के विश्वास की जीत है।"),
    ("development", "Speaker 0: मित्रों, आज जब हम अपने देश की विकास यात्रा को देखते हैं, तो हमें गर्व होता है। पिछले कुछ वर्षों में भारत ने जो उपलब्धियाँ हासिल की हैं, वह अद्भुत हैं।"),
]

torch.manual_seed(42)

print("Loading processor...")
processor = VibeVoiceProcessor.from_pretrained(MODEL_PATH)

print("Loading model...")
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
    MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa",
)

print("Loading LoRA...")
report = load_lora_assets(model, LORA_PATH)
print(f"  Loaded: lm={report.language_model}, diff={report.diffusion_head_full}, "
      f"ac={report.acoustic_connector}, sc={report.semantic_connector}")

model.eval()
model.set_ddpm_inference_steps(num_steps=10)


def generate(name, script, voice_samples, is_prefill, cfg_scale=1.3):
    print(f"\n--- Generating: {name} ---")
    print(f"    text: {script[:80]}...")
    print(f"    ref: {os.path.basename(voice_samples[0]) if voice_samples else 'NONE'}")
    print(f"    prefill: {is_prefill}, cfg: {cfg_scale}")

    inputs = processor(
        text=[script],
        voice_samples=[voice_samples] if voice_samples else None,
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    t0 = time.time()
    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=is_prefill,
    )
    elapsed = time.time() - t0

    has_audio = outputs.speech_outputs and outputs.speech_outputs[0] is not None
    if has_audio:
        samples = outputs.speech_outputs[0].shape[-1]
        dur = samples / 24000
        out_path = os.path.join(OUTPUT_DIR, f"{name}.wav")
        processor.save_audio(outputs.speech_outputs[0], output_path=out_path)
        print(f"    OK: {dur:.2f}s audio in {elapsed:.1f}s -> {out_path}")
        return dur
    else:
        print(f"    FAILED: no audio generated")
        return 0


print("\n" + "=" * 60)
print("PHASE 1: LoRA + Voice Cloning (with reference audio)")
print("=" * 60)

for ref_idx, ref_path in enumerate(REF_CLIPS):
    ref_name = os.path.basename(os.path.dirname(ref_path)) + "_" + os.path.basename(ref_path).replace(".wav","")
    for text_name, text in TEXTS:
        name = f"clone_ref{ref_idx}_{text_name}"
        generate(name, text, [ref_path], is_prefill=True)

print("\n" + "=" * 60)
print("PHASE 2: LoRA + disable_prefill (no reference)")
print("=" * 60)

for text_name, text in TEXTS:
    generate(f"noprefill_{text_name}", text, [], is_prefill=False)

print("\n" + "=" * 60)
print("ALL DONE")
print("=" * 60)

wavs = sorted([f for f in os.listdir(OUTPUT_DIR) if f.endswith('.wav')])
for w in wavs:
    import wave
    with wave.open(os.path.join(OUTPUT_DIR, w), 'r') as wf:
        dur = wf.getnframes() / wf.getframerate()
    print(f"  {w}: {dur:.2f}s")
