"""Fine-grained benchmark: find max concurrency within 15s on A100."""

import gc
import time
import torch
from PIL import Image

from diffusers import QwenImageEditPlusPipeline

print("Loading model...")
pipe = QwenImageEditPlusPipeline.from_pretrained(
    "FireRedTeam/FireRed-Image-Edit-1.1",
    torch_dtype=torch.bfloat16,
).to("cuda")

pipe.load_lora_weights(
    "FireRedTeam/FireRed-Image-Edit-LoRA-Zoo",
    weight_name="FireRed-Image-Edit-Lightning-8steps-v1.0.safetensors",
    adapter_name="lightning",
)
pipe.set_adapters(["lightning"], adapter_weights=[1.0])

# Warmup
dummy = Image.new("RGB", (512, 512), (128, 128, 128))
with torch.inference_mode():
    pipe(image=[dummy], prompt="test", negative_prompt=" ",
         num_inference_steps=4, true_cfg_scale=1.0, guidance_scale=1.0,
         generator=torch.Generator(device="cuda").manual_seed(0), num_images_per_prompt=1)
torch.cuda.empty_cache()

print("\n" + "=" * 70)
print("TARGET: each image must complete within 15 seconds")
print("Testing different resolutions x batch sizes x steps")
print("=" * 70)

configs = [
    # (resolution, steps, label)
    (512, 6, "512x512, 6 steps"),
    (512, 4, "512x512, 4 steps"),
    (768, 6, "768x768, 6 steps"),
    (768, 4, "768x768, 4 steps"),
    (1024, 6, "1024x1024, 6 steps"),
    (1024, 4, "1024x1024, 4 steps"),
]

for res, steps, label in configs:
    print(f"\n{'─'*70}")
    print(f"Config: {label}")
    print(f"{'─'*70}")

    test_image = Image.new("RGB", (res, res), (100, 150, 200))

    for batch_size in range(1, 9):
        torch.cuda.empty_cache()
        gc.collect()

        images = [test_image] * batch_size
        gen = torch.Generator(device="cuda").manual_seed(42)

        try:
            t0 = time.perf_counter()
            with torch.inference_mode():
                result = pipe(
                    image=images,
                    prompt="Change the background to a sunset",
                    negative_prompt=" ",
                    num_inference_steps=steps,
                    true_cfg_scale=1.0,
                    guidance_scale=1.0,
                    generator=gen,
                    num_images_per_prompt=1,
                )
            elapsed = time.perf_counter() - t0
            peak = torch.cuda.max_memory_allocated() / 1e9

            status = "✓" if elapsed <= 15.0 else "✗"
            print(f"  Batch {batch_size}: {elapsed:.1f}s  (peak {peak:.1f}GB)  {status}")

            if elapsed > 15.0:
                print(f"  → Max concurrency for {label} = {batch_size - 1}")
                break

        except torch.cuda.OutOfMemoryError:
            print(f"  Batch {batch_size}: OOM")
            torch.cuda.empty_cache()
            gc.collect()
            print(f"  → Max concurrency for {label} = {batch_size - 1}")
            break
    else:
        print(f"  → All batches up to 8 fit within 15s!")

print("\n" + "=" * 70)
print("DONE")
