"""
Classify YouTube metadata rows for TTS suitability via OpenRouter.

One video per request. Uses asyncio.Semaphore for massive parallelism.
Resumable: reruns without --overwrite skip already-classified video IDs.

Usage:
  python scripts/classify_video_tts_metadata.py \
      --concurrency 1000 --overwrite --no-raw-response
"""
from __future__ import annotations

import argparse
import asyncio
import csv
import json
import os
import sys
import time
from pathlib import Path
from typing import Iterator

import httpx
from dotenv import load_dotenv

PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.video_tts_classifier import OpenRouterClassifier

OUTPUT_FIELDS = [
    "video_id", "model", "recommended_action", "likely_content_type",
    "tts_suitability_score", "spoken_word_score",
    "clean_speech_likelihood_score", "single_speaker_likelihood_score",
    "metadata_confidence_score", "hard_reject", "hard_reject_reasons",
    "positive_signals", "risk_signals", "short_rationale",
    "needs_audio_validation", "cache_hit", "cached_tokens",
    "prompt_tokens", "completion_tokens", "total_tokens",
    "latency_ms", "error",
]


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--input-csv", default="data/youtube_video_metadata_all.csv")
    p.add_argument("--output-csv", default="data/video_tts_classification.csv")
    p.add_argument("--video-id", action="append", default=[])
    p.add_argument("--video-id-file", default="", help="File with one video_id per line")
    p.add_argument("--limit", type=int, default=0)
    p.add_argument("--concurrency", type=int, default=500)
    p.add_argument("--temperature", type=float, default=0.2)
    p.add_argument("--model", default="google/gemini-3-flash-preview")
    p.add_argument("--reasoning-effort", default="low")
    p.add_argument("--progress-every", type=int, default=500)
    p.add_argument("--overwrite", action="store_true")
    p.add_argument("--no-raw-response", action="store_true")
    return p.parse_args()


def load_env() -> None:
    load_dotenv(Path(__file__).resolve().parent.parent / ".env")


def read_done_ids(path: Path) -> set[str]:
    if not path.exists():
        return set()
    ids: set[str] = set()
    with path.open("r", encoding="utf-8", newline="") as f:
        for row in csv.DictReader(f):
            vid = row.get("video_id", "")
            if vid:
                ids.add(vid)
    return ids


def iter_rows(
    path: Path, *, targets: list[str], skip: set[str], limit: int,
) -> Iterator[dict[str, str]]:
    target_set = {v.strip() for v in targets if v.strip()}
    n = 0
    with path.open("r", encoding="utf-8", newline="") as f:
        for row in csv.DictReader(f):
            vid = row.get("video_id", "")
            if not vid or vid in skip:
                continue
            if target_set and vid not in target_set:
                continue
            yield row
            n += 1
            if limit > 0 and n >= limit:
                break


def result_row(r, model: str) -> dict[str, str]:
    c = r.classification
    u = r.usage
    return {
        "video_id": r.video_id, "model": model,
        "recommended_action": c.recommended_action,
        "likely_content_type": c.likely_content_type,
        "tts_suitability_score": str(c.tts_suitability_score),
        "spoken_word_score": str(c.spoken_word_score),
        "clean_speech_likelihood_score": str(c.clean_speech_likelihood_score),
        "single_speaker_likelihood_score": str(c.single_speaker_likelihood_score),
        "metadata_confidence_score": str(c.metadata_confidence_score),
        "hard_reject": str(c.hard_reject).lower(),
        "hard_reject_reasons": json.dumps(c.hard_reject_reasons),
        "positive_signals": json.dumps(c.positive_signals),
        "risk_signals": json.dumps(c.risk_signals),
        "short_rationale": c.short_rationale,
        "needs_audio_validation": str(c.needs_audio_validation).lower(),
        "cache_hit": str(u.cache_hit).lower(),
        "cached_tokens": str(u.cached_tokens),
        "prompt_tokens": str(u.prompt_tokens),
        "completion_tokens": str(u.completion_tokens),
        "total_tokens": str(u.total_tokens),
        "latency_ms": f"{r.latency_ms:.0f}",
        "error": "",
    }


def error_row(row: dict[str, str], model: str, error: str) -> dict[str, str]:
    return {
        "video_id": row.get("video_id", ""), "model": model,
        "recommended_action": "review",
        "likely_content_type": "error",
        **{k: "" for k in OUTPUT_FIELDS if k not in ("video_id", "model", "recommended_action", "likely_content_type", "error")},
        "error": error[:500],
    }


