#!/usr/bin/env python3
"""Focused noprefill debug - test multiple seeds, sampling, cfg scales."""
import os, sys, time, torch
sys.path.insert(0, "/home/ubuntu/vibevoice-community")

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.lora_loading import load_lora_assets
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

MODEL_PATH = "tarun7r/vibevoice-hindi-1.5B"
OUTPUT_DIR = "/home/ubuntu/modi_processed/noprefill_debug"
os.makedirs(OUTPUT_DIR, exist_ok=True)

TEXT = "Speaker 0: भाइयों और बहनों, Digital India का सपना आज हक़ीक़त बन रहा है। हमारे देश ने technology के क्षेत्र में जो प्रगति की है, वह पूरी दुनिया देख रही है।"
TEXT2 = "Speaker 0: मेरे प्यारे देशवासियों, आप सभी को मेरा नमस्कार।"

print("Loading model...")
processor = VibeVoiceProcessor.from_pretrained(MODEL_PATH)
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
    MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa",
)
model.eval()
model.set_ddpm_inference_steps(num_steps=10)


def gen(name, script, lora_path, do_sample, seed, cfg=1.3, temp=0.8):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    inputs = processor(
        text=[script], voice_samples=None,
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    gen_config = {"do_sample": do_sample}
    if do_sample:
        gen_config["temperature"] = temp

    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg,
        tokenizer=processor.tokenizer, generation_config=gen_config,
        verbose=False, is_prefill=False,
    )
    has_audio = outputs.speech_outputs and outputs.speech_outputs[0] is not None
    if has_audio:
        dur = outputs.speech_outputs[0].shape[-1] / 24000
        out_path = os.path.join(OUTPUT_DIR, f"{name}.wav")
        processor.save_audio(outputs.speech_outputs[0], output_path=out_path)
        print(f"  {name}: {dur:.2f}s  OK")
        return dur
    print(f"  {name}: 0s  FAILED")
    return 0


# === TEST v1 LoRA (rank16, 2 epochs, drop_rate=1.0) ===
print("\n=== v1 LoRA (rank16, 2ep, drop_rate=1.0) ===")
v1_path = "/home/ubuntu/modi_processed/lora_output/lora"
report = load_lora_assets(model, v1_path)
print(f"  Loaded: lm={report.language_model}, diff={report.diffusion_head_full}")

for seed in [42, 123, 777, 2024, 9999]:
    gen(f"v1_greedy_s{seed}", TEXT, v1_path, False, seed)
for seed in [42, 123, 777]:
    gen(f"v1_sample_s{seed}", TEXT, v1_path, True, seed)
for seed in [42, 123]:
    gen(f"v1_short_s{seed}", TEXT2, v1_path, True, seed)

# === Now reload base and apply v2 LoRA ===
print("\n=== Reloading base model for v2 ===")
model2 = VibeVoiceForConditionalGenerationInference.from_pretrained(
    MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa",
)
model2.eval()
model2.set_ddpm_inference_steps(num_steps=10)
model = model2

v2_path = "/home/ubuntu/modi_processed/lora_output_v2/lora"
report = load_lora_assets(model, v2_path)
print(f"  Loaded: lm={report.language_model}, diff={report.diffusion_head_full}")

print("\n=== v2 LoRA (rank64, 10ep, drop_rate=0.5) - greedy ===")
for seed in [42, 123, 777, 2024, 9999]:
    gen(f"v2_greedy_s{seed}", TEXT, v2_path, False, seed)

print("\n=== v2 LoRA - do_sample=True ===")
for seed in [42, 123, 777, 2024, 9999]:
    gen(f"v2_sample_s{seed}", TEXT, v2_path, True, seed)

print("\n=== v2 LoRA - different CFG scales ===")
for cfg in [1.0, 1.1, 1.2, 1.3, 1.5]:
    gen(f"v2_cfg{cfg}_s123", TEXT, v2_path, True, 123, cfg=cfg)

print("\n=== v2 LoRA - different temperatures ===")
for temp in [0.6, 0.8, 1.0, 1.2]:
    gen(f"v2_temp{temp}_s123", TEXT, v2_path, True, 123, temp=temp)

print("\n=== v2 LoRA - short text ===")
for seed in [42, 123, 777]:
    gen(f"v2_short_s{seed}", TEXT2, v2_path, True, seed)

print("\n=== DONE ===")
