"""
Measure max concurrency on A100 80GB for VibeVoice 1.5B.
Tests batch sizes 1 -> N until OOM, measures throughput and memory.
"""

import gc
import time
import torch
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def get_gpu_mem():
    return {
        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
        "free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9,
        "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
    }


def run_batch(model, processor, voice_path, batch_size, ddpm_steps=20, cfg_scale=1.3):
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

    texts = [text] * batch_size
    voice_batch = [[voice_path]] * batch_size

    inputs = processor(
        text=texts,
        voice_samples=voice_batch,
        padding=True,
        return_tensors="pt",
        return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    t0 = time.perf_counter()

    outputs = model.generate(
        **inputs,
        max_new_tokens=None,
        cfg_scale=cfg_scale,
        tokenizer=processor.tokenizer,
        generation_config={"do_sample": False},
        verbose=False,
        is_prefill=True,
        show_progress_bar=False,
    )

    torch.cuda.synchronize()
    gen_time = time.perf_counter() - t0
    peak_mem = torch.cuda.max_memory_allocated() / 1e9

    total_audio_dur = 0
    for audio in outputs.speech_outputs:
        if audio is not None:
            total_audio_dur += audio.shape[-1] / 24000.0

    return {
        "batch_size": batch_size,
        "gen_time_s": gen_time,
        "total_audio_s": total_audio_dur,
        "avg_audio_s": total_audio_dur / batch_size,
        "throughput_audio_s_per_s": total_audio_dur / gen_time,
        "peak_mem_gb": peak_mem,
        "rtf_per_sample": gen_time / (total_audio_dur / batch_size) if total_audio_dur > 0 else float("inf"),
    }


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"

    print("Loading model...")
    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    mem = get_gpu_mem()
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Total VRAM: {mem['total_gb']:.1f} GB")
    print(f"Model loaded: {mem['allocated_gb']:.1f} GB used\n")

    # Warmup
    print("Warming up...")
    _ = run_batch(model, processor, voice_path, 1, ddpm_steps=20)

    batch_sizes = [1, 2, 4, 6, 8, 10, 12, 16]

    print("\n" + "=" * 100)
    print(f"{'Batch':>6} {'GenTime':>9} {'TotalAudio':>11} {'Throughput':>12} {'PeakMem':>9} {'RTF/sample':>11} {'Status':<10}")
    print("-" * 100)

    for bs in batch_sizes:
        try:
            gc.collect()
            torch.cuda.empty_cache()
            r = run_batch(model, processor, voice_path, bs, ddpm_steps=20)
            print(
                f"{r['batch_size']:>6} "
                f"{r['gen_time_s']:>8.2f}s "
                f"{r['total_audio_s']:>10.2f}s "
                f"{r['throughput_audio_s_per_s']:>9.2f} s/s "
                f"{r['peak_mem_gb']:>8.1f}GB "
                f"{r['rtf_per_sample']:>10.3f}x "
                f"OK"
            )
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            gc.collect()
            print(f"{bs:>6} {'':>8}  {'':>10}  {'':>9}     {'':>8}  {'':>10}  OOM")
            break
        except Exception as e:
            print(f"{bs:>6} {'':>8}  {'':>10}  {'':>9}     {'':>8}  {'':>10}  ERROR: {e}")
            break

    print("=" * 100)
    print("\nThroughput = total seconds of audio generated per second of wall time")
    print("RTF/sample = how long each individual request takes (lower = faster per user)")
    print("For production: pick batch size where throughput is highest AND RTF/sample is acceptable")


if __name__ == "__main__":
    main()
