"""
CUDA Graph accelerated VibeVoice inference.

Captures three CUDA graphs:
1. LM single-token decode (positive)
2. LM single-token decode (negative/CFG)  
3. Diffusion head full denoising loop (20 steps)

Also captures acoustic decode and semantic encode.
"""

import time
import types
import threading
import subprocess
import torch
import torch.nn as nn

from vibevoice.modular.modeling_vibevoice_inference import (
    VibeVoiceForConditionalGenerationInference,
    VibeVoiceGenerationOutput,
    VibeVoiceTokenConstraintProcessor,
)
from vibevoice.modular.modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from transformers.generation import LogitsProcessorList


class CUDAGraphDiffusion:
    """CUDA graph captured diffusion sampling for fixed batch_size=1 with CFG (effective batch=2)."""

    def __init__(self, prediction_head, noise_scheduler, acoustic_vae_dim, num_steps=10):
        self.prediction_head = prediction_head
        self.noise_scheduler = noise_scheduler
        self.num_steps = num_steps
        self.device = next(prediction_head.parameters()).device
        self.dtype = next(prediction_head.parameters()).dtype

        noise_scheduler.set_timesteps(num_steps)
        self.timesteps = noise_scheduler.timesteps.to(self.device)

        # Static buffers for batch=2 (pos + neg concatenated)
        self.static_speech = torch.zeros(2, acoustic_vae_dim, device=self.device, dtype=self.dtype)
        self.static_condition = torch.zeros(2, 1536, device=self.device, dtype=self.dtype)
        self.static_t = torch.zeros(2, device=self.device, dtype=self.dtype)
        self.static_eps = None

        self.graphs = []
        self._capture()

    def _capture(self):
        # Warmup
        for _ in range(3):
            self.prediction_head(self.static_speech, self.static_t, condition=self.static_condition)

        # Capture one graph per timestep
        self.graphs = []
        self.static_eps_list = []

        for i, t in enumerate(self.timesteps):
            self.static_t.fill_(t.item())

            g = torch.cuda.CUDAGraph()
            s = torch.cuda.Stream()
            s.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(s):
                for _ in range(3):
                    self.prediction_head(self.static_speech, self.static_t, condition=self.static_condition)
            torch.cuda.current_stream().wait_stream(s)

            with torch.cuda.graph(g):
                eps = self.prediction_head(self.static_speech, self.static_t, condition=self.static_condition)
            self.graphs.append(g)
            self.static_eps_list.append(eps)

    @torch.no_grad()
    def sample(self, positive_condition, negative_condition, cfg_scale=1.3):
        condition = torch.cat([positive_condition, negative_condition], dim=0)
        self.static_condition.copy_(condition)

        speech = torch.randn(2, self.static_speech.shape[1], device=self.device, dtype=self.dtype)

        self.noise_scheduler.set_timesteps(self.num_steps)

        for i, t in enumerate(self.timesteps):
            half = speech[:1]
            combined = torch.cat([half, half], dim=0)
            self.static_speech.copy_(combined)
            self.static_t.fill_(t.item())

            self.graphs[i].replay()
            eps = self.static_eps_list[i]

            cond_eps, uncond_eps = eps[:1], eps[1:]
            half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
            full_eps = torch.cat([half_eps, half_eps], dim=0)
            speech = self.noise_scheduler.step(full_eps, t, speech).prev_sample

        return speech[:1]


