from __future__ import annotations

import argparse
import asyncio
import json
import os
import sys
import time
from pathlib import Path

import httpx
import pandas as pd
from dotenv import load_dotenv

ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.transcript_variant_prompt import (
    build_transcript_variant_user_prompt,
    classify_input_script,
    get_cacheable_transcript_variant_prompt,
    get_transcript_variant_json_schema,
)
from src.variant_provider import TranscriptVariantCacheManager


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Reroute English-mixed subset through Gemini")
    parser.add_argument("--input", type=Path, required=True)
    parser.add_argument("--key-env", required=True, help="Env var name for the API key")
    parser.add_argument("--part-name", required=True)
    parser.add_argument("--concurrency", type=int, default=500)
    parser.add_argument("--batch-size", type=int, default=10)
    parser.add_argument("--output-dir", type=Path, default=Path("final_data/english_mixed_reroute/results"))
    return parser.parse_args()


async def main_async(args: argparse.Namespace) -> None:
    load_dotenv(Path(".env"))
    api_key = os.environ[args.key_env]
    args.output_dir.mkdir(parents=True, exist_ok=True)

    out_ok = args.output_dir / f"{args.part_name}_recovered.parquet"
    out_fail = args.output_dir / f"{args.part_name}_failed.parquet"
    out_meta = args.output_dir / f"{args.part_name}_summary.json"

    if out_ok.exists() and out_fail.exists() and out_meta.exists():
        print(f"{args.part_name}: outputs already exist, skipping")
        return

    df = pd.read_parquet(args.input)
    print(f"{args.part_name}: loaded {len(df):,} rows from {args.input}")

    items = []
    for row in df.to_dict(orient="records"):
        items.append(
            {
                "id": row["row_id"],
                "video_id": row["video_id"],
                "segment_id": row["segment_id"],
                "language_code": row["target_language_code"],
                "input_script_profile": classify_input_script(
                    row["text"], row["target_language_code"]
                ).value,
                "text": row["text"],
            }
        )

    cache = await TranscriptVariantCacheManager(api_key).ensure_cache(518400)
    cache_name = cache["name"]
    schema = get_transcript_variant_json_schema()
    prompt = get_cacheable_transcript_variant_prompt()

    batches = [
        items[i : i + args.batch_size] for i in range(0, len(items), args.batch_size)
    ]
    print(
        f"{args.part_name}: {len(batches):,} batches, concurrency={args.concurrency}, "
        f"cache={cache_name}"
    )

    recovered_rows: list[dict] = []
    failed_rows: list[dict] = []

    async def run_batch(client: httpx.AsyncClient, batch: list[dict]) -> tuple[list[dict], list[dict]]:
        url = (
            "https://generativelanguage.googleapis.com/v1beta/models/"
            f"gemini-3-flash-preview:generateContent?key={api_key}"
        )
        body = {
            "contents": [
                {
                    "role": "user",
                    "parts": [{"text": build_transcript_variant_user_prompt(batch)}],
                }
            ],
            "cachedContent": cache_name,
            "generationConfig": {
                "temperature": 0,
                "responseMimeType": "application/json",
                "responseJsonSchema": schema,
                "thinkingConfig": {"thinkingLevel": "LOW"},
            },
        }

        for attempt in range(4):
            try:
                resp = await client.post(url, json=body)
                if resp.status_code == 429:
                    await asyncio.sleep(3 * (attempt + 1))
                    continue
                if resp.status_code >= 500:
                    await asyncio.sleep(2 * (attempt + 1))
                    continue
                if resp.status_code != 200:
                    return [], [
                        {
                            "row_id": item["id"],
                            "video_id": item["video_id"],
                            "segment_id": item["segment_id"],
                            "language_code": item["language_code"],
                            "text": item["text"][:500],
                            "error": f"HTTP {resp.status_code}",
                        }
                        for item in batch
                    ]

                data = resp.json()
                candidates = data.get("candidates", [])
                if not candidates:
                    return [], [
                        {
                            "row_id": item["id"],
                            "video_id": item["video_id"],
                            "segment_id": item["segment_id"],
                            "language_code": item["language_code"],
                            "text": item["text"][:500],
                            "error": f"BLOCKED:{json.dumps(data.get('promptFeedback', {}))[:200]}",
                        }
                        for item in batch
                    ]

                text = ""
                for part in candidates[0].get("content", {}).get("parts", []):
                    if "text" in part:
                        text = part["text"]
                        break
                if not text:
                    await asyncio.sleep(1)
                    continue

                parsed = json.loads(text)
                results = parsed.get("results", [])
                result_map = {r["id"]: r for r in results}

                ok_rows: list[dict] = []
                fail_rows: list[dict] = []
                for item in batch:
                    r = result_map.get(item["id"])
                    if r is None:
                        fail_rows.append(
                            {
                                "row_id": item["id"],
                                "video_id": item["video_id"],
                                "segment_id": item["segment_id"],
                                "language_code": item["language_code"],
                                "text": item["text"][:500],
                                "error": "missing_from_response",
                            }
                        )
                        continue
                    ok_rows.append(
                        {
                            "row_id": item["id"],
                            "video_id": item["video_id"],
                            "segment_id": item["segment_id"],
                            "language_code": item["language_code"],
                            "input_script_profile": item["input_script_profile"],
                            "processing_route": "gemini_en_mixed_reroute",
                            "native_script_text": r.get("native_script_text", ""),
                            "romanized_text": r.get("romanized_text", ""),
                            "validation_errors": "[]",
                        }
                    )
                return ok_rows, fail_rows
            except Exception as exc:
                if attempt < 3:
                    await asyncio.sleep(2 * (attempt + 1))
                    continue
                return [], [
                    {
                        "row_id": item["id"],
                        "video_id": item["video_id"],
                        "segment_id": item["segment_id"],
                        "language_code": item["language_code"],
                        "text": item["text"][:500],
                        "error": str(exc)[:200],
                    }
                    for item in batch
                ]

        return [], []

    sem = asyncio.Semaphore(args.concurrency)
    start = time.time()
    completed = 0

    async with httpx.AsyncClient(
        timeout=httpx.Timeout(connect=10, read=60, write=30, pool=20),
        limits=httpx.Limits(max_connections=max(args.concurrency * 2, 1000), max_keepalive_connections=max(args.concurrency, 500)),
    ) as client:

        async def bounded(batch: list[dict]) -> tuple[list[dict], list[dict]]:
            async with sem:
                return await run_batch(client, batch)

        tasks = [asyncio.create_task(bounded(batch)) for batch in batches]
        for coro in asyncio.as_completed(tasks):
            ok, fail = await coro
            recovered_rows.extend(ok)
            failed_rows.extend(fail)
            completed += 1
            if completed % 200 == 0 or completed == len(batches):
                elapsed = time.time() - start
                rpm = completed / max(elapsed / 60.0, 0.001)
                print(
                    f"{args.part_name}: {completed:,}/{len(batches):,} batches | "
                    f"{len(recovered_rows):,} ok | {len(failed_rows):,} fail | "
                    f"{rpm:.0f} batch-rpm"
                )

    pd.DataFrame(recovered_rows).to_parquet(out_ok, index=False)
    pd.DataFrame(failed_rows).to_parquet(out_fail, index=False)
    summary = {
        "part_name": args.part_name,
        "input_rows": len(items),
        "batches": len(batches),
        "recovered_rows": len(recovered_rows),
        "failed_rows": len(failed_rows),
        "concurrency": args.concurrency,
        "batch_size": args.batch_size,
        "key_env": args.key_env,
        "elapsed_seconds": round(time.time() - start, 1),
        "cache_name": cache_name,
    }
    out_meta.write_text(json.dumps(summary, indent=2))
    print(json.dumps(summary, indent=2))


def main() -> None:
    args = parse_args()
    asyncio.run(main_async(args))


if __name__ == "__main__":
    main()
