"""
Profile VibeVoice generation loop to find exact time breakdown per component.
Patches the generate() method to instrument each component.
"""

import time
import torch
import numpy as np
from collections import defaultdict

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def profile_generation(model, processor, voice_path, text, ddpm_steps=20):
    if not text.startswith("Speaker"):
        text = f"Speaker 1: {text}"

    inputs = processor(
        text=[text], voice_samples=[[voice_path]],
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    # Use torch profiler for detailed GPU timing
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        record_shapes=True,
        with_stack=False,
    ) as prof:
        outputs = model.generate(
            **inputs, max_new_tokens=None, cfg_scale=1.3,
            tokenizer=processor.tokenizer, generation_config={"do_sample": False},
            verbose=False, is_prefill=True, show_progress_bar=False,
        )

    # Print top CUDA operations
    print("\n=== TOP 30 CUDA OPERATIONS (by total CUDA time) ===")
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))

    # Print self CUDA time
    print("\n=== TOP 20 SELF CUDA TIME ===")
    print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20))

    return outputs


def manual_profile(model, processor, voice_path, text, ddpm_steps=20):
    """Manual instrumented profiling with cuda events for per-component timing."""
    if not text.startswith("Speaker"):
        text = f"Speaker 1: {text}"

    inputs = processor(
        text=[text], voice_samples=[[voice_path]],
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    timings = defaultdict(list)

    # Instrument the key methods
    orig_forward = model.forward
    orig_sample = model.sample_speech_tokens
    orig_acoustic_decode = model.model.acoustic_tokenizer.decode
    orig_semantic_encode = model.model.semantic_tokenizer.encode

    def timed_forward(*args, **kwargs):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        result = orig_forward(*args, **kwargs)
        torch.cuda.synchronize()
        timings["lm_forward"].append(time.perf_counter() - t0)
        return result

    def timed_sample(*args, **kwargs):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        result = orig_sample(*args, **kwargs)
        torch.cuda.synchronize()
        timings["diffusion_sample"].append(time.perf_counter() - t0)
        return result

    def timed_acoustic_decode(*args, **kwargs):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        result = orig_acoustic_decode(*args, **kwargs)
        torch.cuda.synchronize()
        timings["acoustic_decode"].append(time.perf_counter() - t0)
        return result

    def timed_semantic_encode(*args, **kwargs):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        result = orig_semantic_encode(*args, **kwargs)
        torch.cuda.synchronize()
        timings["semantic_encode"].append(time.perf_counter() - t0)
        return result

    model.forward = timed_forward
    model.sample_speech_tokens = timed_sample
    model.model.acoustic_tokenizer.decode = timed_acoustic_decode
    model.model.semantic_tokenizer.encode = timed_semantic_encode

    torch.cuda.synchronize()
    t_total_start = time.perf_counter()

    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=1.3,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, show_progress_bar=False,
    )

    torch.cuda.synchronize()
    t_total = time.perf_counter() - t_total_start

    # Restore original methods
    model.forward = orig_forward
    model.sample_speech_tokens = orig_sample
    model.model.acoustic_tokenizer.decode = orig_acoustic_decode
    model.model.semantic_tokenizer.encode = orig_semantic_encode

    audio_dur = outputs.speech_outputs[0].shape[-1] / 24000.0 if outputs.speech_outputs[0] is not None else 0

    print(f"\n{'='*80}")
    print(f"COMPONENT PROFILING (20 DDPM steps, batch=1)")
    print(f"{'='*80}")
    print(f"Total generation time: {t_total:.3f}s")
    print(f"Audio duration: {audio_dur:.2f}s")
    print(f"RTF: {t_total/audio_dur:.3f}x")
    print(f"Generated tokens: {len(timings['diffusion_sample'])} speech + {len(timings['lm_forward']) - 2*len(timings['diffusion_sample'])} control")
    print()

    total_profiled = 0
    # lm_forward includes both positive and negative passes
    # For diffusion steps: 2 LM forwards (positive + negative) per speech token
    # For control tokens: 1 LM forward per token
    lm_calls = len(timings["lm_forward"])
    diffusion_calls = len(timings["diffusion_sample"])

    for name, vals in sorted(timings.items(), key=lambda x: -sum(x[1])):
        total = sum(vals)
        count = len(vals)
        avg = total / count if count > 0 else 0
        pct = total / t_total * 100
        total_profiled += total
        print(f"  {name:<25} total={total:>7.3f}s  count={count:>4}  avg={avg*1000:>7.2f}ms  ({pct:>5.1f}%)")

    overhead = t_total - total_profiled
    print(f"  {'python_overhead':<25} total={overhead:>7.3f}s  {'':>12}  {'':>12}  ({overhead/t_total*100:>5.1f}%)")
    print(f"{'='*80}")

    # Per-speech-token breakdown
    if diffusion_calls > 0:
        lm_per_diff = sum(timings["lm_forward"]) / lm_calls * 2  # 2 LM calls per diffusion token (pos+neg)
        diff_per_token = sum(timings["diffusion_sample"]) / diffusion_calls
        acoustic_per_token = sum(timings["acoustic_decode"]) / diffusion_calls
        semantic_per_token = sum(timings["semantic_encode"]) / diffusion_calls

        total_per_token = lm_per_diff + diff_per_token + acoustic_per_token + semantic_per_token
        audio_per_token = 1000 / 7.5  # 133.3ms of audio per token at 7.5Hz

        print(f"\nPER SPEECH TOKEN BREAKDOWN (avg):")
        print(f"  LM forward (pos+neg):  {lm_per_diff*1000:>7.2f}ms")
        print(f"  Diffusion ({ddpm_steps} steps):  {diff_per_token*1000:>7.2f}ms")
        print(f"  Acoustic decode:       {acoustic_per_token*1000:>7.2f}ms")
        print(f"  Semantic encode:       {semantic_per_token*1000:>7.2f}ms")
        print(f"  Total per token:       {total_per_token*1000:>7.2f}ms")
        print(f"  Audio per token:       {audio_per_token:>7.2f}ms")
        print(f"  Token RTF:             {total_per_token*1000/audio_per_token:.3f}x")
        print(f"{'='*80}")


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"

    print("Loading model...")
    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

    # Warmup
    print("Warming up...")
    model.set_ddpm_inference_steps(num_steps=20)
    torch.manual_seed(42); torch.cuda.manual_seed_all(42)
    inp = processor(text=["Speaker 1: test."], voice_samples=[[voice_path]],
                    padding=True, return_tensors="pt", return_attention_mask=True)
    for k, v in inp.items():
        if torch.is_tensor(v): inp[k] = v.to("cuda")
    _ = model.generate(**inp, max_new_tokens=None, cfg_scale=1.3,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, show_progress_bar=False)

    print("\n\n" + "#" * 80)
    print("# MANUAL COMPONENT PROFILING")
    print("#" * 80)
    manual_profile(model, processor, voice_path, text, ddpm_steps=20)


if __name__ == "__main__":
    main()
