"""Benchmark FireRed-Edit-1.1 on A100 — measure VRAM and find max batch size."""

import gc
import time
import torch
from PIL import Image

def gpu_mem():
    """Return (allocated_GB, reserved_GB, total_GB)."""
    a = torch.cuda.memory_allocated() / 1e9
    r = torch.cuda.memory_reserved() / 1e9
    t = torch.cuda.get_device_properties(0).total_memory / 1e9
    return a, r, t

def log_mem(label):
    a, r, t = gpu_mem()
    print(f"[{label}] Allocated: {a:.2f} GB | Reserved: {r:.2f} GB | Total: {t:.2f} GB | Free: {t-r:.2f} GB")

print("=" * 70)
print("FireRed-Edit-1.1 VRAM Benchmark on A100-80GB")
print("=" * 70)

log_mem("Before load")

# ── Load model ──
print("\nLoading pipeline...")
t0 = time.time()

from diffusers import QwenImageEditPlusPipeline

pipe = QwenImageEditPlusPipeline.from_pretrained(
    "FireRedTeam/FireRed-Image-Edit-1.1",
    torch_dtype=torch.bfloat16,
).to("cuda")

log_mem("After model load")

print("\nLoading Lightning LoRA...")
pipe.load_lora_weights(
    "FireRedTeam/FireRed-Image-Edit-LoRA-Zoo",
    weight_name="FireRed-Image-Edit-Lightning-8steps-v1.0.safetensors",
    adapter_name="lightning",
)
pipe.set_adapters(["lightning"], adapter_weights=[1.0])
log_mem("After LoRA load")

print(f"\nTotal load time: {time.time() - t0:.1f}s")

# ── Warmup ──
print("\nWarmup run...")
dummy = Image.new("RGB", (512, 512), (128, 128, 128))
gen = torch.Generator(device="cuda").manual_seed(0)
with torch.inference_mode():
    pipe(
        image=[dummy],
        prompt="test",
        negative_prompt=" ",
        num_inference_steps=4,
        true_cfg_scale=1.0,
        guidance_scale=1.0,
        generator=gen,
        num_images_per_prompt=1,
    )
torch.cuda.empty_cache()
log_mem("After warmup + cache clear")

# ── Benchmark batch sizes ──
print("\n" + "=" * 70)
print("BATCH SIZE BENCHMARK (512x512 images, 6 steps, cfg=1.0)")
print("=" * 70)

test_image = Image.new("RGB", (512, 512), (100, 150, 200))

for batch_size in [1, 2, 3, 4, 5, 6, 7, 8]:
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.reset_peak_memory_stats()

    images = [test_image] * batch_size
    gen = torch.Generator(device="cuda").manual_seed(42)

    try:
        t0 = time.perf_counter()
        with torch.inference_mode():
            result = pipe(
                image=images,
                prompt="Change the background to a sunset",
                negative_prompt=" ",
                num_inference_steps=6,
                true_cfg_scale=1.0,
                guidance_scale=1.0,
                generator=gen,
                num_images_per_prompt=1,
            )
        elapsed = time.perf_counter() - t0
        peak = torch.cuda.max_memory_allocated() / 1e9

        print(f"\nBatch {batch_size}: ✓ SUCCESS")
        print(f"  Time: {elapsed:.2f}s ({elapsed/batch_size:.2f}s/image)")
        print(f"  Peak VRAM: {peak:.2f} GB")
        print(f"  Throughput: {batch_size/elapsed:.2f} images/s")

    except torch.cuda.OutOfMemoryError:
        peak = torch.cuda.max_memory_allocated() / 1e9
        print(f"\nBatch {batch_size}: ✗ OOM (peak was {peak:.2f} GB)")
        torch.cuda.empty_cache()
        gc.collect()
        break

    except Exception as e:
        print(f"\nBatch {batch_size}: ✗ ERROR: {e}")
        torch.cuda.empty_cache()
        gc.collect()
        break

print("\n" + "=" * 70)
print("DONE")
print("=" * 70)
