"""
Find the actual max batch size on A100 80GB.
Push until OOM to find the real VRAM ceiling.
"""

import gc
import time
import torch
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


def run_batch(model, processor, voice_path, batch_size, ddpm_steps=10):
    text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ."
    texts = [text] * batch_size
    voice_batch = [[voice_path]] * batch_size

    inputs = processor(
        text=texts, voice_samples=voice_batch,
        padding=True, return_tensors="pt", return_attention_mask=True,
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to("cuda")

    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    model.set_ddpm_inference_steps(num_steps=ddpm_steps)

    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    t0 = time.perf_counter()

    outputs = model.generate(
        **inputs, max_new_tokens=None, cfg_scale=1.3,
        tokenizer=processor.tokenizer, generation_config={"do_sample": False},
        verbose=False, is_prefill=True, show_progress_bar=False,
    )

    torch.cuda.synchronize()
    gen_time = time.perf_counter() - t0
    peak_mem = torch.cuda.max_memory_allocated() / 1e9

    total_audio = 0
    for audio in outputs.speech_outputs:
        if audio is not None:
            total_audio += audio.shape[-1] / 24000.0

    throughput = total_audio / gen_time if gen_time > 0 else 0
    return gen_time, total_audio, throughput, peak_mem


def main():
    voice_path = "demo/voices/modi.wav"
    model_path = "microsoft/VibeVoice-1.5B"

    print("Loading model...")
    processor = VibeVoiceProcessor.from_pretrained(model_path)
    try:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="flash_attention_2")
    except:
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="cuda",
            attn_implementation="sdpa")
    model.eval()

    total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    model_gb = torch.cuda.memory_allocated() / 1e9
    print(f"GPU: {torch.cuda.get_device_name(0)} | {total_gb:.0f}GB total | {model_gb:.1f}GB model")

    # Warmup
    print("Warming up...")
    _ = run_batch(model, processor, voice_path, 1, ddpm_steps=10)

    # 10 steps with full CFG for speed
    batch_sizes = [1, 8, 16, 32, 48, 64, 96, 128, 160, 192, 224, 256]

    print(f"\n{'='*90}")
    print(f"{'Batch':>6} {'GenTime':>9} {'TotalAudio':>11} {'Throughput':>12} {'PeakVRAM':>10} {'VRAM/req':>9}")
    print(f"{'-'*90}")

    last_ok_batch = 0
    for bs in batch_sizes:
        try:
            gc.collect()
            torch.cuda.empty_cache()
            gen_time, total_audio, throughput, peak_mem = run_batch(model, processor, voice_path, bs, ddpm_steps=10)
            vram_per_req = (peak_mem - model_gb) / bs
            print(
                f"{bs:>6} "
                f"{gen_time:>8.2f}s "
                f"{total_audio:>10.1f}s "
                f"{throughput:>9.1f} s/s "
                f"{peak_mem:>9.1f}GB "
                f"{vram_per_req*1000:>7.0f}MB "
                f"OK"
            )
            last_ok_batch = bs
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            gc.collect()
            print(f"{bs:>6} {'':>50} OOM")
            break
        except Exception as e:
            print(f"{bs:>6} ERROR: {e}")
            break

    print(f"{'='*90}")
    print(f"\nMax batch before OOM: {last_ok_batch}")
    avail = total_gb - model_gb
    print(f"Available VRAM for batching: {avail:.1f}GB")


if __name__ == "__main__":
    main()
