"""
Fetch classification-oriented YouTube metadata for known video IDs.

Example:
  python scripts/fetch_youtube_video_metadata.py \
      --input-csv data/transcription_video_ids.csv \
      --output-csv data/youtube_metadata_sample_10.csv \
      --limit 10
"""
from __future__ import annotations

import asyncio
import argparse
import csv
import json
import os
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from itertools import chain
from pathlib import Path
from typing import Any, Iterator, Sequence

import httpx
from dotenv import load_dotenv

YOUTUBE_VIDEOS_ENDPOINT = "https://www.googleapis.com/youtube/v3/videos"
YOUTUBE_PARTS = "snippet,contentDetails,topicDetails"
YOUTUBE_FIELDS = (
    "items("
    "id,"
    "snippet(channelId,channelTitle,title,description,tags,categoryId,defaultLanguage,defaultAudioLanguage),"
    "contentDetails(duration,definition),"
    "topicDetails(topicCategories)"
    ")"
)
MAX_BATCH_SIZE = 50
DEFAULT_API_KEY_ENV_NAMES = [
    "YOUTUBE_API_KEY",
    "GOOGLE_API_KEY",
    "GEMINI_KEY",
    "GEMINI_PROJECT2",
    "GEMINI_PROJECT3",
    "GEMINI_PROJECT4",
]
OUTPUT_FIELDS = [
    "video_id",
    "channel_id",
    "channel_title",
    "title",
    "description",
    "tags",
    "category_id",
    "default_language",
    "default_audio_language",
    "duration",
    "duration_seconds",
    "definition",
    "topic_categories",
    "crawl_timestamp_utc",
    "fetch_status",
    "error_detail",
    "raw_json",
]


@dataclass(frozen=True)
class ApiKey:
    env_name: str
    value: str


@dataclass
class ApiKeyState:
    api_key: ApiKey
    min_interval_seconds: float
    next_request_at: float = 0.0
    lock: asyncio.Lock = field(default_factory=asyncio.Lock)

    async def wait_for_slot(self) -> None:
        async with self.lock:
            now = time.monotonic()
            wait_seconds = self.next_request_at - now
            if wait_seconds > 0:
                await asyncio.sleep(wait_seconds)
                now = time.monotonic()
            self.next_request_at = max(self.next_request_at, now) + self.min_interval_seconds


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Fetch YouTube metadata for a set of video IDs")
    parser.add_argument(
        "--input-csv",
        default="data/transcription_video_ids.csv",
        help="CSV containing video IDs (column: video_id/videoID, or first column if no header)",
    )
    parser.add_argument(
        "--output-csv",
        default="data/youtube_video_metadata.csv",
        help="Path for normalized metadata CSV output",
    )
    parser.add_argument(
        "--video-id",
        action="append",
        default=[],
        help="Explicit video ID to fetch; can be passed multiple times",
    )
    parser.add_argument(
        "--limit",
        type=int,
        default=0,
        help="Max number of video IDs to fetch after input ordering (0 = no limit)",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=50,
        help="Number of video IDs per API request (YouTube API max is 50)",
    )
    parser.add_argument(
        "--api-key-env",
        action="append",
        default=[],
        help="Env var name to read an API key from; can be passed multiple times",
    )
    parser.add_argument(
        "--sleep-seconds",
        type=float,
        default=0.0,
        help="Legacy global sleep knob; prefer --per-key-delay-seconds",
    )
    parser.add_argument(
        "--per-key-delay-seconds",
        type=float,
        default=0.2,
        help="Minimum delay between request starts for the same API key",
    )
    parser.add_argument(
        "--max-concurrency",
        type=int,
        default=32,
        help="Maximum in-flight requests across all keys",
    )
    parser.add_argument(
        "--progress-every",
        type=int,
        default=100,
        help="Print progress every N completed batches",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite output CSV if it already exists",
    )
    parser.add_argument(
        "--no-raw-json",
        action="store_true",
        help="Skip the raw_json column payload to reduce output size",
    )
    return parser.parse_args()


def load_env() -> None:
    load_dotenv(Path(__file__).resolve().parent.parent / ".env")


