"""
Issue: cache_control on system prompt is ignored when
user message contains audio (input_audio). Explicit caching never triggers
cache_write_tokens or cached_tokens for multimodal/audio requests, despite
working correctly for text-only requests.

REPRO: Set OPENROUTER_API_KEY env var and run:
    python openrouter_cache_bug_repro.py

Tested models: google/gemini-2.5-flash, google/gemini-3-flash-preview
Result: text-only caching works (cache_write > 0 on req1, cached > 0 on req2+).
        audio caching always returns cache_write=0, cached=0.
"""
import base64
import json
import os
import time
from pathlib import Path

import httpx
from dotenv import load_dotenv

load_dotenv()

API_KEY = os.getenv("OPENROUTER_API_KEY")
API_URL = "https://openrouter.ai/api/v1/chat/completions"

# Our system Prompt  that i've tested is ~1037 tokens — exceeds the 1028 minimum for Gemini 2.5 Flash/ Gemini-3-flash-preview caching
SYSTEM_PROMPT = """ """

HEADERS = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json",
}


def fmt(data: dict) -> str:
    u = data.get("usage", {})
    d = u.get("prompt_tokens_details", {})
    return (
        f"prompt={u.get('prompt_tokens', 0)}, "
        f"cached={d.get('cached_tokens', 0)}, "
        f"cache_write={d.get('cache_write_tokens', 0)}, "
        f"output={u.get('completion_tokens', 0)}"
    )


def send(body: dict, label: str) -> dict:
    start = time.monotonic()
    resp = httpx.post(API_URL, json=body, headers=HEADERS, timeout=60)
    elapsed = (time.monotonic() - start) * 1000

    if resp.status_code != 200:
        print(f"  {label}: HTTP {resp.status_code} ({elapsed:.0f}ms)")
        print(f"    {resp.text[:300]}")
        return {"error": resp.status_code}

    data = resp.json()
    req_id = data.get("id", "?")
    print(f"  {label}: {fmt(data)} | {elapsed:.0f}ms | id={req_id}")
    return data


def find_audio_file() -> Path | None:
    """Look for any FLAC file in common test locations."""
    candidates = [
        Path("preflight/canary_data/en/nM2KMwb86IU/nM2KMwb86IU/segments"),
        Path("test_audio"),
        Path("."),
    ]
    for d in candidates:
        if d.exists():
            flacs = sorted(d.glob("*.flac"))
            if flacs:
                return flacs[0]
    return None


def build_system_msg():
    """System message with cache_control breakpoint."""
    return {
        "role": "system",
        "content": [
            {
                "type": "text",
                "text": SYSTEM_PROMPT,
                "cache_control": {"type": "ephemeral"},
            }
        ],
    }


def main():
    assert API_KEY, "Set OPENROUTER_API_KEY"

    for model in ["google/gemini-2.5-flash", "google/gemini-3-flash-preview"]:
        print(f"\n{'='*72}")
        print(f"MODEL: {model}")
        print(f"{'='*72}")

        # ── TEST 1: TEXT-ONLY (control) ──────────────────────────────────
        print(f"\n  ── TEST 1: Text-only + cache_control (3 identical requests) ──")
        print(f"  Expected: cache_write>0 on req1, cached>0 on req2+")
        for i in range(3):
            body = {
                "model": model,
                "messages": [
                    build_system_msg(),
                    {"role": "user", "content": "Say hello in 5 words."},
                ],
                "temperature": 0,
                "provider": {"order": ["Google AI Studio"], "allow_fallbacks": False},
            }
            send(body, f"  text-req{i+1}")
            time.sleep(2)

        # ── TEST 2: AUDIO — SAME FILE 3x ────────────────────────────────
        audio_path = find_audio_file()
        if not audio_path:
            print("\n  ── SKIPPING audio tests: no .flac file found ──")
            print("  Place a FLAC file in ./preflight/canary_data/en/*/segments/")
            continue

        audio_b64 = base64.b64encode(audio_path.read_bytes()).decode()
        print(f"\n  ── TEST 2: Same audio 3x + cache_control (pinned AI Studio) ──")
        print(f"  Audio: {audio_path.name} ({len(audio_b64)//1024}KB)")
        print(f"  Expected: cache_write>0 on req1, cached>0 on req2+")
        print(f"  ACTUAL BUG: cache_write=0 and cached=0 on ALL requests")
        for i in range(3):
            body = {
                "model": model,
                "messages": [
                    build_system_msg(),
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "input_audio",
                                "input_audio": {"data": audio_b64, "format": "flac"},
                            },
                            {"type": "text", "text": "Transcribe this audio."},
                        ],
                    },
                ],
                "temperature": 0,
                "provider": {"order": ["Google AI Studio"], "allow_fallbacks": False},
            }
            send(body, f"  audio-req{i+1}")
            time.sleep(2)

        # ── TEST 3: AUDIO — NO PROVIDER PIN ──────────────────────────────
        print(f"\n  ── TEST 3: Same audio 3x + cache_control (default routing) ──")
        for i in range(3):
            body = {
                "model": model,
                "messages": [
                    build_system_msg(),
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "input_audio",
                                "input_audio": {"data": audio_b64, "format": "flac"},
                            },
                            {"type": "text", "text": "Transcribe this audio."},
                        ],
                    },
                ],
                "temperature": 0,
            }
            send(body, f"  audio-req{i+1}")
            time.sleep(2)

    print(f"\n{'='*72}")
    print("SUMMARY")
    print("="*72)

if __name__ == "__main__":
    main()
