"""
Gemini client for transcript variant generation.

This is the production counterpart to the prompt test harness. It reuses the
same cacheable prompt and schema, but exposes a worker-friendly API.
"""
from __future__ import annotations

import hashlib
import json
from dataclasses import dataclass
from typing import Any, Optional

import httpx

from .config import GEMINI_MODEL, THINKING_LEVEL
from .providers.base import TokenUsage
from .transcript_variant_prompt import (
    TranscriptVariantBatchResult,
    build_transcript_variant_user_prompt,
    get_cacheable_transcript_variant_prompt,
    get_transcript_variant_json_schema,
)


AISTUDIO_BASE = "https://generativelanguage.googleapis.com/v1beta"


def get_variant_cache_display_name() -> str:
    prompt_hash = hashlib.sha1(
        get_cacheable_transcript_variant_prompt().encode("utf-8")
    ).hexdigest()[:12]
    return f"transcript-variant-{prompt_hash}"


@dataclass
class VariantBatchResult:
    items: list[dict[str, Any]]
    token_usage: TokenUsage
    raw_response: dict[str, Any]


class TranscriptVariantCacheManager:
    def __init__(self, api_key: str):
        self.api_key = api_key

    async def ensure_cache(self, ttl_s: int) -> dict[str, Any]:
        existing = await self._find_existing_cache()
        if existing:
            return existing
        return await self._create_cache(ttl_s)

    async def _find_existing_cache(self) -> dict[str, Any] | None:
        url = f"{AISTUDIO_BASE}/cachedContents?key={self.api_key}"
        display_name = get_variant_cache_display_name()
        async with httpx.AsyncClient(timeout=30.0) as client:
            resp = await client.get(url)
            if resp.status_code != 200:
                return None
            model_name = f"models/{GEMINI_MODEL}"
            for cache in resp.json().get("cachedContents", []):
                if cache.get("model") == model_name and cache.get("displayName") == display_name:
                    detailed = await self.get_cache_info(cache["name"])
                    return detailed or cache
        return None

    async def _create_cache(self, ttl_s: int) -> dict[str, Any]:
        url = f"{AISTUDIO_BASE}/cachedContents?key={self.api_key}"
        body = {
            "model": f"models/{GEMINI_MODEL}",
            "displayName": get_variant_cache_display_name(),
            "systemInstruction": {
                "parts": [{"text": get_cacheable_transcript_variant_prompt()}]
            },
            "ttl": f"{ttl_s}s",
        }
        async with httpx.AsyncClient(timeout=60.0) as client:
            resp = await client.post(url, json=body)
            if resp.status_code != 200:
                raise RuntimeError(f"Cache creation failed: {resp.status_code} {resp.text[:500]}")
            return resp.json()

    async def get_cache_info(self, cache_name: str) -> dict[str, Any] | None:
        url = f"{AISTUDIO_BASE}/{cache_name}?key={self.api_key}"
        async with httpx.AsyncClient(timeout=30.0) as client:
            resp = await client.get(url)
            if resp.status_code == 200:
                return resp.json()
        return None


class GeminiRetryableError(Exception):
    """Raised for errors that should be retried (429, 500, timeout, empty response)."""
    pass


class GeminiPermanentError(Exception):
    """Raised for errors that should NOT be retried (400 bad request, schema issues)."""
    pass