class CUDAGraphLMDecode:
    """CUDA graph captured single-token LM decode."""

    def __init__(self, model, max_seq_len=2048):
        self.model = model
        self.lm = model.model.language_model
        self.lm_head = model.lm_head
        self.device = next(model.parameters()).device
        self.dtype = next(model.parameters()).dtype
        self.max_seq_len = max_seq_len

        hidden_size = model.config.decoder_config.hidden_size

        # Static buffers
        self.static_inputs_embeds = torch.zeros(1, 1, hidden_size, device=self.device, dtype=self.dtype)
        self.static_attention_mask = torch.zeros(1, max_seq_len, device=self.device, dtype=torch.long)
        self.static_cache_position = torch.zeros(1, device=self.device, dtype=torch.long)
        self.static_position_ids = torch.zeros(1, 1, device=self.device, dtype=torch.long)

        # Set up static KV cache
        from transformers import StaticCache
        config = model.config.decoder_config
        self.static_cache = StaticCache(
            config=config,
            batch_size=1,
            max_cache_len=max_seq_len,
            device=self.device,
            dtype=self.dtype,
        )

        self.static_logits = None
        self.static_hidden = None
        self.graph = None
        self._captured = False

    def warmup_and_capture(self, seq_len=150):
        """Must be called after prefill to capture with correct cache state."""
        self.static_attention_mask[:, :seq_len] = 1
        self.static_cache_position.fill_(seq_len - 1)
        self.static_position_ids.fill_(seq_len - 1)

        # Warmup
        for _ in range(3):
            out = self.lm(
                inputs_embeds=self.static_inputs_embeds,
                attention_mask=self.static_attention_mask,
                position_ids=self.static_position_ids,
                past_key_values=self.static_cache,
                use_cache=True,
                cache_position=self.static_cache_position,
            )
            _ = self.lm_head(out.last_hidden_state[:, -1:, :])

        # Capture
        self.graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self.graph):
            out = self.lm(
                inputs_embeds=self.static_inputs_embeds,
                attention_mask=self.static_attention_mask,
                position_ids=self.static_position_ids,
                past_key_values=self.static_cache,
                use_cache=True,
                cache_position=self.static_cache_position,
            )
            self.static_hidden = out.last_hidden_state[:, -1:, :]
            self.static_logits = self.lm_head(self.static_hidden)

        self._captured = True

    def decode(self, inputs_embeds, attention_mask, cache_position):
        if not self._captured:
            raise RuntimeError("Call warmup_and_capture first")

        self.static_inputs_embeds.copy_(inputs_embeds)
        seq_len = attention_mask.shape[1]
        self.static_attention_mask.zero_()
        self.static_attention_mask[:, :seq_len] = attention_mask
        self.static_cache_position.copy_(cache_position)
        self.static_position_ids.fill_(cache_position.item())

        self.graph.replay()

        return self.static_logits.clone(), self.static_hidden.clone()


def benchmark_cuda_graphs(model, processor, voice_path, text, ddpm_steps=10, cfg_scale=1.3):
    """Benchmark with CUDA graph captured diffusion head."""
    if not text.startswith("Speaker"):
        text = f"Speaker 1: {text}"

    print("Setting up CUDA graph for diffusion head...")
    cg_diffusion = CUDAGraphDiffusion(
        model.model.prediction_head, model.model.noise_scheduler,
        model.config.acoustic_vae_dim, num_steps=ddpm_steps,
    )

    # Monkey-patch sample_speech_tokens to use CUDA graph
    @torch.no_grad()
    def fast_sample(condition, neg_condition, cfg_scale=1.3):
        return cg_diffusion.sample(condition, neg_condition, cfg_scale)

    original_sample = model.sample_speech_tokens
    model.sample_speech_tokens = fast_sample

    inputs = processor(
        text=[text], voice_samples=[[voice_path]],
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    streamer = AudioStreamer(batch_size=1, stop_signal=None)
    first_chunk = [None]
    total_samp = [0]
    start = [None]

    def consumer():
        for chunk in streamer.get_stream(0):
            t = time.perf_counter()
            if first_chunk[0] is None:
                first_chunk[0] = t
            total_samp[0] += chunk.shape[-1]

    th = threading.Thread(target=consumer, daemon=True)
    th.start()

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    torch.cuda.synchronize()
    start[0] = time.perf_counter()
    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False,
    )
    torch.cuda.synchronize()
    end = time.perf_counter()
    th.join(timeout=5)

    gen = end - start[0]
    ttfb = (first_chunk[0] - start[0]) * 1000 if first_chunk[0] else -1
    dur = total_samp[0] / 24000.0
    rtf = gen / dur if dur > 0 else float("inf")

    model.sample_speech_tokens = original_sample
    return {"ttfb_ms": ttfb, "gen_s": gen, "audio_s": dur, "rtf": rtf}, outputs


