#!/usr/bin/env python3
"""Debug inference: test different configurations to find what works."""
import os, sys, time, torch
sys.path.insert(0, "/home/ubuntu/vibevoice-community")

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.lora_loading import load_lora_assets
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

MODEL_PATH = "tarun7r/vibevoice-hindi-1.5B"
LORA_PATH = "/home/ubuntu/modi_processed/lora_output/lora"
REF_AUDIO = "/home/ubuntu/modi_processed/segments/001_ERW9i1lwnBw/seg_0004.wav"
OUTPUT_DIR = "/home/ubuntu/modi_processed/test_outputs"
TEXT = "Speaker 0: मेरे प्यारे देशवासियों, आप सभी को मेरा नमस्कार।"

os.makedirs(OUTPUT_DIR, exist_ok=True)
torch.manual_seed(42)

print("=" * 60)
print("Loading processor...")
processor = VibeVoiceProcessor.from_pretrained(MODEL_PATH)

print("Loading base model...")
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    attn_implementation="sdpa",
)
model.eval()
model.set_ddpm_inference_steps(num_steps=10)


def run_test(name, script, voice_samples, is_prefill, cfg_scale=1.3):
    print(f"\n{'='*60}")
    print(f"TEST: {name}")
    print(f"  script     = {script[:80]}...")
    print(f"  voice      = {voice_samples}")
    print(f"  is_prefill = {is_prefill}")
    print(f"  cfg_scale  = {cfg_scale}")
    print("=" * 60)

    inputs = processor(
        text=[script],
        voice_samples=[voice_samples] if voice_samples else None,
        padding=True,
        return_tensors="pt",
        return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    print(f"  input_ids shape: {inputs['input_ids'].shape}")

    t0 = time.time()
    outputs = model.generate(
        **inputs,
        max_new_tokens=None,
        cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer,
        generation_config={"do_sample": False},
        verbose=True,
        is_prefill=is_prefill,
    )
    elapsed = time.time() - t0

    gen_tokens = outputs.sequences.shape[1] - inputs["input_ids"].shape[1]
    has_audio = outputs.speech_outputs and outputs.speech_outputs[0] is not None

    if has_audio:
        samples = outputs.speech_outputs[0].shape[-1]
        dur = samples / 24000
        out_path = os.path.join(OUTPUT_DIR, f"debug_{name}.wav")
        processor.save_audio(outputs.speech_outputs[0], output_path=out_path)
        print(f"  RESULT: {dur:.2f}s audio, {gen_tokens} tokens, {elapsed:.1f}s wall")
        print(f"  Saved: {out_path}")
    else:
        print(f"  RESULT: NO AUDIO, {gen_tokens} tokens, {elapsed:.1f}s wall")

    return has_audio


# ====== TEST 1: Base model + reference audio (voice cloning) ======
print("\n\n>>> PHASE 1: BASE MODEL (no LoRA)")
run_test(
    "base_with_ref",
    TEXT,
    [REF_AUDIO],
    is_prefill=True,
)

# ====== TEST 2: Base model + disable_prefill (no ref audio) ======
run_test(
    "base_no_ref",
    TEXT,
    [],
    is_prefill=False,
)

# ====== Now load LoRA ======
print("\n\n>>> LOADING LoRA WEIGHTS...")
report = load_lora_assets(model, LORA_PATH)
print(f"  Loaded: language_model={report.language_model}, "
      f"diffusion_head_full={report.diffusion_head_full}, "
      f"acoustic_connector={report.acoustic_connector}, "
      f"semantic_connector={report.semantic_connector}")

# ====== TEST 3: LoRA + reference audio (voice cloning) ======
run_test(
    "lora_with_ref",
    TEXT,
    [REF_AUDIO],
    is_prefill=True,
)

# ====== TEST 4: LoRA + disable_prefill ======
run_test(
    "lora_no_ref",
    TEXT,
    [],
    is_prefill=False,
)

# ====== TEST 5: LoRA + disable_prefill + higher CFG ======
run_test(
    "lora_no_ref_cfg2",
    TEXT,
    [],
    is_prefill=False,
    cfg_scale=2.0,
)

# ====== TEST 6: LoRA + disable_prefill + longer text ======
long_text = "Speaker 0: भाइयों और बहनों, Digital India का सपना आज हक़ीक़त बन रहा है। हमारे देश ने technology के क्षेत्र में जो प्रगति की है, वह पूरी दुनिया देख रही है। आज हर गाँव में internet पहुँच रहा है, हर हाथ में smartphone है।"
run_test(
    "lora_no_ref_long",
    long_text,
    [],
    is_prefill=False,
)

# ====== TEST 7: LoRA + ref audio + disable_prefill (both loaded but prefill off) ======
run_test(
    "lora_ref_loaded_prefill_off",
    TEXT,
    [REF_AUDIO],
    is_prefill=False,
)

print("\n\n>>> ALL TESTS COMPLETE")
