"""
OpenRouter-backed metadata classifier for TTS dataset filtering.

One video per request. The prompt is structured so the static rubric lives in
the first user message with a cache_control breakpoint, and per-video data is
injected in a second user message. This keeps the cached prefix identical
across all 500k requests so OpenRouter/Gemini implicit + explicit caching works.
"""
from __future__ import annotations

import asyncio
import json
import random
import time
from dataclasses import dataclass
from typing import Any, Literal

import httpx
from pydantic import BaseModel, Field

OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
DEFAULT_OPENROUTER_MODEL = "google/gemini-3-flash-preview"

YOUTUBE_CATEGORY_LABELS = {
    "1": "Film & Animation", "2": "Autos & Vehicles", "10": "Music",
    "15": "Pets & Animals", "17": "Sports", "18": "Short Movies",
    "19": "Travel & Events", "20": "Gaming", "21": "Videoblogging",
    "22": "People & Blogs", "23": "Comedy", "24": "Entertainment",
    "25": "News & Politics", "26": "Howto & Style", "27": "Education",
    "28": "Science & Technology", "29": "Nonprofits & Activism",
    "30": "Movies", "35": "Documentary", "42": "Shorts", "43": "Shows",
    "44": "Trailers",
}

# ── Static rubric (cached across all requests) ──────────────────────────

CACHED_RUBRIC = """You rank YouTube videos for TTS training suitability.

TARGET: long-form spoken-word — podcasts, lectures, interviews, oral histories, audiobooks, explainers, clean narration.
REJECT: meetings, webinars, elevator pitches, demo days, noisy events, gaming, music, sports highlights, memes, short-form clips.

Judge ONLY from the metadata payload below. Do not use outside knowledge.

Scoring (0-100 each):
- tts_suitability_score: overall keep-worthiness for TTS data.
- spoken_word_score: speech-dominated vs music/visuals/noise.
- clean_speech_likelihood_score: metadata guess at clean, stable audio.
- single_speaker_likelihood_score: one voice at a time vs crosstalk.
- metadata_confidence_score: how confidently metadata supports the call.

Decision:
- KEEP: strong spoken-word fit, low noise risk.
- REVIEW: mixed signals or ambiguous metadata.
- DROP: clear evidence of unsuitable content.

Hard negatives: meetings, webinars, launchpads, demo days, pitch sessions, ceremonies, Q&A, music videos, songs, remixes, gaming streams, gameplay, highlights, reactions, compilations, memes, shorts, live stage/crowd audio.
Hard positives: lecture series, university talks, archive interviews, structured podcasts, audiobooks, clean explainers, narrated history/stories.

Rules:
- Duration alone is not enough. Category alone is not enough.
- People & Blogs can be any action depending on title/channel/tags.
- Music category is almost always DROP unless metadata clearly shows a spoken interview.
- Meeting-like language (agenda, townhall, launchpad, elevator pitch, Q&A) is a strong negative.
- Scores >= 80 → usually KEEP. 55-79 → usually REVIEW. < 55 → usually DROP.
- If content is clearly a meeting/pitch/webinar/noisy event, set hard_reject=true.

Output: keep JSON compact. Lists max 4 items. short_rationale under 30 words. No info not in the payload."""

SYSTEM_PROMPT = (
    "You are a strict metadata-based gatekeeper for TTS training data. "
    "Think silently. Return only the JSON object."
)

STATIC_TASK_PREFIX = (
    "Classify this video's metadata for TTS dataset filtering. "
    "Return the structured JSON result."
)


def parse_json_list(raw_value: str) -> list[str]:
    if not raw_value:
        return []
    try:
        parsed = json.loads(raw_value)
    except json.JSONDecodeError:
        return []
    return [str(item) for item in parsed] if isinstance(parsed, list) else []


def normalize_metadata_row(row: dict[str, str]) -> dict[str, Any]:
    category_id = str(row.get("category_id", "") or "")
    tags = parse_json_list(row.get("tags", ""))
    topic_categories = parse_json_list(row.get("topic_categories", ""))
    dur = row.get("duration_seconds")
    duration_seconds = int(dur) if dur else None
    return {
        "video_id": str(row.get("video_id", "") or ""),
        "channel_title": str(row.get("channel_title", "") or ""),
        "title": str(row.get("title", "") or ""),
        "description": str(row.get("description", "") or "")[:800],
        "tags": tags[:20],
        "category_id": category_id,
        "category_label": YOUTUBE_CATEGORY_LABELS.get(category_id, "Unknown"),
        "default_audio_language": str(row.get("default_audio_language", "") or ""),
        "duration_seconds": duration_seconds,
        "definition": str(row.get("definition", "") or ""),
        "topic_categories": topic_categories,
    }


class VideoTtsClassification(BaseModel):
    recommended_action: Literal["keep", "review", "drop"] = Field(
        description="Final action for this video."
    )
    likely_content_type: str = Field(
        description="Short label: podcast, lecture, sports talk, meeting, audiobook, explainer, event, unknown, etc."
    )
    tts_suitability_score: int = Field(ge=0, le=100)
    spoken_word_score: int = Field(ge=0, le=100)
    clean_speech_likelihood_score: int = Field(ge=0, le=100)
    single_speaker_likelihood_score: int = Field(ge=0, le=100)
    metadata_confidence_score: int = Field(ge=0, le=100)
    hard_reject: bool = Field(
        description="True when metadata strongly indicates exclusion."
    )
    hard_reject_reasons: list[str] = Field(default_factory=list)
    positive_signals: list[str] = Field(default_factory=list)
    risk_signals: list[str] = Field(default_factory=list)
    short_rationale: str = Field(
        description="Under 30 words explaining the action."
    )
    needs_audio_validation: bool = Field(
        description="True if audio-level check is still needed."
    )


