"""
AI Studio provider (PRIMARY): raw httpx.AsyncClient for true async concurrency.
Hits Gemini REST API directly. 429 handling with exponential backoff.
Cache verification via usage_metadata.cached_content_token_count.
"""
from __future__ import annotations

import asyncio
import base64
import json
import logging
import time
from typing import Any, Optional

import httpx

from .base import (
    BaseProvider, TranscriptionRequest, TranscriptionResponse,
    TokenUsage, RequestStatus,
)
from ..config import (
    GEMINI_MODEL, TEMPERATURE, THINKING_LEVEL, MAX_RETRIES_429,
)
from ..prompt_builder import build_system_prompt, get_user_prompt, get_json_schema, build_user_prompt

logger = logging.getLogger(__name__)

AISTUDIO_BASE = "https://generativelanguage.googleapis.com/v1beta"

# Gemini API error counters — logged periodically for monitoring
_api_stats = {
    "requests": 0, "success": 0,
    "http_429": 0, "http_500": 0, "http_other": 0,
    "timeouts": 0, "errors": 0, "retries": 0,
}


def log_api_stats():
    s = _api_stats
    logger.info(
        f"[GEMINI-STATS] sent={s['requests']} ok={s['success']} "
        f"429={s['http_429']} 500={s['http_500']} other_http={s['http_other']} "
        f"timeouts={s['timeouts']} errors={s['errors']} retries={s['retries']}"
    )


