"""
Concurrency benchmark with STREAMING + TTFB measurement.
Tests batch 1 -> N, measures TTFB per sample and throughput.
"""

import gc
import time
import threading
import torch
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def run_streaming_batch(model, processor, voice_path, batch_size, ddpm_steps=20, cfg_scale=1.3):
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

    texts = [text] * batch_size
    voice_batch = [[voice_path]] * batch_size

    inputs = processor(
        text=texts, voice_samples=voice_batch,
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    streamer = AudioStreamer(batch_size=batch_size, stop_signal=None)

    first_chunk_times = [None] * batch_size
    chunk_counts = [0] * batch_size
    total_samples = [0] * batch_size
    start_time = [None]

    def consumer(idx):
        for chunk in streamer.get_stream(idx):
            t = time.perf_counter()
            if first_chunk_times[idx] is None:
                first_chunk_times[idx] = t
            chunk_counts[idx] += 1
            total_samples[idx] += chunk.shape[-1]

    threads = [threading.Thread(target=consumer, args=(i,), daemon=True) for i in range(batch_size)]
    for t in threads:
        t.start()

    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    start_time[0] = time.perf_counter()

    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False,
    )

    torch.cuda.synchronize()
    end_time = time.perf_counter()

    for t in threads:
        t.join(timeout=5)

    gen_time = end_time - start_time[0]
    peak_mem = torch.cuda.max_memory_allocated() / 1e9

    ttfbs = []
    for i in range(batch_size):
        if first_chunk_times[i] is not None:
            ttfbs.append((first_chunk_times[i] - start_time[0]) * 1000)

    total_audio = sum(s / 24000.0 for s in total_samples)
    avg_ttfb = sum(ttfbs) / len(ttfbs) if ttfbs else -1
    min_ttfb = min(ttfbs) if ttfbs else -1
    max_ttfb = max(ttfbs) if ttfbs else -1

    return {
        "batch_size": batch_size,
        "gen_time_s": gen_time,
        "total_audio_s": total_audio,
        "throughput": total_audio / gen_time if gen_time > 0 else 0,
        "peak_mem_gb": peak_mem,
        "min_ttfb_ms": min_ttfb,
        "avg_ttfb_ms": avg_ttfb,
        "max_ttfb_ms": max_ttfb,
    }


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"

    print("Loading model...")
    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    mem_gb = torch.cuda.memory_allocated() / 1e9
    total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {torch.cuda.get_device_name(0)} | {total_gb:.0f}GB total | {mem_gb:.1f}GB model")

    print("\nWarming up...")
    _ = run_streaming_batch(model, processor, voice_path, 1, ddpm_steps=20)

    batch_sizes = [1, 2, 4, 8, 12, 16, 24, 32]

    print(f"\n{'='*110}")
    print(f"{'Batch':>6} {'GenTime':>9} {'Audio':>8} {'Thruput':>10} {'VRAM':>7} {'MinTTFB':>9} {'AvgTTFB':>9} {'MaxTTFB':>9} {'Status':<6}")
    print(f"{'-'*110}")

    for bs in batch_sizes:
        try:
            gc.collect()
            torch.cuda.empty_cache()
            r = run_streaming_batch(model, processor, voice_path, bs, ddpm_steps=20)
            print(
                f"{r['batch_size']:>6} "
                f"{r['gen_time_s']:>8.2f}s "
                f"{r['total_audio_s']:>7.1f}s "
                f"{r['throughput']:>7.2f} s/s "
                f"{r['peak_mem_gb']:>6.1f}GB "
                f"{r['min_ttfb_ms']:>8.0f}ms "
                f"{r['avg_ttfb_ms']:>8.0f}ms "
                f"{r['max_ttfb_ms']:>8.0f}ms "
                f"OK"
            )
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            gc.collect()
            print(f"{bs:>6} {'OOM':>60}")
            break
        except Exception as e:
            print(f"{bs:>6} ERROR: {e}")
            break

    print(f"{'='*110}")


if __name__ == "__main__":
    main()