_SCHEMA = VideoTtsClassification.model_json_schema()


@dataclass
class ClassificationUsage:
    prompt_tokens: int = 0
    completion_tokens: int = 0
    cached_tokens: int = 0
    total_tokens: int = 0
    cache_hit: bool = False


@dataclass
class ClassificationResult:
    video_id: str
    classification: VideoTtsClassification
    usage: ClassificationUsage
    latency_ms: float


class OpenRouterClassifier:
    """Single-video-per-request classifier with stable cached prefix."""

    def __init__(
        self,
        api_key: str,
        *,
        model: str = DEFAULT_OPENROUTER_MODEL,
        temperature: float = 0.2,
        max_retries: int = 3,
        reasoning_effort: str = "low",
    ):
        self.api_key = api_key
        self.model = model
        self.temperature = temperature
        self.max_retries = max_retries
        self.reasoning_effort = reasoning_effort
        self._headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }

    def _build_body(self, metadata_row: dict[str, str]) -> dict[str, Any]:
        video_json = json.dumps(
            normalize_metadata_row(metadata_row), ensure_ascii=False
        )
        schema_instruction = (
            "\n\nReturn ONLY a JSON object with these exact fields:\n"
            + json.dumps(list(_SCHEMA.get("properties", {}).keys()))
            + "\nField types: recommended_action=(keep|review|drop), "
            "all scores=int 0-100, hard_reject=bool, lists=string[], "
            "short_rationale=string, needs_audio_validation=bool."
        )
        body: dict[str, Any] = {
            "model": self.model,
            "messages": [
                {
                    "role": "system",
                    "content": SYSTEM_PROMPT,
                },
                {
                    "role": "user",
                    "content": STATIC_TASK_PREFIX + "\n\n" + CACHED_RUBRIC + schema_instruction,
                },
                {
                    "role": "user",
                    "content": video_json,
                },
            ],
            "temperature": self.temperature,
            "max_tokens": 600,
            "response_format": {"type": "json_object"},
        }
        return body

    async def classify(
        self,
        metadata_row: dict[str, str],
        client: httpx.AsyncClient,
    ) -> ClassificationResult:
        video_id = str(metadata_row.get("video_id", "") or "")
        body = self._build_body(metadata_row)
        failures: list[str] = []

        for attempt in range(self.max_retries + 1):
            try:
                t0 = time.monotonic()
                resp = await client.post(
                    OPENROUTER_API_URL,
                    headers=self._headers,
                    json=body,
                )
                latency_ms = (time.monotonic() - t0) * 1000
            except httpx.HTTPError as exc:
                failures.append(f"transport: {exc}")
                if attempt < self.max_retries:
                    await asyncio.sleep(_retry_delay(attempt))
                    continue
                raise RuntimeError(
                    f"Transport failed {video_id}: {' | '.join(failures)}"
                ) from exc

            if resp.status_code == 429 or resp.status_code >= 500:
                failures.append(f"{resp.status_code}")
                if attempt < self.max_retries:
                    await asyncio.sleep(_retry_delay(attempt))
                    continue
            elif resp.status_code != 200:
                raise RuntimeError(
                    f"HTTP {resp.status_code} for {video_id}: {resp.text[:300]}"
                )

            payload = resp.json()
            content = payload.get("choices", [{}])[0].get("message", {}).get("content", "")
            if not content:
                raise RuntimeError(f"Empty content for {video_id}")

            parsed = _parse_json(content)
            classification = VideoTtsClassification.model_validate(parsed)
            usage = _extract_usage(payload)

            return ClassificationResult(
                video_id=video_id,
                classification=classification,
                usage=usage,
                latency_ms=latency_ms,
            )

        raise RuntimeError(
            f"All retries exhausted for {video_id}: {' | '.join(failures)}"
        )


def _parse_json(content: Any) -> dict[str, Any]:
    if isinstance(content, dict):
        return content
    if isinstance(content, list):
        content = "".join(
            str(item.get("text", item)) if isinstance(item, dict) else str(item)
            for item in content
        )
    if not isinstance(content, str):
        raise json.JSONDecodeError("bad type", str(content), 0)
    s = content.strip()
    if s.startswith("```"):
        s = s.strip("`")
        if s.startswith("json"):
            s = s[4:].lstrip()
    i, j = s.find("{"), s.rfind("}")
    if i != -1 and j > i:
        s = s[i : j + 1]
    import re
    s = re.sub(r',\s*([}\]])', r'\1', s)
    return json.loads(s)


def _extract_usage(payload: dict[str, Any]) -> ClassificationUsage:
    usage = payload.get("usage", {}) or {}
    pt = int(usage.get("prompt_tokens", 0) or 0)
    ct = int(usage.get("completion_tokens", 0) or 0)
    details = usage.get("prompt_tokens_details", {}) or {}
    cached = int(details.get("cached_tokens", 0) or 0)
    return ClassificationUsage(
        prompt_tokens=pt,
        completion_tokens=ct,
        cached_tokens=cached,
        total_tokens=int(usage.get("total_tokens", pt + ct) or 0),
        cache_hit=cached > 0,
    )


def _retry_delay(attempt: int) -> float:
    return min(2 ** attempt, 8) + random.uniform(0.0, 0.3)