def discover_api_keys(explicit_env_names: Sequence[str]) -> list[ApiKey]:
    env_names = list(explicit_env_names) if explicit_env_names else DEFAULT_API_KEY_ENV_NAMES
    keys: list[ApiKey] = []
    seen_values: set[str] = set()
    for env_name in env_names:
        value = os.getenv(env_name, "").strip()
        if not value or value in seen_values:
            continue
        keys.append(ApiKey(env_name=env_name, value=value))
        seen_values.add(value)
    if not keys:
        joined = ", ".join(env_names)
        raise SystemExit(f"No API keys found. Checked env vars: {joined}")
    return keys


def read_video_ids(input_csv: Path) -> list[str]:
    if not input_csv.exists():
        raise SystemExit(f"Input CSV not found: {input_csv}")

    ids: list[str] = []
    seen: set[str] = set()
    with input_csv.open("r", encoding="utf-8", newline="") as handle:
        reader = csv.reader(handle)
        try:
            first_row = next(reader)
        except StopIteration:
            return ids

        header_index = detect_video_id_column(first_row)
        if header_index is not None:
            rows: Iterator[list[str]] = reader
        else:
            rows = chain([first_row], reader)
            header_index = 0

        for row in rows:
            if header_index >= len(row):
                continue
            video_id = row[header_index].strip()
            if not video_id or video_id in seen:
                continue
            ids.append(video_id)
            seen.add(video_id)
    return ids


def detect_video_id_column(row: Sequence[str]) -> int | None:
    normalized = [value.strip().lower() for value in row]
    for name in ("video_id", "videoid", "id"):
        if name in normalized:
            return normalized.index(name)
    return None


def select_video_ids(args: argparse.Namespace) -> list[str]:
    explicit_ids = [video_id.strip() for video_id in args.video_id if video_id.strip()]
    source_ids = explicit_ids if explicit_ids else read_video_ids(Path(args.input_csv))
    if args.limit > 0:
        source_ids = source_ids[: args.limit]
    if not source_ids:
        raise SystemExit("No video IDs selected")
    return source_ids


def chunked(values: Sequence[str], size: int) -> Iterator[list[str]]:
    for start in range(0, len(values), size):
        yield list(values[start : start + size])


async def fetch_video_batch(
    client: httpx.AsyncClient,
    video_ids: Sequence[str],
    key_states: Sequence[ApiKeyState],
    *,
    preferred_key_index: int,
    legacy_sleep_seconds: float,
) -> list[dict[str, Any]]:
    params = {
        "part": YOUTUBE_PARTS,
        "fields": YOUTUBE_FIELDS,
        "id": ",".join(video_ids),
    }
    failures: list[str] = []
    for offset in range(len(key_states)):
        key_state = key_states[(preferred_key_index + offset) % len(key_states)]
        await key_state.wait_for_slot()
        try:
            response = await client.get(
                YOUTUBE_VIDEOS_ENDPOINT,
                params={**params, "key": key_state.api_key.value},
            )
        except httpx.HTTPError as exc:
            failures.append(f"{key_state.api_key.env_name}: transport error: {exc}")
            continue

        if response.status_code == 200:
            if legacy_sleep_seconds > 0:
                await asyncio.sleep(legacy_sleep_seconds)
            payload = response.json()
            return payload.get("items", [])

        failures.append(
            f"{key_state.api_key.env_name}: {response.status_code} {extract_error_message(response)}"
        )
        if legacy_sleep_seconds > 0:
            await asyncio.sleep(legacy_sleep_seconds)

    raise RuntimeError("All API keys failed for batch: " + " | ".join(failures))


def extract_error_message(response: httpx.Response) -> str:
    try:
        payload = response.json()
    except ValueError:
        return response.text[:200]
    error = payload.get("error", {})
    if isinstance(error, dict):
        message = error.get("message")
        if message:
            return str(message)
    return json.dumps(payload, ensure_ascii=True)[:200]


def build_fetch_failure_message(exc: RuntimeError) -> str:
    message = str(exc)
    if "youtube.api.v3.V3DataVideoService.List are blocked" in message:
        return (
            f"{message}\n"
            "The configured Google API keys exist, but YouTube Data API access is blocked for them. "
            "Add a key with YouTube Data API v3 enabled and no API restriction blocking "
            "`youtube.v3.videos.list`, then rerun."
        )
    return message