def get_gpu_util():
    r = subprocess.run(['nvidia-smi', '--query-gpu=utilization.gpu,power.draw', '--format=csv,noheader,nounits'],
                       capture_output=True, text=True)
    parts = r.stdout.strip().split(', ')
    return float(parts[0]), float(parts[1])


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    # Warmup baseline
    print("Warming up baseline...")
    model.set_ddpm_inference_steps(10)
    inp = processor(text=["Speaker 1: test."], voice_samples=[[voice_path]],
                    padding=True, return_tensors="pt", return_attention_mask=True)
    for k, v in inp.items():
        if torch.is_tensor(v): inp[k] = v.to("cuda")
    _ = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, show_progress_bar=False)

    # Baseline (no CUDA graphs, no compile)
    print("\n[1] Baseline: no graphs, no compile, 10 steps, cfg=1.3")
    streamer = AudioStreamer(batch_size=1, stop_signal=None)
    fc = [None]; ts = [0]; st = [None]
    def c():
        for ch in streamer.get_stream(0):
            t = time.perf_counter()
            if fc[0] is None: fc[0] = t
            ts[0] += ch.shape[-1]
    th = threading.Thread(target=c, daemon=True); th.start()
    inp2 = processor(text=[text], voice_samples=[[voice_path]], padding=True, return_tensors="pt", return_attention_mask=True)
    for k,v in inp2.items():
        if torch.is_tensor(v): inp2[k]=v.to("cuda")
    torch.manual_seed(42); torch.cuda.manual_seed_all(42)
    torch.cuda.synchronize(); st[0] = time.perf_counter()
    out = model.generate(**inp2, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
        generation_config={"do_sample": False}, verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False)
    torch.cuda.synchronize(); end = time.perf_counter(); th.join(timeout=5)
    gen = end-st[0]; ttfb = (fc[0]-st[0])*1000 if fc[0] else -1; dur = ts[0]/24000.0
    print(f"    RTF={gen/dur:.3f}x  TTFB={ttfb:.0f}ms  Audio={dur:.2f}s")

    # CUDA graph diffusion
    print("\n[2] CUDA graph diffusion head, 10 steps, cfg=1.3")
    r, out = benchmark_cuda_graphs(model, processor, voice_path, text, ddpm_steps=10, cfg_scale=1.3)
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # CUDA graph diffusion + torch.compile LM
    print("\n[3] CUDA graph diffusion + torch.compile(LM), 10 steps, cfg=1.3")
    model.model.language_model = torch.compile(model.model.language_model, mode="default")
    # Warmup compile
    inp3 = processor(text=["Speaker 1: warmup."], voice_samples=[[voice_path]], padding=True, return_tensors="pt", return_attention_mask=True)
    for k,v in inp3.items():
        if torch.is_tensor(v): inp3[k]=v.to("cuda")
    _ = model.generate(**inp3, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
        generation_config={"do_sample": False}, verbose=False, is_prefill=True, show_progress_bar=False)
    _ = model.generate(**inp3, max_new_tokens=None, cfg_scale=1.3, tokenizer=processor.tokenizer,
        generation_config={"do_sample": False}, verbose=False, is_prefill=True, show_progress_bar=False)

    r, out = benchmark_cuda_graphs(model, processor, voice_path, text, ddpm_steps=10, cfg_scale=1.3)
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")

    # Save sample
    import os
    os.makedirs("samples_cuda_graph", exist_ok=True)
    if out.speech_outputs[0] is not None:
        processor.save_audio(out.speech_outputs[0], output_path="samples_cuda_graph/modi_cg_compiled.wav")
        print(f"    Saved: samples_cuda_graph/modi_cg_compiled.wav")

    # 20 steps
    print("\n[4] CUDA graph diffusion + torch.compile(LM), 20 steps, cfg=1.3")
    r, out = benchmark_cuda_graphs(model, processor, voice_path, text, ddpm_steps=20, cfg_scale=1.3)
    print(f"    RTF={r['rtf']:.3f}x  TTFB={r['ttfb_ms']:.0f}ms  Audio={r['audio_s']:.2f}s")
    if out.speech_outputs[0] is not None:
        processor.save_audio(out.speech_outputs[0], output_path="samples_cuda_graph/modi_cg_compiled_20step.wav")

    print("\nDone.")


if __name__ == "__main__":
    main()
