from __future__ import annotations

import argparse
import asyncio
import hashlib
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Any

import httpx
from dotenv import load_dotenv

ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
load_dotenv(ROOT / ".env")

from src.config import EnvConfig, GEMINI_MODEL, LANGUAGE_MAP, THINKING_LEVEL
from src.transcript_variant_prompt import (
    InputScriptProfile,
    TranscriptVariantBatchResult,
    build_transcript_variant_user_prompt,
    classify_input_script,
    detect_script_counts,
    extract_protected_spans,
    get_cacheable_transcript_variant_prompt,
    get_target_script_block,
    get_transcript_variant_json_schema,
    romanized_text_is_ascii,
)


logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("variant_prompt_test")

AISTUDIO_BASE = "https://generativelanguage.googleapis.com/v1beta"
DEFAULT_OUTPUT = ROOT / "final_data" / "transcript_variant_prompt_test_results.json"
DEFAULT_SAMPLE_FILE = ROOT / "data" / "transcript_variant_prompt_samples.json"


def write_json(path: Path, payload: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2, ensure_ascii=False, sort_keys=True) + "\n")


def get_cache_display_name() -> str:
    prompt_hash = hashlib.sha1(
        get_cacheable_transcript_variant_prompt().encode("utf-8")
    ).hexdigest()[:12]
    return f"transcript-variant-{prompt_hash}"


def chunked(items: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]:
    return [items[idx: idx + size] for idx in range(0, len(items), size)]


def load_samples(path: Path) -> list[dict[str, Any]]:
    data = json.loads(path.read_text())
    if not isinstance(data, list):
        raise ValueError(f"Sample file must contain a JSON array: {path}")
    return data


def expand_samples(sample_pool: list[dict[str, Any]], target_count: int) -> list[dict[str, Any]]:
    if target_count <= 0:
        return sample_pool
    expanded: list[dict[str, Any]] = []
    for idx in range(target_count):
        base = dict(sample_pool[idx % len(sample_pool)])
        base["source_id"] = base["id"]
        base["id"] = f"{base['id']}__{idx:03d}"
        expanded.append(base)
    return expanded


def apply_shard(items: list[dict[str, Any]], shard_index: int, num_shards: int) -> list[dict[str, Any]]:
    if num_shards <= 1:
        return items
    return [item for idx, item in enumerate(items) if idx % num_shards == shard_index]


def prepare_items(samples: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, int]]:
    local_skips: list[dict[str, Any]] = []
    model_items: list[dict[str, Any]] = []
    profile_counts = {profile.value: 0 for profile in InputScriptProfile}

    for sample in samples:
        profile = classify_input_script(sample["text"], sample["language_code"])
        profile_counts[profile.value] += 1
        prepared = {
            "id": sample["id"],
            "source_id": sample.get("source_id", sample["id"]),
            "language_code": sample["language_code"],
            "text": sample["text"],
            "input_script_profile": profile.value,
        }
        if profile == InputScriptProfile.fully_roman:
            local_skips.append(
                {
                    "id": sample["id"],
                    "language_code": sample["language_code"],
                    "input_script_profile": profile.value,
                    "native_script_text": "",
                    "romanized_text": "",
                    "skip_reason": "fully_roman_local_skip",
                }
            )
            continue
        model_items.append(prepared)

    return local_skips, model_items, profile_counts


def get_key_map(config: EnvConfig) -> dict[int, str]:
    return {idx: key for idx, key in enumerate(config.gemini_keys)}


class TranscriptVariantCacheManager:
    def __init__(self, api_key: str):
        self.api_key = api_key

    async def ensure_cache(self, ttl_s: int) -> dict[str, Any]:
        existing = await self._find_existing_cache()
        if existing:
            return existing
        return await self._create_cache(ttl_s)

    async def _find_existing_cache(self) -> dict[str, Any] | None:
        url = f"{AISTUDIO_BASE}/cachedContents?key={self.api_key}"
        display_name = get_cache_display_name()
        async with httpx.AsyncClient(timeout=30.0) as client:
            resp = await client.get(url)
            if resp.status_code != 200:
                logger.warning("Could not list caches: %s %s", resp.status_code, resp.text[:200])
                return None
            model_name = f"models/{GEMINI_MODEL}"
            for cache in resp.json().get("cachedContents", []):
                if cache.get("model") == model_name and cache.get("displayName") == display_name:
                    detailed = await self._get_cache(cache["name"])
                    if detailed:
                        return detailed
                    return cache
        return None

    async def _get_cache(self, cache_name: str) -> dict[str, Any] | None:
        url = f"{AISTUDIO_BASE}/{cache_name}?key={self.api_key}"
        async with httpx.AsyncClient(timeout=30.0) as client:
            resp = await client.get(url)
            if resp.status_code == 200:
                return resp.json()
        return None

    async def _create_cache(self, ttl_s: int) -> dict[str, Any]:
        url = f"{AISTUDIO_BASE}/cachedContents?key={self.api_key}"
        body = {
            "model": f"models/{GEMINI_MODEL}",
            "displayName": get_cache_display_name(),
            "systemInstruction": {"parts": [{"text": get_cacheable_transcript_variant_prompt()}]},
            "ttl": f"{ttl_s}s",
        }
        async with httpx.AsyncClient(timeout=60.0) as client:
            resp = await client.post(url, json=body)
            if resp.status_code != 200:
                raise RuntimeError(f"Cache creation failed: {resp.status_code} {resp.text[:500]}")
            return resp.json()