def parse_iso8601_duration_seconds(duration: str) -> int | None:
    if not duration or not duration.startswith("P"):
        return None

    total_seconds = 0
    number = ""
    in_time_section = False
    for char in duration[1:]:
        if char == "T":
            in_time_section = True
            continue
        if char.isdigit():
            number += char
            continue
        if not number:
            continue
        value = int(number)
        number = ""
        if char == "D":
            total_seconds += value * 86_400
        elif char == "H":
            total_seconds += value * 3_600
        elif char == "M":
            total_seconds += value * (60 if in_time_section else 2_592_000)
        elif char == "S":
            total_seconds += value
    return total_seconds


def normalize_video_item(
    item: dict[str, Any],
    *,
    crawl_timestamp_utc: str,
    include_raw_json: bool,
) -> dict[str, str]:
    snippet = item.get("snippet", {}) or {}
    content_details = item.get("contentDetails", {}) or {}
    topic_details = item.get("topicDetails", {}) or {}
    duration = str(content_details.get("duration", "") or "")
    duration_seconds = parse_iso8601_duration_seconds(duration)
    return {
        "video_id": str(item.get("id", "") or ""),
        "channel_id": str(snippet.get("channelId", "") or ""),
        "channel_title": str(snippet.get("channelTitle", "") or ""),
        "title": str(snippet.get("title", "") or ""),
        "description": str(snippet.get("description", "") or ""),
        "tags": json.dumps(snippet.get("tags", []) or [], ensure_ascii=True),
        "category_id": str(snippet.get("categoryId", "") or ""),
        "default_language": str(snippet.get("defaultLanguage", "") or ""),
        "default_audio_language": str(snippet.get("defaultAudioLanguage", "") or ""),
        "duration": duration,
        "duration_seconds": "" if duration_seconds is None else str(duration_seconds),
        "definition": str(content_details.get("definition", "") or ""),
        "topic_categories": json.dumps(topic_details.get("topicCategories", []) or [], ensure_ascii=True),
        "crawl_timestamp_utc": crawl_timestamp_utc,
        "fetch_status": "ok",
        "error_detail": "",
        "raw_json": json.dumps(item, ensure_ascii=True, sort_keys=True) if include_raw_json else "",
    }


def build_missing_row(video_id: str, *, crawl_timestamp_utc: str) -> dict[str, str]:
    return {
        "video_id": video_id,
        "channel_id": "",
        "channel_title": "",
        "title": "",
        "description": "",
        "tags": "[]",
        "category_id": "",
        "default_language": "",
        "default_audio_language": "",
        "duration": "",
        "duration_seconds": "",
        "definition": "",
        "topic_categories": "[]",
        "crawl_timestamp_utc": crawl_timestamp_utc,
        "fetch_status": "not_found",
        "error_detail": "Video not returned by YouTube videos.list",
        "raw_json": "",
    }


async def fetch_and_normalize_batch(
    batch_index: int,
    video_id_batch: Sequence[str],
    client: httpx.AsyncClient,
    key_states: Sequence[ApiKeyState],
    crawl_timestamp_utc: str,
    include_raw_json: bool,
    legacy_sleep_seconds: float,
) -> tuple[int, list[dict[str, str]]]:
    items = await fetch_video_batch(
        client,
        video_id_batch,
        key_states,
        preferred_key_index=batch_index % len(key_states),
        legacy_sleep_seconds=legacy_sleep_seconds,
    )
    by_video_id = {str(item.get("id", "") or ""): item for item in items}
    rows: list[dict[str, str]] = []
    for video_id in video_id_batch:
        item = by_video_id.get(video_id)
        if item is None:
            rows.append(build_missing_row(video_id, crawl_timestamp_utc=crawl_timestamp_utc))
            continue
        rows.append(
            normalize_video_item(
                item,
                crawl_timestamp_utc=crawl_timestamp_utc,
                include_raw_json=include_raw_json,
            )
        )
    return batch_index, rows


