"""
Test real RTF optimizations based on profiling data.
1. cfg_scale=1.0 (skip negative LM forward - saves 29ms/token)
2. torch.compile on LM backbone
3. torch.compile on diffusion head
4. Combined
"""

import gc
import time
import threading
import torch
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3):
    if not text.startswith("Speaker"):
        text = f"Speaker 1: {text}"
    inputs = processor(
        text=[text], voice_samples=[[voice_path]],
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    streamer = AudioStreamer(batch_size=1, stop_signal=None)
    first_chunk = [None]
    total_samples = [0]
    start = [None]

    def consumer():
        for chunk in streamer.get_stream(0):
            t = time.perf_counter()
            if first_chunk[0] is None:
                first_chunk[0] = t
            total_samples[0] += chunk.shape[-1]

    t = threading.Thread(target=consumer, daemon=True)
    t.start()

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    torch.cuda.synchronize()
    start[0] = time.perf_counter()
    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False,
    )
    torch.cuda.synchronize()
    end = time.perf_counter()
    t.join(timeout=5)

    gen = end - start[0]
    ttfb = (first_chunk[0] - start[0]) * 1000 if first_chunk[0] else -1
    dur = total_samples[0] / 24000.0
    rtf = gen / dur if dur > 0 else float("inf")
    return {"ttfb_ms": ttfb, "gen_s": gen, "audio_s": dur, "rtf": rtf}


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    # Warmup
    print("Warming up baseline...")
    _ = measure(model, processor, voice_path, "Speaker 1: test.", ddpm_steps=20, cfg_scale=1.3)

    configs = []

    # ---- Test 1: Baseline (cfg=1.3, 20 steps) ----
    print("\n[1] Baseline: cfg=1.3, 20 steps")
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3)
    configs.append(("Baseline (cfg=1.3, 20 steps)", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # ---- Test 2: cfg=1.0 (should skip or simplify negative) ----
    print("\n[2] cfg_scale=1.0, 20 steps")
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.0)
    configs.append(("cfg=1.0, 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # ---- Test 3: cfg=1.3, 10 steps ----
    print("\n[3] cfg=1.3, 10 steps")
    r = measure(model, processor, voice_path, text, ddpm_steps=10, cfg_scale=1.3)
    configs.append(("cfg=1.3, 10 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # ---- Test 4: cfg=1.0, 10 steps ----
    print("\n[4] cfg=1.0, 10 steps")
    r = measure(model, processor, voice_path, text, ddpm_steps=10, cfg_scale=1.0)
    configs.append(("cfg=1.0, 10 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # ---- Test 5: torch.compile LM ----
    print("\n[5] torch.compile LM + cfg=1.3, 20 steps")
    model.model.language_model = torch.compile(model.model.language_model, mode="reduce-overhead")
    # Warmup compile
    _ = measure(model, processor, voice_path, "Speaker 1: warmup test.", ddpm_steps=20, cfg_scale=1.3)
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3)
    configs.append(("torch.compile LM, cfg=1.3, 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # ---- Test 6: torch.compile LM + cfg=1.0 ----
    print("\n[6] torch.compile LM + cfg=1.0, 20 steps")
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.0)
    configs.append(("torch.compile LM, cfg=1.0, 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # ---- Test 7: torch.compile LM + diffusion head + cfg=1.0, 20 steps ----
    print("\n[7] torch.compile LM + diffusion + cfg=1.0, 20 steps")
    model.model.prediction_head = torch.compile(model.model.prediction_head, mode="reduce-overhead")
    # Warmup
    _ = measure(model, processor, voice_path, "Speaker 1: warmup.", ddpm_steps=20, cfg_scale=1.0)
    r = measure(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.0)
    configs.append(("torch.compile LM+diff, cfg=1.0, 20 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # ---- Test 8: torch.compile everything + cfg=1.0, 10 steps ----
    print("\n[8] torch.compile LM + diffusion + cfg=1.0, 10 steps")
    r = measure(model, processor, voice_path, text, ddpm_steps=10, cfg_scale=1.0)
    configs.append(("torch.compile all, cfg=1.0, 10 steps", r))
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # Summary
    print(f"\n{'='*90}")
    print(f"{'Config':<45} {'RTF':>7} {'TTFB':>8} {'Audio':>8} {'Stream?':>8}")
    print(f"{'-'*90}")
    for name, r in configs:
        can = "YES" if r["rtf"] < 1.0 else "NO"
        print(f"{name:<45} {r['rtf']:>6.3f}x {r['ttfb_ms']:>7.0f}ms {r['audio_s']:>7.2f}s {can:>7}")
    print(f"{'='*90}")


if __name__ == "__main__":
    main()