class TranscriptVariantClient:
    def __init__(self, api_key: str, cache_name: str | None):
        self.api_key = api_key
        self.cache_name = cache_name
        self.schema = get_transcript_variant_json_schema()

    async def generate_batch(self, items: list[dict[str, Any]]) -> dict[str, Any]:
        url = f"{AISTUDIO_BASE}/models/{GEMINI_MODEL}:generateContent?key={self.api_key}"
        body: dict[str, Any] = {
            "contents": [
                {
                    "role": "user",
                    "parts": [{"text": build_transcript_variant_user_prompt(items)}],
                }
            ],
            "generationConfig": {
                "temperature": 0,
                "responseMimeType": "application/json",
                "responseJsonSchema": self.schema,
                "thinkingConfig": {
                    "thinkingLevel": THINKING_LEVEL.upper(),
                },
            },
        }
        if self.cache_name:
            body["cachedContent"] = self.cache_name
        else:
            body["systemInstruction"] = {"parts": [{"text": get_cacheable_transcript_variant_prompt()}]}

        async with httpx.AsyncClient(timeout=httpx.Timeout(connect=10.0, read=60.0, write=60.0, pool=10.0)) as client:
            resp = await client.post(url, json=body)
            if resp.status_code != 200:
                raise RuntimeError(f"Generate failed: {resp.status_code} {resp.text[:500]}")
            return resp.json()


def _extract_response_text(response_json: dict[str, Any]) -> str:
    candidates = response_json.get("candidates", [])
    if not candidates:
        raise ValueError("No candidates in Gemini response")
    for part in candidates[0].get("content", {}).get("parts", []):
        if "text" in part:
            return part["text"]
    raise ValueError("No text part in Gemini response")


def validate_item(expected: dict[str, Any], actual: dict[str, Any]) -> list[str]:
    errors: list[str] = []
    if actual["id"] != expected["id"]:
        errors.append(f"{expected['id']}: id mismatch ({actual['id']})")
    if not actual["native_script_text"].strip():
        errors.append(f"{expected['id']}: empty native_script_text")
    if not actual["romanized_text"].strip():
        errors.append(f"{expected['id']}: empty romanized_text")

    for span in extract_protected_spans(expected["text"]):
        if span not in actual["native_script_text"]:
            errors.append(f"{expected['id']}: protected span missing in native_script_text -> {span}")
        if span not in actual["romanized_text"]:
            errors.append(f"{expected['id']}: protected span missing in romanized_text -> {span}")

    if not romanized_text_is_ascii(actual["romanized_text"]):
        errors.append(f"{expected['id']}: romanized_text contains non-ASCII outside protected spans")

    target_script = get_target_script_block(expected["language_code"])
    native_counts = detect_script_counts(actual["native_script_text"])
    roman_counts = detect_script_counts(actual["romanized_text"])
    if target_script != "Latin" and native_counts.get(target_script, 0) == 0:
        errors.append(f"{expected['id']}: native_script_text missing target script characters")
    if target_script != "Latin" and roman_counts.get(target_script, 0) > 0:
        errors.append(f"{expected['id']}: romanized_text still contains target script characters")

    return errors


def compare_runs(run_outputs: list[dict[str, dict[str, Any]]], item_ids: list[str]) -> tuple[int, list[dict[str, Any]]]:
    deterministic = 0
    diffs: list[dict[str, Any]] = []
    for item_id in item_ids:
        serialized = [
            json.dumps(run.get(item_id, {}), ensure_ascii=False, sort_keys=True)
            for run in run_outputs
        ]
        if all(value == serialized[0] for value in serialized):
            deterministic += 1
            continue
        diffs.append(
            {
                "id": item_id,
                "runs": [run.get(item_id, {}) for run in run_outputs[:3]],
            }
        )
    return deterministic, diffs


