"""LTX-2 throughput test — how many videos per minute with persistent engine."""

import os
import time

os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")

import torch
import ltx_core.loader

from ltx_core.quantization import QuantizationPolicy
from ltx2_server import LTX2PersistentEngine

MODEL_DIR = "/home/ubuntu/ltx2-models"

engine = LTX2PersistentEngine(
    distilled_checkpoint_path=f"{MODEL_DIR}/ltx-2.3-22b-distilled-fp8.safetensors",
    spatial_upsampler_path=f"{MODEL_DIR}/ltx-2.3-spatial-upscaler-x2-1.0.safetensors",
    gemma_root=f"{MODEL_DIR}/gemma3",
    quantization=QuantizationPolicy.fp8_cast(),
)

prompts = [
    "A golden retriever running through wildflowers at sunset",
    "Futuristic cityscape at night with neon lights on wet streets",
    "Ocean waves crashing against rocky cliffs, slow motion",
    "Chef preparing sushi in a Japanese restaurant, close-up",
    "Astronaut floating in space with Earth in background",
    "Rain falling on a quiet Japanese garden with koi pond",
    "Sports car drifting on a mountain road, drone shot",
    "Northern lights over a frozen lake, timelapse",
    "Street musician playing guitar in Paris at dusk",
    "Underwater coral reef with tropical fish swimming",
]

os.makedirs("/home/ubuntu/ltx2_bench_output", exist_ok=True)

# Warmup
print("Warmup...")
engine.generate(prompt=prompts[0], seed=0, num_frames=41)

# Burst: 10 generations, with file save
print("\n" + "=" * 60)
print("BURST TEST: 10 videos WITH file save")
print("=" * 60)
t0 = time.time()
for i, p in enumerate(prompts):
    t = engine.generate(
        prompt=p, seed=i, num_frames=41,
        output_path=f"/home/ubuntu/ltx2_bench_output/burst_{i}.mp4",
    )
    print(f"  [{i+1}/10] {t['total']:.2f}s  (denoise: {t['stage1_denoise']+t['stage2_denoise']:.2f}s, encode: {t.get('file_encode',0):.2f}s)")
burst_with_save = time.time() - t0
print(f"Total: {burst_with_save:.1f}s | Rate: {10/burst_with_save*60:.1f} videos/min")

# Burst: 10 generations, NO file save (pure GPU throughput)
print("\n" + "=" * 60)
print("BURST TEST: 10 videos WITHOUT file save (pure GPU)")
print("=" * 60)
t0 = time.time()
for i, p in enumerate(prompts):
    t = engine.generate(prompt=p, seed=i, num_frames=41)
    print(f"  [{i+1}/10] {t['total']:.2f}s")
burst_no_save = time.time() - t0
print(f"Total: {burst_no_save:.1f}s | Rate: {10/burst_no_save*60:.1f} videos/min")

print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"With file save:    {10/burst_with_save*60:.1f} videos/min  ({burst_with_save/10:.2f}s each)")
print(f"Without file save: {10/burst_no_save*60:.1f} videos/min  ({burst_no_save/10:.2f}s each)")
print(f"Peak VRAM: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
