"""
Shared cache manager for Gemini explicit Context Caching.
Creates one cache with V2 system prompt, shared by all workers.
TTL: 6 days (518400s). First worker to start creates it, rest reuse.
"""
from __future__ import annotations

import logging
from typing import Optional

import httpx

from .config import GEMINI_MODEL
from .prompt_builder import get_cacheable_system_prompt

logger = logging.getLogger(__name__)

AISTUDIO_BASE = "https://generativelanguage.googleapis.com/v1beta"
DEFAULT_CACHE_TTL_S = 518400  # 6 days


class CacheManager:

    def __init__(self, api_key: str):
        self.api_key = api_key
        self.cache_name: Optional[str] = None

    async def ensure_cache(self, ttl_s: int = DEFAULT_CACHE_TTL_S) -> str:
        """Get existing cache or create new one. Returns cache name."""
        existing = await self._find_existing_cache()
        if existing:
            self.cache_name = existing
            logger.info(f"Reusing existing cache: {self.cache_name}")
            return self.cache_name

        self.cache_name = await self._create_cache(ttl_s)
        logger.info(f"Created new cache: {self.cache_name} (TTL={ttl_s}s)")
        return self.cache_name

    async def _find_existing_cache(self) -> Optional[str]:
        """Check if we already have a valid cache for our model."""
        url = f"{AISTUDIO_BASE}/cachedContents?key={self.api_key}"
        async with httpx.AsyncClient(timeout=30.0) as client:
            resp = await client.get(url)
            if resp.status_code != 200:
                logger.warning(f"Failed to list caches: {resp.status_code} {resp.text[:200]}")
                return None

            data = resp.json()
            model_full = f"models/{GEMINI_MODEL}"
            for cache in data.get("cachedContents", []):
                if cache.get("model") == model_full:
                    name = cache["name"]
                    logger.info(f"Found existing cache: {name} (model={model_full})")
                    return name
        return None

    async def _create_cache(self, ttl_s: int) -> str:
        """Create explicit cache with V2 system prompt."""
        url = f"{AISTUDIO_BASE}/cachedContents?key={self.api_key}"
        body = {
            "model": f"models/{GEMINI_MODEL}",
            "systemInstruction": {
                "parts": [{"text": get_cacheable_system_prompt()}]
            },
            "ttl": f"{ttl_s}s",
        }
        async with httpx.AsyncClient(timeout=30.0) as client:
            resp = await client.post(url, json=body)
            if resp.status_code != 200:
                raise RuntimeError(f"Cache creation failed: {resp.status_code} {resp.text[:500]}")
            data = resp.json()
            name = data["name"]
            token_count = data.get("usageMetadata", {}).get("totalTokenCount", "unknown")
            logger.info(f"Cache created: {name}, tokens cached: {token_count}")
            return name

    async def get_cache_info(self) -> Optional[dict]:
        """Get info about current cache (TTL remaining, token count)."""
        if not self.cache_name:
            return None
        url = f"{AISTUDIO_BASE}/{self.cache_name}?key={self.api_key}"
        async with httpx.AsyncClient(timeout=30.0) as client:
            resp = await client.get(url)
            if resp.status_code == 200:
                return resp.json()
        return None

    async def delete_all_caches(self):
        """Delete all caches for this API key (cleanup utility)."""
        url = f"{AISTUDIO_BASE}/cachedContents?key={self.api_key}"
        async with httpx.AsyncClient(timeout=30.0) as client:
            resp = await client.get(url)
            if resp.status_code != 200:
                return
            for cache in resp.json().get("cachedContents", []):
                name = cache["name"]
                del_resp = await client.delete(f"{AISTUDIO_BASE}/{name}?key={self.api_key}")
                logger.info(f"Deleted cache {name}: {del_resp.status_code}")