class AIStudioProvider(BaseProvider):

    def __init__(self, api_key: str, mock_mode: bool = False,
                 cached_content_name: Optional[str] = None):
        self.api_key = api_key
        self.mock_mode = mock_mode
        self.cached_content_name = cached_content_name
        self._schema = get_json_schema()

    def get_provider_name(self) -> str:
        return "aistudio"

    async def send_batch(self, requests: list[TranscriptionRequest]) -> list[TranscriptionResponse]:
        if self.mock_mode:
            return [self._mock_response(req) for req in requests]

        semaphore = asyncio.Semaphore(len(requests))
        timeout = httpx.Timeout(connect=10.0, read=20.0, write=30.0, pool=10.0)
        limits = httpx.Limits(max_connections=len(requests), max_keepalive_connections=100)

        async with httpx.AsyncClient(timeout=timeout, limits=limits) as client:
            tasks = [self._send_single(req, client, semaphore) for req in requests]
            return await asyncio.gather(*tasks)

    async def _send_single(self, req: TranscriptionRequest,
                           client: httpx.AsyncClient,
                           semaphore: asyncio.Semaphore) -> TranscriptionResponse:
        async with semaphore:
            for attempt in range(MAX_RETRIES_429 + 1):
                _api_stats["requests"] += 1
                try:
                    start = time.monotonic()
                    thinking_override = "MINIMAL" if attempt > 0 else None
                    resp = await self._call_api(req, client, thinking_override=thinking_override)
                    latency = (time.monotonic() - start) * 1000

                    if resp.status_code == 429:
                        _api_stats["http_429"] += 1
                        if attempt < MAX_RETRIES_429:
                            _api_stats["retries"] += 1
                            wait = 10 * (2 ** attempt)
                            logger.warning(f"429 on {req.segment_id}, retry {attempt+1} in {wait}s")
                            await asyncio.sleep(wait)
                            continue
                        return TranscriptionResponse(
                            segment_id=req.segment_id,
                            status=RequestStatus.RATE_LIMITED,
                            error_message=f"429 after {MAX_RETRIES_429} retries",
                        )

                    if resp.status_code >= 500:
                        _api_stats["http_500"] += 1
                    if resp.status_code != 200:
                        error_body = resp.text[:500]
                        if resp.status_code < 500:
                            _api_stats["http_other"] += 1
                        if ("RESOURCE_EXHAUSTED" in error_body or resp.status_code >= 500) and attempt < MAX_RETRIES_429:
                            _api_stats["retries"] += 1
                            wait = 10 * (2 ** attempt)
                            logger.warning(f"HTTP {resp.status_code} on {req.segment_id}, retry {attempt+1} in {wait}s")
                            await asyncio.sleep(wait)
                            continue
                        return TranscriptionResponse(
                            segment_id=req.segment_id,
                            status=RequestStatus.ERROR,
                            error_message=f"HTTP {resp.status_code}: {error_body}",
                            latency_ms=latency,
                        )

                    _api_stats["success"] += 1
                    data = resp.json()
                    return self._parse_response(req.segment_id, data, latency)

                except httpx.TimeoutException as e:
                    _api_stats["timeouts"] += 1
                    if attempt < 1:
                        _api_stats["retries"] += 1
                        logger.warning(f"Timeout on {req.segment_id}, retrying with thinking=MINIMAL")
                        continue
                    logger.error(f"Discarding {req.segment_id}: timed out twice (bad sample / model hang)")
                    return TranscriptionResponse(
                        segment_id=req.segment_id,
                        status=RequestStatus.TIMEOUT,
                        error_message=f"Discarded: timed out after {attempt+1} attempts: {e}",
                    )
                except Exception as e:
                    _api_stats["errors"] += 1
                    logger.error(f"Error for {req.segment_id}: {e}")
                    return TranscriptionResponse(
                        segment_id=req.segment_id,
                        status=RequestStatus.ERROR,
                        error_message=str(e),
                    )

        return TranscriptionResponse(
            segment_id=req.segment_id,
            status=RequestStatus.ERROR,
            error_message="Exhausted retries",
        )

    async def _call_api(self, req: TranscriptionRequest,
                        client: httpx.AsyncClient,
                        thinking_override: Optional[str] = None) -> httpx.Response:
        url = f"{AISTUDIO_BASE}/models/{GEMINI_MODEL}:generateContent?key={self.api_key}"

        if self.cached_content_name:
            user_prompt = build_user_prompt(req.language_code)
        else:
            user_prompt = get_user_prompt()

        thinking_level = thinking_override or THINKING_LEVEL.upper()

        body: dict[str, Any] = {
            "contents": [
                {
                    "role": "user",
                    "parts": [
                        {
                            "inlineData": {
                                "mimeType": req.mime_type,
                                "data": req.audio_base64,
                            }
                        },
                        {"text": user_prompt},
                    ],
                }
            ],
            "generationConfig": {
                "temperature": TEMPERATURE,
                "responseMimeType": "application/json",
                "responseJsonSchema": self._schema,
                "thinkingConfig": {
                    "thinkingLevel": thinking_level,
                },
            },
        }

        if self.cached_content_name:
            body["cachedContent"] = self.cached_content_name
        else:
            body["systemInstruction"] = {
                "parts": [{"text": build_system_prompt(req.language_code)}]
            }

        return await client.post(url, json=body)

    def _parse_response(self, segment_id: str, data: dict,
                        latency: float) -> TranscriptionResponse:
        try:
            candidates = data.get("candidates", [])
            if not candidates:
                return TranscriptionResponse(
                    segment_id=segment_id,
                    status=RequestStatus.ERROR,
                    error_message=f"No candidates in response: {json.dumps(data)[:300]}",
                    latency_ms=latency,
                )

            parts = candidates[0].get("content", {}).get("parts", [])
            text = ""
            for part in parts:
                if "text" in part:
                    text = part["text"]
                    break

            try:
                parsed = json.loads(text) if text else {}
            except json.JSONDecodeError:
                parsed = {"raw_text": text}

            token_usage = self.get_token_usage(data)

            return TranscriptionResponse(
                segment_id=segment_id,
                status=RequestStatus.SUCCESS,
                transcription_data=parsed,
                token_usage=token_usage,
                latency_ms=latency,
                raw_response=data,
            )
        except Exception as e:
            return TranscriptionResponse(
                segment_id=segment_id,
                status=RequestStatus.ERROR,
                error_message=f"Parse error: {e}",
                latency_ms=latency,
            )

    def verify_cache_hit(self, response: Any) -> bool:
        try:
            usage = response.get("usageMetadata", {})
            cached = usage.get("cachedContentTokenCount", 0)
            return cached > 0
        except Exception:
            return False

    def get_token_usage(self, response: Any) -> TokenUsage:
        try:
            usage = response.get("usageMetadata", {})
            input_tokens = usage.get("promptTokenCount", 0)
            output_tokens = usage.get("candidatesTokenCount", 0)
            cached = usage.get("cachedContentTokenCount", 0)
            return TokenUsage(
                input_tokens=input_tokens,
                output_tokens=output_tokens,
                cached_tokens=cached,
                total_tokens=input_tokens + output_tokens,
                cache_hit=cached > 0,
            )
        except Exception:
            return TokenUsage()

    def _mock_response(self, req: TranscriptionRequest) -> TranscriptionResponse:
        lang_name = req.language_code
        mock_data = {
            "transcription": f"[MOCK] Sample transcription for {req.original_file} in {lang_name}",
            "tagged": f"[MOCK] [noise] Sample transcription for {req.original_file} in {lang_name}",
            "speaker": {
                "emotion": "neutral",
                "speaking_style": "conversational",
                "pace": "normal",
                "accent": "",
            },
            "detected_language": req.language_code,
        }
        return TranscriptionResponse(
            segment_id=req.segment_id,
            status=RequestStatus.SUCCESS,
            transcription_data=mock_data,
            token_usage=TokenUsage(
                input_tokens=300, output_tokens=150, cached_tokens=200,
                total_tokens=450, cache_hit=True,
            ),
            latency_ms=50.0,
        )
