"""
Pre-compute Modi voice embeddings once, save to disk.
Patch the model to skip re-encoding when cache is available.
Benchmark TTFB with cached vs uncached.
"""

import time
import threading
import torch
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def patch_model_for_cache(model):
    """Monkey-patch _process_speech_inputs to use cached embeddings when available."""
    original_process = model._process_speech_inputs

    def patched_process(speech_tensors, speech_masks, speech_type="audio"):
        if hasattr(model, '_cached_speech_embeds') and model._cached_speech_embeds is not None:
            cached = model._cached_speech_embeds
            return cached["acoustic_features"].to(speech_tensors.device), cached["acoustic_connected"].to(speech_tensors.device)
        return original_process(speech_tensors, speech_masks, speech_type)

    model._process_speech_inputs = patched_process
    model._cached_speech_embeds = None


def measure_ttfb(model, processor, voice_path, text, use_cache=False):
    if not text.startswith("Speaker"):
        text = f"Speaker 1: {text}"

    inputs = processor(
        text=[text], voice_samples=[[voice_path]],
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    streamer = AudioStreamer(batch_size=1, stop_signal=None)
    first_chunk_time = [None]
    total_samples = [0]
    start_time = [None]

    def consumer():
        for chunk in streamer.get_stream(0):
            t = time.perf_counter()
            if first_chunk_time[0] is None:
                first_chunk_time[0] = t
            total_samples[0] += chunk.shape[-1]

    t = threading.Thread(target=consumer, daemon=True)
    t.start()

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

    if use_cache:
        model._cached_speech_embeds = model._voice_cache
    else:
        model._cached_speech_embeds = None

    torch.cuda.synchronize()
    start_time[0] = time.perf_counter()

    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=1.3,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False,
    )

    torch.cuda.synchronize()
    end_time = time.perf_counter()
    t.join(timeout=5)

    gen_time = end_time - start_time[0]
    ttfb = (first_chunk_time[0] - start_time[0]) * 1000 if first_chunk_time[0] else -1
    audio_dur = total_samples[0] / 24000.0
    rtf = gen_time / audio_dur if audio_dur > 0 else float("inf")

    return {"ttfb_ms": ttfb, "gen_time_s": gen_time, "audio_s": audio_dur, "rtf": rtf}


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"

    print("Loading model...")
    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()
    model.set_ddpm_inference_steps(num_steps=20)

    # Patch model
    patch_model_for_cache(model)

    # =========================================================
    # Step 1: Pre-encode voice and cache
    # =========================================================
    print("\n[STEP 1] Pre-encoding Modi voice...")
    inputs_dummy = processor(
        text=["Speaker 1: test"],
        voice_samples=[[voice_path]],
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs_dummy.items():
        if torch.is_tensor(v):
            inputs_dummy[k] = v.to("cuda")

    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with torch.no_grad():
        speech_tensors = inputs_dummy["speech_tensors"].to(model.dtype)
        speech_masks = inputs_dummy["speech_masks"]

        encoder_output = model.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
        acoustic_latents = encoder_output.sample(dist_type=model.model.acoustic_tokenizer.std_dist_type)[0]
        acoustic_features = (acoustic_latents + model.model.speech_bias_factor.to(acoustic_latents.device)) * model.model.speech_scaling_factor.to(acoustic_latents.device)
        acoustic_connected = model.model.acoustic_connector(acoustic_features)[speech_masks.cpu()]
    torch.cuda.synchronize()
    encode_time = time.perf_counter() - t0

    model._voice_cache = {
        "acoustic_features": acoustic_features,
        "acoustic_connected": acoustic_connected,
    }

    cache_path = "demo/voices/modi_cached.pt"
    torch.save({k: v.cpu() for k, v in model._voice_cache.items()}, cache_path)
    print(f"  Encoding time: {encode_time*1000:.0f}ms (this is saved on every request with cache)")
    print(f"  Saved to: {cache_path}")

    # =========================================================
    # Step 2: Benchmark
    # =========================================================
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

    print("\nWarming up...")
    _ = measure_ttfb(model, processor, voice_path, text, use_cache=False)

    print(f"\n{'='*80}")
    print(f"{'Mode':<30} {'TTFB':>8} {'GenTime':>9} {'Audio':>8} {'RTF':>7}")
    print(f"{'-'*80}")

    for i in range(3):
        r = measure_ttfb(model, processor, voice_path, text, use_cache=False)
        print(f"{'Uncached (re-encode) #'+str(i+1):<30} {r['ttfb_ms']:>7.0f}ms {r['gen_time_s']:>8.2f}s {r['audio_s']:>7.2f}s {r['rtf']:>6.3f}x")

    for i in range(3):
        r = measure_ttfb(model, processor, voice_path, text, use_cache=True)
        print(f"{'Cached (skip encode) #'+str(i+1):<30} {r['ttfb_ms']:>7.0f}ms {r['gen_time_s']:>8.2f}s {r['audio_s']:>7.2f}s {r['rtf']:>6.3f}x")

    print(f"{'='*80}")


if __name__ == "__main__":
    main()