async def _run_fetch_impl(args: argparse.Namespace) -> tuple[Path, int, int, int]:
    load_env()

    output_csv = Path(args.output_csv)
    if output_csv.exists() and not args.overwrite:
        raise SystemExit(f"Output already exists: {output_csv} (pass --overwrite to replace it)")

    batch_size = min(max(args.batch_size, 1), MAX_BATCH_SIZE)
    if args.batch_size != batch_size:
        print(f"Adjusted batch size to {batch_size} (YouTube API max is {MAX_BATCH_SIZE})")

    api_keys = discover_api_keys(args.api_key_env)
    key_states = [
        ApiKeyState(api_key=api_key, min_interval_seconds=max(args.per_key_delay_seconds, 0.0))
        for api_key in api_keys
    ]
    selected_video_ids = select_video_ids(args)
    video_id_batches = list(chunked(selected_video_ids, batch_size))
    selected_video_count = len(selected_video_ids)
    crawl_timestamp_utc = datetime.now(timezone.utc).isoformat()
    output_csv.parent.mkdir(parents=True, exist_ok=True)
    temp_output_csv = output_csv.with_name(f"{output_csv.name}.tmp")
    if temp_output_csv.exists():
        temp_output_csv.unlink()
    row_count = 0
    ok_count = 0
    missing_count = 0
    try:
        with temp_output_csv.open("w", encoding="utf-8", newline="") as handle:
            writer = csv.DictWriter(handle, fieldnames=OUTPUT_FIELDS)
            writer.writeheader()
            pending_rows: dict[int, list[dict[str, str]]] = {}
            next_batch_to_write = 0
            semaphore = asyncio.Semaphore(max(args.max_concurrency, 1))

            async with httpx.AsyncClient(timeout=30.0) as client:
                async def guarded_fetch(batch_index: int, video_id_batch: Sequence[str]) -> tuple[int, list[dict[str, str]]]:
                    async with semaphore:
                        return await fetch_and_normalize_batch(
                            batch_index,
                            video_id_batch,
                            client,
                            key_states,
                            crawl_timestamp_utc,
                            include_raw_json=not args.no_raw_json,
                            legacy_sleep_seconds=max(args.sleep_seconds, 0.0),
                        )

                tasks = [
                    asyncio.create_task(guarded_fetch(batch_index, video_id_batch))
                    for batch_index, video_id_batch in enumerate(video_id_batches)
                ]

                completed_batches = 0
                for task in asyncio.as_completed(tasks):
                    batch_index, rows = await task
                    pending_rows[batch_index] = rows
                    completed_batches += 1
                    if args.progress_every > 0 and (
                        completed_batches % args.progress_every == 0
                        or completed_batches == len(video_id_batches)
                    ):
                        print(
                            f"Completed {completed_batches:,}/{len(video_id_batches):,} batches "
                            f"({min(completed_batches * batch_size, selected_video_count):,}/{selected_video_count:,} videos)"
                        )

                    while next_batch_to_write in pending_rows:
                        ready_rows = pending_rows.pop(next_batch_to_write)
                        for row in ready_rows:
                            writer.writerow(row)
                            row_count += 1
                            if row["fetch_status"] == "ok":
                                ok_count += 1
                            else:
                                missing_count += 1
                        next_batch_to_write += 1

        temp_output_csv.replace(output_csv)
        return output_csv, row_count, ok_count, missing_count
    except RuntimeError as exc:
        raise SystemExit(build_fetch_failure_message(exc)) from exc
    finally:
        if temp_output_csv.exists():
            temp_output_csv.unlink()


async def run_fetch(args: argparse.Namespace) -> tuple[Path, int, int, int]:
    try:
        return await asyncio.wait_for(_run_fetch_impl(args), timeout=500)
    except asyncio.TimeoutError as exc:
        raise SystemExit(
            "Fetch exceeded the 500-second cap before completion. "
            "Lower per-key delay, increase concurrency, or reduce output size and rerun."
        ) from exc

def main() -> None:
    output_csv, row_count, ok_count, missing_count = asyncio.run(run_fetch(parse_args()))
    print(f"Wrote {row_count:,} rows to {output_csv}")
    print(f"Resolved videos: {ok_count:,}")
    print(f"Missing videos: {missing_count:,}")


if __name__ == "__main__":
    main()
