import base64
import io
import time
import torch
from PIL import Image, ImageOps


class Model:
    def __init__(self, **kwargs):
        self._pipe = None
        self._secrets = kwargs.get("secrets", {})

    def load(self):
        from diffusers import QwenImageEditPlusPipeline

        print("Loading FireRed-Edit-1.1...")
        t0 = time.time()

        self._pipe = QwenImageEditPlusPipeline.from_pretrained(
            "FireRedTeam/FireRed-Image-Edit-1.1",
            torch_dtype=torch.bfloat16,
        ).to("cuda")

        print(f"Base model loaded in {time.time() - t0:.1f}s")

        # Load Lightning LoRA
        print("Loading Lightning LoRA...")
        self._pipe.load_lora_weights(
            "FireRedTeam/FireRed-Image-Edit-LoRA-Zoo",
            weight_name="FireRed-Image-Edit-Lightning-8steps-v1.0.safetensors",
            adapter_name="lightning",
        )
        self._pipe.set_adapters(["lightning"], adapter_weights=[1.0])
        print(f"Lightning LoRA loaded. Total: {time.time() - t0:.1f}s")

        # Warmup
        dummy = Image.new("RGB", (512, 512), (128, 128, 128))
        gen = torch.Generator(device="cuda").manual_seed(0)
        with torch.inference_mode():
            self._pipe(
                image=[dummy],
                prompt="test",
                negative_prompt=" ",
                num_inference_steps=4,
                true_cfg_scale=1.0,
                guidance_scale=1.0,
                generator=gen,
                num_images_per_prompt=1,
            )
        torch.cuda.empty_cache()
        print(f"Warmup done. Ready in {time.time() - t0:.1f}s")

    def predict(self, request):
        t0 = time.perf_counter()

        # Parse input
        image_b64 = request.get("image")
        prompt = request.get("prompt", "")
        steps = request.get("steps", 6)
        cfg = request.get("cfg", 1.0)
        seed = request.get("seed", 42)

        if not image_b64 or not prompt:
            return {"status": "error", "message": "Missing 'image' or 'prompt'"}

        # Decode input image
        image_bytes = base64.b64decode(image_b64)
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        image = ImageOps.exif_transpose(image)

        # Run inference
        gen = torch.Generator(device="cuda").manual_seed(seed)
        with torch.inference_mode():
            result = self._pipe(
                image=[image],
                prompt=prompt,
                negative_prompt=" ",
                num_inference_steps=steps,
                true_cfg_scale=cfg,
                guidance_scale=1.0,
                generator=gen,
                num_images_per_prompt=1,
            )

        output_image = result.images[0]
        elapsed = time.perf_counter() - t0

        # Encode output
        buf = io.BytesIO()
        output_image.save(buf, format="PNG")
        output_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")

        return {
            "status": "success",
            "image": output_b64,
            "time_s": round(elapsed, 2),
            "config": {"steps": steps, "cfg": cfg, "seed": seed},
            "input_size": f"{image.size[0]}x{image.size[1]}",
            "output_size": f"{output_image.size[0]}x{output_image.size[1]}",
        }