def notfound_row(row: dict[str, str], model: str) -> dict[str, str]:
    return {
        "video_id": row.get("video_id", ""), "model": model,
        "recommended_action": "drop",
        "likely_content_type": "missing_metadata",
        "tts_suitability_score": "0", "spoken_word_score": "0",
        "clean_speech_likelihood_score": "0",
        "single_speaker_likelihood_score": "0",
        "metadata_confidence_score": "100",
        "hard_reject": "true",
        "hard_reject_reasons": '["no_metadata"]',
        "positive_signals": "[]", "risk_signals": '["no_metadata"]',
        "short_rationale": "Metadata unavailable.",
        "needs_audio_validation": "false",
        "cache_hit": "false", "cached_tokens": "0",
        "prompt_tokens": "0", "completion_tokens": "0",
        "total_tokens": "0", "latency_ms": "0", "error": "",
    }


async def run(args: argparse.Namespace) -> None:
    load_env()
    api_key = os.getenv("OPENROUTER_API_KEY", "").strip()
    if not api_key:
        raise SystemExit("OPENROUTER_API_KEY missing")

    output_csv = Path(args.output_csv)
    output_csv.parent.mkdir(parents=True, exist_ok=True)

    skip = set()
    write_hdr = True
    if output_csv.exists():
        if args.overwrite:
            output_csv.unlink()
        else:
            skip = read_done_ids(output_csv)
            write_hdr = False
            print(f"Resuming: {len(skip):,} already done")

    classifier = OpenRouterClassifier(
        api_key, model=args.model, temperature=args.temperature,
        reasoning_effort=args.reasoning_effort,
    )

    targets = list(args.video_id)
    if args.video_id_file:
        vf = Path(args.video_id_file)
        if vf.exists():
            with vf.open("r") as fh:
                for line in fh:
                    v = line.strip()
                    if v and v != "video_id":
                        targets.append(v)

    rows = list(iter_rows(
        Path(args.input_csv), targets=targets,
        skip=skip, limit=args.limit,
    ))
    total = len(rows)
    print(f"Videos to classify: {total:,}")
    if total == 0:
        return

    sem = asyncio.Semaphore(args.concurrency)
    t0 = time.monotonic()
    completed = 0
    cache_hits = 0
    keep_n = 0
    review_n = 0
    drop_n = 0
    lock = asyncio.Lock()

    async with httpx.AsyncClient(
        timeout=httpx.Timeout(connect=15.0, read=120.0, write=30.0, pool=60.0),
        limits=httpx.Limits(
            max_connections=args.concurrency + 50,
            max_keepalive_connections=args.concurrency,
        ),
    ) as client:
        fh = output_csv.open("a", encoding="utf-8", newline="")
        writer = csv.DictWriter(fh, fieldnames=OUTPUT_FIELDS)
        if write_hdr:
            writer.writeheader()
            fh.flush()

        async def process(row: dict[str, str]) -> None:
            nonlocal completed, cache_hits, keep_n, review_n, drop_n

            if row.get("fetch_status") != "ok":
                out = notfound_row(row, args.model)
            else:
                async with sem:
                    try:
                        r = await classifier.classify(row, client)
                        out = result_row(r, args.model)
                    except Exception as exc:
                        out = error_row(row, args.model, str(exc))

            async with lock:
                writer.writerow(out)
                completed += 1
                if out.get("cache_hit") == "true":
                    cache_hits += 1
                act = out.get("recommended_action", "")
                if act == "keep":
                    keep_n += 1
                elif act == "review":
                    review_n += 1
                else:
                    drop_n += 1

                if completed % args.progress_every == 0 or completed == total:
                    elapsed = time.monotonic() - t0
                    rate = completed / elapsed if elapsed > 0 else 0
                    eta_s = (total - completed) / rate if rate > 0 else 0
                    print(
                        f"{completed:,}/{total:,} in {elapsed:.0f}s "
                        f"({rate:.1f}/s, ETA {eta_s/60:.1f}m) "
                        f"keep={keep_n:,} review={review_n:,} drop={drop_n:,} "
                        f"cache_hits={cache_hits:,}",
                        flush=True,
                    )
                    fh.flush()

        tasks = [asyncio.create_task(process(row)) for row in rows]
        await asyncio.gather(*tasks, return_exceptions=True)
        fh.flush()
        fh.close()

    elapsed = time.monotonic() - t0
    print(f"\nDone: {completed:,} in {elapsed:.0f}s ({completed/elapsed:.1f}/s)")
    print(f"keep={keep_n:,} review={review_n:,} drop={drop_n:,} cache_hits={cache_hits:,}")
    print(f"Output: {output_csv}")


def main() -> None:
    asyncio.run(run(parse_args()))


if __name__ == "__main__":
    main()
