"""
OpenRouter provider (SECONDARY): httpx async client, same interface as AI Studio,
cache verification via prompt_tokens_details.cached_tokens.
"""
from __future__ import annotations

import asyncio
import json
import logging
import time
from typing import Any, Optional

import httpx

from .base import (
    BaseProvider, TranscriptionRequest, TranscriptionResponse,
    TokenUsage, RequestStatus,
)
from ..config import TEMPERATURE, MAX_RETRIES_429

OPENROUTER_MODEL = "google/gemini-3-flash-preview"
from ..prompt_builder import get_cacheable_system_prompt, build_user_prompt, get_json_schema

logger = logging.getLogger(__name__)

OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"


class OpenRouterProvider(BaseProvider):

    def __init__(self, api_key: str, mock_mode: bool = False):
        self.api_key = api_key
        self.mock_mode = mock_mode
        self._schema = get_json_schema()

    def get_provider_name(self) -> str:
        return "openrouter"

    async def send_batch(self, requests: list[TranscriptionRequest]) -> list[TranscriptionResponse]:
        if self.mock_mode:
            return [self._mock_response(req) for req in requests]

        semaphore = asyncio.Semaphore(len(requests))
        timeout = httpx.Timeout(connect=10.0, read=20.0, write=30.0, pool=10.0)
        limits = httpx.Limits(max_connections=len(requests), max_keepalive_connections=100)
        async with httpx.AsyncClient(timeout=timeout, limits=limits) as client:
            tasks = [self._send_single(req, client, semaphore) for req in requests]
            return await asyncio.gather(*tasks)

    async def _send_single(self, req: TranscriptionRequest,
                           client: httpx.AsyncClient,
                           semaphore: asyncio.Semaphore) -> TranscriptionResponse:
        async with semaphore:
            for attempt in range(MAX_RETRIES_429 + 1):
                try:
                    start = time.monotonic()
                    resp = await self._call_api(req, client)
                    latency = (time.monotonic() - start) * 1000

                    if resp.status_code == 429:
                        if attempt < MAX_RETRIES_429:
                            wait = 10 * (2 ** attempt)
                            logger.warning(f"429 on {req.segment_id}, retry {attempt+1} in {wait}s")
                            await asyncio.sleep(wait)
                            continue
                        return TranscriptionResponse(
                            segment_id=req.segment_id,
                            status=RequestStatus.RATE_LIMITED,
                            error_message="429 rate limited after retries",
                        )

                    if resp.status_code != 200:
                        error_body = resp.text[:500]
                        return TranscriptionResponse(
                            segment_id=req.segment_id,
                            status=RequestStatus.ERROR,
                            error_message=f"HTTP {resp.status_code}: {error_body}",
                            latency_ms=latency,
                        )
                    data = resp.json()

                    content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
                    try:
                        parsed = json.loads(content) if content else {}
                    except json.JSONDecodeError:
                        parsed = {"raw_text": content}

                    token_usage = self.get_token_usage(data)

                    return TranscriptionResponse(
                        segment_id=req.segment_id,
                        status=RequestStatus.SUCCESS,
                        transcription_data=parsed,
                        token_usage=token_usage,
                        latency_ms=latency,
                        raw_response=data,
                    )

                except httpx.TimeoutException as e:
                    if attempt < 1:
                        logger.warning(f"Timeout on {req.segment_id}, retry {attempt+1}")
                        continue
                    # 20s x 2 attempts = 40s max waste per bad segment
                    logger.error(f"Discarding {req.segment_id}: timed out twice (bad sample / model hang)")
                    return TranscriptionResponse(
                        segment_id=req.segment_id,
                        status=RequestStatus.TIMEOUT,
                        error_message=f"Discarded: timed out after {attempt+1} attempts: {e}",
                    )
                except Exception as e:
                    logger.error(f"OpenRouter error for {req.segment_id}: {e}")
                    return TranscriptionResponse(
                        segment_id=req.segment_id,
                        status=RequestStatus.ERROR,
                        error_message=str(e),
                    )

        return TranscriptionResponse(
            segment_id=req.segment_id,
            status=RequestStatus.ERROR,
            error_message="Exhausted retries",
        )

    async def _call_api(self, req: TranscriptionRequest,
                        client: httpx.AsyncClient) -> httpx.Response:
        # V2: uniform cacheable system prompt + per-request language hint
        # OpenRouter handles caching internally via prompt prefix matching
        system_prompt = get_cacheable_system_prompt()
        user_prompt = build_user_prompt(req.language_code)

        # System prompt uses cache_control breakpoint for Gemini caching on OpenRouter.
        # OpenRouter uses only the last breakpoint for Gemini. The system prompt is
        # uniform (~1036 tokens), so it gets cached and reused across all requests.
        payload = {
            "model": OPENROUTER_MODEL,
            "messages": [
                {
                    "role": "system",
                    "content": [
                        {
                            "type": "text",
                            "text": system_prompt,
                            "cache_control": {"type": "ephemeral"},
                        },
                    ],
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "input_audio",
                            "input_audio": {
                                "data": req.audio_base64,
                                "format": "flac",
                            },
                        },
                        {"type": "text", "text": user_prompt},
                    ],
                },
            ],
            "temperature": TEMPERATURE,
            "response_format": {
                "type": "json_schema",
                "json_schema": {
                    "name": "transcription_output",
                    "strict": True,
                    "schema": self._schema,
                },
            },
        }

        return await client.post(
            OPENROUTER_API_URL,
            headers={
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
            },
            json=payload,
        )

    def verify_cache_hit(self, response: Any) -> bool:
        try:
            usage = response.get("usage", {})
            details = usage.get("prompt_tokens_details", {})
            cached = details.get("cached_tokens", 0)
            return cached > 0
        except Exception:
            return False

    def get_token_usage(self, response: Any) -> TokenUsage:
        try:
            usage = response.get("usage", {})
            input_t = usage.get("prompt_tokens", 0)
            output_t = usage.get("completion_tokens", 0)
            details = usage.get("prompt_tokens_details", {})
            cached = details.get("cached_tokens", 0)
            return TokenUsage(
                input_tokens=input_t,
                output_tokens=output_t,
                cached_tokens=cached,
                total_tokens=input_t + output_t,
                cache_hit=cached > 0,
            )
        except Exception:
            return TokenUsage()

    def _mock_response(self, req: TranscriptionRequest) -> TranscriptionResponse:
        mock_data = {
            "transcription": f"[MOCK-OR] Transcription for {req.original_file}",
            "tagged": f"[MOCK-OR] [laugh] Transcription for {req.original_file}",
            "speaker": {
                "emotion": "happy",
                "speaking_style": "conversational",
                "pace": "normal",
                "accent": "",
            },
            "detected_language": req.language_code,
        }
        return TranscriptionResponse(
            segment_id=req.segment_id,
            status=RequestStatus.SUCCESS,
            transcription_data=mock_data,
            token_usage=TokenUsage(
                input_tokens=300, output_tokens=150, cached_tokens=180,
                total_tokens=450, cache_hit=True,
            ),
            latency_ms=65.0,
        )
