"""
Measure Time To First Byte (TTFB) for VibeVoice streaming inference.
Uses AudioStreamer to capture exactly when the first audio chunk arrives.
"""

import time
import threading
import torch

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def measure_ttfb(
    model,
    processor,
    voice_path: str,
    text: str,
    ddpm_steps: int = 10,
    cfg_scale: float = 1.3,
    seed: int = 42,
):
    """Run generation with AudioStreamer and measure TTFB."""
    if not text.startswith("Speaker"):
        text = f"Speaker 1: {text}"

    inputs = processor(
        text=[text],
        voice_samples=[[voice_path]],
        padding=True,
        return_tensors="pt",
        return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    streamer = AudioStreamer(batch_size=1, stop_signal=None)

    first_chunk_time = [None]
    all_chunk_times = []
    total_audio_samples = [0]
    start_time = [None]

    def consumer():
        for chunk in streamer.get_stream(0):
            t = time.perf_counter()
            if first_chunk_time[0] is None:
                first_chunk_time[0] = t
            all_chunk_times.append(t)
            total_audio_samples[0] += chunk.shape[-1]

    consumer_thread = threading.Thread(target=consumer, daemon=True)
    consumer_thread.start()

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    torch.cuda.synchronize()
    start_time[0] = time.perf_counter()

    outputs = model.generate(
        **inputs,
        max_new_tokens=None,
        cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer,
        generation_config={"do_sample": False},
        verbose=False,
        is_prefill=True,
        audio_streamer=streamer,
        show_progress_bar=False,
    )

    torch.cuda.synchronize()
    end_time = time.perf_counter()

    consumer_thread.join(timeout=5)

    total_gen_time = end_time - start_time[0]
    ttfb = first_chunk_time[0] - start_time[0] if first_chunk_time[0] else float("inf")
    num_chunks = len(all_chunk_times)
    audio_dur = total_audio_samples[0] / 24000.0

    chunk_deltas = []
    if len(all_chunk_times) > 1:
        for i in range(1, len(all_chunk_times)):
            chunk_deltas.append(all_chunk_times[i] - all_chunk_times[i - 1])

    return {
        "ttfb_ms": ttfb * 1000,
        "total_gen_time_s": total_gen_time,
        "audio_duration_s": audio_dur,
        "rtf": total_gen_time / audio_dur if audio_dur > 0 else float("inf"),
        "num_chunks": num_chunks,
        "avg_chunk_interval_ms": (sum(chunk_deltas) / len(chunk_deltas) * 1000) if chunk_deltas else 0,
        "chunk_audio_ms": (audio_dur / num_chunks * 1000) if num_chunks > 0 else 0,
    }


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"

    print(f"Loading processor from {model_path}")
    processor = VibeVoiceProcessor.from_pretrained(model_path)

    print("Loading model with flash_attention_2")
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="cuda",
            attn_implementation="flash_attention_2",
        )
    except Exception:
        print("flash_attention_2 failed, falling back to sdpa")
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="cuda",
            attn_implementation="sdpa",
        )

    model.eval()

    texts = {
        "short": "Speaker 1: नमस्ते, मेरे प्यारे देशवासियों.",
        "medium": "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.",
        "long": "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है, जहाँ टेक्नोलॉजी और इनोवेशन हमारी ताकत बन रही है. डिजिटल इंडिया ने देश की तस्वीर बदल दी है. गाँव गाँव में इंटरनेट पहुँच रहा है.",
    }

    step_configs = [5, 10]

    # Warmup
    print("\nWarming up...")
    _ = measure_ttfb(model, processor, voice_path, "Speaker 1: test.", ddpm_steps=5)

    print("\n" + "=" * 80)
    print("TTFB BENCHMARK")
    print("=" * 80)
    print(f"{'Config':<22} {'TTFB':>8} {'GenTime':>9} {'AudioDur':>9} {'RTF':>7} {'Chunks':>7} {'ChunkInt':>10} {'ChunkAud':>10}")
    print("-" * 80)

    for label, text in texts.items():
        for steps in step_configs:
            r = measure_ttfb(model, processor, voice_path, text, ddpm_steps=steps)
            tag = f"{label}/steps={steps}"
            print(
                f"{tag:<22} "
                f"{r['ttfb_ms']:>7.0f}ms "
                f"{r['total_gen_time_s']:>8.2f}s "
                f"{r['audio_duration_s']:>8.2f}s "
                f"{r['rtf']:>6.3f}x "
                f"{r['num_chunks']:>6} "
                f"{r['avg_chunk_interval_ms']:>9.1f}ms "
                f"{r['chunk_audio_ms']:>9.1f}ms"
            )

    print("=" * 80)


if __name__ == "__main__":
    main()