class TranscriptVariantClient:
    MAX_RETRIES = 4
    RETRY_BACKOFF = (5, 10, 20, 40)

    def __init__(self, api_key: str, cache_name: Optional[str]):
        self.api_key = api_key
        self.cache_name = cache_name
        self.schema = get_transcript_variant_json_schema()
        self._http = httpx.AsyncClient(
            timeout=httpx.Timeout(connect=10.0, read=60.0, write=60.0, pool=20.0),
            limits=httpx.Limits(max_connections=100, max_keepalive_connections=50),
        )

    async def close(self):
        await self._http.aclose()

    async def generate_batch(self, items: list[dict[str, Any]]) -> VariantBatchResult:
        url = f"{AISTUDIO_BASE}/models/{GEMINI_MODEL}:generateContent?key={self.api_key}"
        body = self._build_request_body(items)

        last_error: Exception | None = None
        for attempt in range(self.MAX_RETRIES):
            try:
                resp = await self._http.post(url, json=body)
                raw_response = self._handle_http_response(resp, attempt)
                return self._parse_response(raw_response)
            except GeminiRetryableError as exc:
                last_error = exc
                wait = self.RETRY_BACKOFF[min(attempt, len(self.RETRY_BACKOFF) - 1)]
                import logging
                logging.getLogger(__name__).warning(
                    "Gemini retryable error (attempt %s/%s): %s, retry in %ss",
                    attempt + 1, self.MAX_RETRIES, str(exc)[:200], wait,
                )
                import asyncio
                await asyncio.sleep(wait)
            except GeminiPermanentError:
                raise
            except httpx.TimeoutException as exc:
                last_error = GeminiRetryableError(f"Timeout: {exc}")
                wait = self.RETRY_BACKOFF[min(attempt, len(self.RETRY_BACKOFF) - 1)]
                import logging, asyncio
                logging.getLogger(__name__).warning(
                    "Gemini timeout (attempt %s/%s): %s, retry in %ss",
                    attempt + 1, self.MAX_RETRIES, str(exc)[:100], wait,
                )
                await asyncio.sleep(wait)

        raise GeminiRetryableError(f"Exhausted {self.MAX_RETRIES} retries: {last_error}")

    def _build_request_body(self, items: list[dict[str, Any]]) -> dict[str, Any]:
        body: dict[str, Any] = {
            "contents": [
                {
                    "role": "user",
                    "parts": [{"text": build_transcript_variant_user_prompt(items)}],
                }
            ],
            "generationConfig": {
                "temperature": 0,
                "responseMimeType": "application/json",
                "responseJsonSchema": self.schema,
                "thinkingConfig": {
                    "thinkingLevel": THINKING_LEVEL.upper(),
                },
            },
        }
        if self.cache_name:
            body["cachedContent"] = self.cache_name
        else:
            body["systemInstruction"] = {
                "parts": [{"text": get_cacheable_transcript_variant_prompt()}]
            }
        return body

    def _handle_http_response(self, resp: httpx.Response, attempt: int) -> dict[str, Any]:
        if resp.status_code == 429:
            raise GeminiRetryableError(f"429 rate limited: {resp.text[:200]}")
        if resp.status_code >= 500:
            raise GeminiRetryableError(f"HTTP {resp.status_code}: {resp.text[:200]}")
        if resp.status_code == 400:
            body_text = resp.text[:500]
            if "RESOURCE_EXHAUSTED" in body_text:
                raise GeminiRetryableError(f"RESOURCE_EXHAUSTED: {body_text}")
            raise GeminiPermanentError(f"HTTP 400: {body_text}")
        if resp.status_code != 200:
            raise GeminiRetryableError(f"HTTP {resp.status_code}: {resp.text[:200]}")
        return resp.json()

    def _parse_response(self, raw_response: dict[str, Any]) -> VariantBatchResult:
        candidates = raw_response.get("candidates", [])
        if not candidates:
            raise GeminiRetryableError(
                f"No candidates: {json.dumps(raw_response.get('promptFeedback', {}))[:200]}"
            )
        finish_reason = candidates[0].get("finishReason", "")
        if finish_reason == "SAFETY":
            raise GeminiPermanentError(f"Safety block: {candidates[0].get('safetyRatings', [])}")

        response_text = ""
        for part in candidates[0].get("content", {}).get("parts", []):
            if "text" in part:
                response_text = part["text"]
                break
        if not response_text:
            raise GeminiRetryableError("Empty text in response parts")

        try:
            parsed = TranscriptVariantBatchResult.model_validate_json(response_text)
        except Exception as exc:
            raise GeminiRetryableError(f"Schema parse failed: {exc}")

        return VariantBatchResult(
            items=[item.model_dump() for item in parsed.results],
            token_usage=_extract_token_usage(raw_response),
            raw_response=raw_response,
        )


def _extract_token_usage(response_json: dict[str, Any]) -> TokenUsage:
    usage = response_json.get("usageMetadata", {})
    input_tokens = usage.get("promptTokenCount", 0)
    output_tokens = usage.get("candidatesTokenCount", 0)
    cached_tokens = usage.get("cachedContentTokenCount", 0)
    return TokenUsage(
        input_tokens=input_tokens,
        output_tokens=output_tokens,
        cached_tokens=cached_tokens,
        total_tokens=input_tokens + output_tokens,
        cache_hit=cached_tokens > 0,
    )


def serialize_result_row(row: dict[str, Any]) -> str:
    return json.dumps(row, ensure_ascii=False, sort_keys=True)