async def run_for_key(
    *,
    key_index: int,
    api_key: str,
    args: argparse.Namespace,
    samples: list[dict[str, Any]],
    checkpoint_payload: dict[str, Any],
) -> dict[str, Any]:
    local_skips, model_items, profile_counts = prepare_items(samples)
    item_ids = [sample["id"] for sample in samples]

    cache_manager = TranscriptVariantCacheManager(api_key)
    cache_info = await cache_manager.ensure_cache(args.ttl_s)
    cache_name = cache_info["name"]
    cached_prompt_tokens = cache_info.get("usageMetadata", {}).get("totalTokenCount", 0)
    client = TranscriptVariantClient(api_key=api_key, cache_name=cache_name)

    logger.info(
        "[key %s] samples=%s model_items=%s local_skips=%s cached_prompt_tokens=%s",
        key_index,
        len(samples),
        len(model_items),
        len(local_skips),
        cached_prompt_tokens,
    )

    run_outputs: list[dict[str, dict[str, Any]]] = []
    run_summaries: list[dict[str, Any]] = []
    validation_errors_all: list[str] = []
    schema_errors_all: list[str] = []

    for run_idx in range(args.runs):
        logger.info("[key %s] starting run %s/%s", key_index, run_idx + 1, args.runs)
        run_start = time.monotonic()
        merged_outputs: dict[str, dict[str, Any]] = {
            item["id"]: item for item in local_skips
        }
        usage_rows: list[dict[str, Any]] = []
        run_validation_errors: list[str] = []
        run_schema_errors: list[str] = []

        batches = chunked(model_items, args.batch_size)
        for batch_idx, batch in enumerate(batches, start=1):
            response_json = await client.generate_batch(batch)
            usage = response_json.get("usageMetadata", {})
            usage_rows.append(
                {
                    "batch": batch_idx,
                    "prompt_tokens": usage.get("promptTokenCount", 0),
                    "cached_tokens": usage.get("cachedContentTokenCount", 0),
                    "output_tokens": usage.get("candidatesTokenCount", 0),
                }
            )

            response_text = _extract_response_text(response_json)
            try:
                parsed = TranscriptVariantBatchResult.model_validate_json(response_text)
            except Exception as exc:
                run_schema_errors.append(f"run {run_idx + 1} batch {batch_idx}: {exc}")
                continue

            if len(parsed.results) != len(batch):
                run_schema_errors.append(
                    f"run {run_idx + 1} batch {batch_idx}: expected {len(batch)} results, got {len(parsed.results)}"
                )
                continue

            for expected, actual in zip(batch, parsed.results):
                actual_dict = actual.model_dump()
                errors = validate_item(expected, actual_dict)
                if errors:
                    run_validation_errors.extend(errors)
                merged_outputs[expected["id"]] = actual_dict

            checkpoint_payload.setdefault("per_key_progress", {})[str(key_index)] = {
                "current_run": run_idx + 1,
                "current_batch": batch_idx,
                "total_batches": len(batches),
                "completed_outputs": len(merged_outputs),
            }
            write_json(Path(args.output), checkpoint_payload)

        elapsed_s = round(time.monotonic() - run_start, 2)
        run_summaries.append(
            {
                "run": run_idx + 1,
                "elapsed_s": elapsed_s,
                "outputs": len(merged_outputs),
                "schema_errors": len(run_schema_errors),
                "validation_errors": len(run_validation_errors),
                "usage": usage_rows,
            }
        )
        run_outputs.append(merged_outputs)
        schema_errors_all.extend(run_schema_errors)
        validation_errors_all.extend(run_validation_errors)
        logger.info(
            "[key %s] run %s done in %ss (schema_errors=%s validation_errors=%s)",
            key_index,
            run_idx + 1,
            elapsed_s,
            len(run_schema_errors),
            len(run_validation_errors),
        )

    deterministic_all, diffs_all = compare_runs(run_outputs, item_ids)
    deterministic_model, diffs_model = compare_runs(run_outputs, [item["id"] for item in model_items])

    return {
        "key_index": key_index,
        "cache_name": cache_name,
        "cached_prompt_tokens": cached_prompt_tokens,
        "cache_threshold_ok": cached_prompt_tokens >= 1024,
        "sample_count": len(samples),
        "model_item_count": len(model_items),
        "local_skip_count": len(local_skips),
        "profile_counts": profile_counts,
        "run_summaries": run_summaries,
        "determinism_all_items": {
            "deterministic": deterministic_all,
            "total": len(item_ids),
            "pct": round((deterministic_all / max(len(item_ids), 1)) * 100, 2),
            "diff_examples": diffs_all[:5],
        },
        "determinism_model_items": {
            "deterministic": deterministic_model,
            "total": len(model_items),
            "pct": round((deterministic_model / max(len(model_items), 1)) * 100, 2),
            "diff_examples": diffs_model[:5],
        },
        "schema_error_count": len(schema_errors_all),
        "validation_error_count": len(validation_errors_all),
        "schema_errors": schema_errors_all[:20],
        "validation_errors": validation_errors_all[:20],
        "sample_outputs": [run_outputs[-1].get(item_id, {}) for item_id in item_ids[:5]],
    }


