"""
CUDA graph accelerated VibeVoice diffusion head.
Captures prediction_head as a CUDA graph, replays for each DDPM step.
"""
import torch
import time
import threading
import subprocess

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def setup_cuda_graph_diffusion(model, batch_size=1):
    """Capture prediction_head as CUDA graph for given batch size (doubled for CFG)."""
    ph = model.model.prediction_head
    device = next(ph.parameters()).device
    dtype = next(ph.parameters()).dtype
    cfg_batch = batch_size * 2  # pos + neg for CFG

    static_x = torch.zeros(cfg_batch, model.config.acoustic_vae_dim, device=device, dtype=dtype)
    static_t = torch.zeros(cfg_batch, device=device, dtype=dtype)
    static_cond = torch.zeros(cfg_batch, model.config.decoder_config.hidden_size, device=device, dtype=dtype)

    # Warmup
    for _ in range(5):
        ph(static_x, static_t, condition=static_cond)
    torch.cuda.synchronize()

    # Capture
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        static_out = ph(static_x, static_t, condition=static_cond)
    torch.cuda.synchronize()

    return graph, static_x, static_t, static_cond, static_out


def patched_sample_speech_tokens(model, graph, static_x, static_t, static_cond, static_out):
    """Replace sample_speech_tokens with CUDA graph version."""
    original_sample = model.sample_speech_tokens

    @torch.no_grad()
    def fast_sample(condition, neg_condition, cfg_scale=3.0):
        model.model.noise_scheduler.set_timesteps(model.ddpm_inference_steps)
        combined_cond = torch.cat([condition, neg_condition], dim=0)
        static_cond.copy_(combined_cond)

        speech = torch.randn(combined_cond.shape[0], model.config.acoustic_vae_dim,
                             device=combined_cond.device, dtype=combined_cond.dtype)

        for t in model.model.noise_scheduler.timesteps:
            half = speech[:len(speech) // 2]
            combined = torch.cat([half, half], dim=0)
            static_x.copy_(combined)
            static_t.fill_(t.item())

            graph.replay()
            eps = static_out.clone()

            cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
            half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
            full_eps = torch.cat([half_eps, half_eps], dim=0)
            speech = model.model.noise_scheduler.step(full_eps, t, speech).prev_sample

        return speech[:len(speech) // 2]

    model.sample_speech_tokens = fast_sample
    return original_sample


def run_benchmark(model, processor, voice_path, text, ddpm_steps=10, label=""):
    inp = processor(text=[text], voice_samples=[[voice_path]],
                    padding=True, return_tensors="pt", return_attention_mask=True)
    for k, v in inp.items():
        if torch.is_tensor(v): inp[k] = v.to("cuda")

    streamer = AudioStreamer(batch_size=1, stop_signal=None)
    fc = [None]; ts = [0]; st = [None]
    gpu_s = []; stop = [False]

    def consumer():
        for ch in streamer.get_stream(0):
            t = time.perf_counter()
            if fc[0] is None: fc[0] = t
            ts[0] += ch.shape[-1]

    def mon():
        while not stop[0]:
            r = subprocess.run(['nvidia-smi', '--query-gpu=utilization.gpu', '--format=csv,noheader,nounits'],
                               capture_output=True, text=True)
            try: gpu_s.append(float(r.stdout.strip()))
            except: pass
            time.sleep(0.15)

    th = threading.Thread(target=consumer, daemon=True); th.start()
    mt = threading.Thread(target=mon, daemon=True); mt.start()

    torch.manual_seed(42); torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)
    torch.cuda.synchronize(); st[0] = time.perf_counter()
    out = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
                          generation_config={"do_sample": False}, verbose=False, is_prefill=True,
                          audio_streamer=streamer, show_progress_bar=False)
    torch.cuda.synchronize(); end = time.perf_counter()
    stop[0] = True; th.join(timeout=5); time.sleep(0.3)

    gen = end - st[0]
    ttfb = (fc[0] - st[0]) * 1000 if fc[0] else -1
    dur = ts[0] / 24000.0
    rtf = gen / dur if dur > 0 else 999
    avg_gpu = sum(gpu_s) / len(gpu_s) if gpu_s else 0
    print(f"{label:<45} RTF={rtf:.3f}x TTFB={ttfb:.0f}ms Audio={dur:.2f}s GPU={avg_gpu:.0f}%", flush=True)
    return out


def main():
    voice_path = "demo/voices/modi.wav"
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

    processor = VibeVoiceProcessor.from_pretrained("microsoft/VibeVoice-1.5B")
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            "microsoft/VibeVoice-1.5B", torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            "microsoft/VibeVoice-1.5B", torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    # Warmup baseline
    run_benchmark(model, processor, voice_path, "Speaker 1: warmup.", ddpm_steps=10, label="warmup")

    # 1. Baseline
    run_benchmark(model, processor, voice_path, text, ddpm_steps=10, label="[1] Baseline (10 steps)")
    run_benchmark(model, processor, voice_path, text, ddpm_steps=20, label="[1] Baseline (20 steps)")

    # 2. CUDA graph diffusion
    print("\nSetting up CUDA graph for diffusion head...", flush=True)
    graph, sx, st_buf, sc, so = setup_cuda_graph_diffusion(model, batch_size=1)
    orig = patched_sample_speech_tokens(model, graph, sx, st_buf, sc, so)
    print("CUDA graph ready.", flush=True)

    run_benchmark(model, processor, voice_path, text, ddpm_steps=10, label="[2] CUDA graph diffusion (10 steps)")
    run_benchmark(model, processor, voice_path, text, ddpm_steps=20, label="[2] CUDA graph diffusion (20 steps)")

    # 3. + torch.compile LM
    print("\nCompiling LM backbone...", flush=True)
    model.model.language_model = torch.compile(model.model.language_model, mode="default")
    run_benchmark(model, processor, voice_path, "Speaker 1: compile warmup.", ddpm_steps=10, label="compile warmup 1")
    run_benchmark(model, processor, voice_path, "Speaker 1: compile warmup.", ddpm_steps=10, label="compile warmup 2")

    run_benchmark(model, processor, voice_path, text, ddpm_steps=10, label="[3] CUDA graph diff + compile LM (10 steps)")
    run_benchmark(model, processor, voice_path, text, ddpm_steps=20, label="[3] CUDA graph diff + compile LM (20 steps)")

    # Save sample
    import os
    os.makedirs("samples_cuda_graph_v2", exist_ok=True)
    out = run_benchmark(model, processor, voice_path, text, ddpm_steps=20, label="[FINAL] Best config 20 steps")
    if out.speech_outputs[0] is not None:
        processor.save_audio(out.speech_outputs[0], output_path="samples_cuda_graph_v2/modi_best.wav")
        print("Saved: samples_cuda_graph_v2/modi_best.wav", flush=True)


if __name__ == "__main__":
    main()
