"""
Benchmark VibeVoice inference speed.
Uses the correct Inference class + LoRA loading.
"""

import os
import time
import torch
import numpy as np
import soundfile as sf

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.lora_loading import load_lora_assets
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

MODEL_PATH = "tarun7r/vibevoice-hindi-1.5B"
LORA_PATH = "output/hindi_lora_v2_from_tarun/checkpoint-12500/lora"
DEVICE = "cuda"


def load_model(lora_path=None):
    print("Loading processor...")
    processor = VibeVoiceProcessor.from_pretrained(MODEL_PATH)

    print("Loading model...")
    model = VibeVoiceForConditionalGenerationInference.from_pretrained(
        MODEL_PATH, torch_dtype=torch.bfloat16,
        device_map="cuda", attn_implementation="sdpa",
    )

    if lora_path and os.path.exists(lora_path):
        print(f"Loading LoRA from {lora_path}...")
        report = load_lora_assets(model, lora_path)
        print(f"  LoRA loaded: {report}")

    model.eval()
    return model, processor


def generate(model, processor, text, speaker_wav=None, cfg_scale=1.3):
    script = f"Speaker 1: {text}"
    kwargs = {}
    if speaker_wav and os.path.exists(speaker_wav):
        kwargs["voice_samples"] = [speaker_wav]
    
    inputs = processor(
        text=[script], padding=False, truncation=False,
        return_tensors="pt", **kwargs,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to(DEVICE)

    with torch.inference_mode():
        outputs = model.generate(
            **inputs, max_new_tokens=None, cfg_scale=cfg_scale,
            tokenizer=processor.tokenizer,
            generation_config={'do_sample': False},
            verbose=False, is_prefill=speaker_wav is not None,
        )

    if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
        audio = outputs.speech_outputs[0].float().cpu().numpy()
        if audio.ndim > 1:
            audio = audio.squeeze()
        return audio
    return None


def benchmark(model, processor, text, speaker_wav=None, runs=3, cfg_scale=1.3, label=""):
    print(f"\n{'='*60}")
    print(f"BENCHMARK: {label}")
    print(f"{'='*60}")
    
    # Warmup
    print("  Warming up...")
    generate(model, processor, "टेस्ट", speaker_wav, cfg_scale)
    torch.cuda.synchronize()

    times = []
    durations = []
    for i in range(runs):
        torch.cuda.synchronize()
        t0 = time.time()
        audio = generate(model, processor, text, speaker_wav, cfg_scale)
        torch.cuda.synchronize()
        elapsed = time.time() - t0
        if audio is not None:
            dur = len(audio) / 24000
            times.append(elapsed)
            durations.append(dur)
            print(f"  Run {i+1}: {elapsed:.2f}s gen → {dur:.1f}s audio (RTF={elapsed/dur:.2f}x)")

    if times:
        avg_time = np.mean(times)
        avg_dur = np.mean(durations)
        avg_rtf = avg_time / avg_dur
        print(f"  → AVG: {avg_time:.2f}s gen, {avg_dur:.1f}s audio, RTF={avg_rtf:.2f}x")
        return avg_time, avg_dur, avg_rtf
    return 0, 0, float('inf')


if __name__ == "__main__":
    text_short = "नमस्ते, कैसे हैं आप?"
    text_medium = "शिक्षा हमारे जीवन का सबसे महत्वपूर्ण हिस्सा है। हमें अपने सभी कार्य सही समय पर पूरे करने की आदत डालनी चाहिए।"
    text_long = "समय दुनिया की सबसे मूल्यवान चीज़ है जिसे कभी वापस नहीं पाया जा सकता। जो व्यक्ति समय का सम्मान करता है, वह जीवन में हमेशा सफलता प्राप्त करता है।"
    speaker_wav = "demo/voices/hi-Priya_woman.wav"

    print("Loading model with LoRA...")
    model, processor = load_model(LORA_PATH)
    
    params = sum(p.numel() for p in model.parameters())
    print(f"Model: {params/1e9:.2f}B params")
    print(f"GPU memory: {torch.cuda.memory_allocated()/1e9:.1f}GB allocated")

    # Test 1: Short text, no voice cloning
    benchmark(model, processor, text_short, runs=3, label="Short text, NO voice cloning")

    # Test 2: Short text, with Priya voice
    benchmark(model, processor, text_short, speaker_wav, runs=3, label="Short text, Priya voice cloning")

    # Test 3: Medium text, with voice cloning
    benchmark(model, processor, text_medium, speaker_wav, runs=3, label="Medium text, Priya voice cloning")

    # Test 4: Long text, with voice cloning
    benchmark(model, processor, text_long, speaker_wav, runs=3, label="Long text, Priya voice cloning")

    # Save samples
    print("\n" + "="*60)
    print("SAVING SAMPLES")
    print("="*60)
    for name, text, spk in [
        ("bench_short_priya", text_short, speaker_wav),
        ("bench_medium_priya", text_medium, speaker_wav),
        ("bench_long_priya", text_long, speaker_wav),
        ("bench_medium_noclone", text_medium, None),
    ]:
        audio = generate(model, processor, text, spk)
        if audio is not None:
            sf.write(f"{name}.wav", audio, 24000)
            print(f"  {name}.wav: {len(audio)/24000:.1f}s")
    
    print("\nDone!")