def compare_keys(key_results: list[dict[str, Any]]) -> dict[str, Any]:
    if len(key_results) < 2:
        return {}
    reference = key_results[0]["sample_outputs"]
    comparison_rows = []
    for other in key_results[1:]:
        same = 0
        total = min(len(reference), len(other["sample_outputs"]))
        for left, right in zip(reference, other["sample_outputs"]):
            if json.dumps(left, ensure_ascii=False, sort_keys=True) == json.dumps(right, ensure_ascii=False, sort_keys=True):
                same += 1
        comparison_rows.append(
            {
                "left_key_index": key_results[0]["key_index"],
                "right_key_index": other["key_index"],
                "exact_match_on_sample_outputs": same,
                "sample_size": total,
                "pct": round((same / max(total, 1)) * 100, 2),
            }
        )
    return {"sample_output_cross_key": comparison_rows}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Live Gemini test for transcript variant prompt.")
    parser.add_argument("--sample-file", type=Path, default=DEFAULT_SAMPLE_FILE)
    parser.add_argument("--sample-target", type=int, default=100)
    parser.add_argument("--runs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=10)
    parser.add_argument("--key-indices", type=str, default="0")
    parser.add_argument("--num-shards", type=int, default=1)
    parser.add_argument("--shard-index", type=int, default=0)
    parser.add_argument("--ttl-s", type=int, default=518400)
    parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
    return parser.parse_args()


async def main() -> None:
    args = parse_args()
    config = EnvConfig()
    key_map = get_key_map(config)
    requested_key_indices = [int(part.strip()) for part in args.key_indices.split(",") if part.strip()]

    for key_index in requested_key_indices:
        if key_index not in key_map:
            raise ValueError(f"Key index {key_index} not available. Have: {sorted(key_map)}")

    sample_pool = load_samples(args.sample_file)
    expanded_samples = expand_samples(sample_pool, args.sample_target)
    sharded_samples = apply_shard(expanded_samples, args.shard_index, args.num_shards)

    if not sharded_samples:
        raise ValueError("No samples left after sharding.")

    checkpoint_payload: dict[str, Any] = {
        "started_at_epoch_s": round(time.time(), 3),
        "model": GEMINI_MODEL,
        "sample_file": str(args.sample_file),
        "sample_pool_size": len(sample_pool),
        "sample_target": args.sample_target,
        "shard_index": args.shard_index,
        "num_shards": args.num_shards,
        "runs": args.runs,
        "batch_size": args.batch_size,
        "requested_key_indices": requested_key_indices,
    }
    write_json(args.output, checkpoint_payload)

    key_results: list[dict[str, Any]] = []
    for key_index in requested_key_indices:
        result = await run_for_key(
            key_index=key_index,
            api_key=key_map[key_index],
            args=args,
            samples=sharded_samples,
            checkpoint_payload=checkpoint_payload,
        )
        key_results.append(result)
        checkpoint_payload.setdefault("completed_key_results", []).append(
            {
                "key_index": key_index,
                "cached_prompt_tokens": result["cached_prompt_tokens"],
                "determinism_model_items_pct": result["determinism_model_items"]["pct"],
                "validation_error_count": result["validation_error_count"],
                "schema_error_count": result["schema_error_count"],
            }
        )
        write_json(args.output, checkpoint_payload)

    final_payload = {
        **checkpoint_payload,
        "finished_at_epoch_s": round(time.time(), 3),
        "key_results": key_results,
        "cross_key_summary": compare_keys(key_results),
        "language_distribution": {
            code: sum(1 for sample in sharded_samples if sample["language_code"] == code)
            for code in sorted({sample["language_code"] for sample in sharded_samples})
        },
    }
    write_json(args.output, final_payload)

    logger.info("Saved results -> %s", args.output)
    for result in key_results:
        logger.info(
            "[key %s] cache_tokens=%s model_determinism=%s%% validation_errors=%s schema_errors=%s",
            result["key_index"],
            result["cached_prompt_tokens"],
            result["determinism_model_items"]["pct"],
            result["validation_error_count"],
            result["schema_error_count"],
        )


if __name__ == "__main__":
    asyncio.run(main())
