"""
Quick 10-sample determinism test for GEMINI_PROJECT2 and GEMINI_PROJECT3.
Reuses canary_data segments (Telugu + Hindi, 5 each).
3 runs per key, compares exact transcript match.
"""
import asyncio
import json
import logging
import os
import sys
import time
from collections import defaultdict
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
os.environ.setdefault("MOCK_MODE", "false")

from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent.parent / ".env")

from src.audio_polish import polish_all_segments
from src.providers.aistudio import AIStudioProvider
from src.providers.base import TranscriptionRequest, RequestStatus
from src.cache_manager import CacheManager

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("key_test")

CANARY_DATA = Path(__file__).parent / "canary_data"
RESULTS_DIR = Path(__file__).parent / "key_test_results"

TEST_LANGS = {
    "te": "5MDwLvubhR8",
    "hi": "BFU9eSKO_t4",
}
SEGMENTS_PER_LANG = 5
DETERMINISM_RUNS = 3


def load_segments(lang_code: str, video_id: str) -> list[TranscriptionRequest]:
    seg_dir = CANARY_DATA / lang_code / video_id / video_id / "segments"
    if not seg_dir.exists():
        logger.error(f"No segments at {seg_dir}")
        return []

    flacs = sorted(seg_dir.glob("*.flac"))[:SEGMENTS_PER_LANG + 3]
    if not flacs:
        return []

    polished = polish_all_segments(flacs)
    valid = [p for p in polished if not p.trim_meta.discarded][:SEGMENTS_PER_LANG]

    requests = []
    for seg in valid:
        seg_id = seg.trim_meta.original_file
        if seg.trim_meta.was_split:
            seg_id = f"{seg.trim_meta.original_file}_split{seg.trim_meta.split_index}"
        requests.append(TranscriptionRequest(
            segment_id=seg_id,
            audio_base64=seg.base64_audio,
            language_code=lang_code,
            original_file=seg.trim_meta.original_file,
        ))
    return requests


async def test_key(key_name: str, api_key: str, all_requests: dict[str, list]) -> dict:
    """Run determinism test for a single key: 3 runs, compare transcripts."""
    logger.info(f"\n{'='*60}")
    logger.info(f"Testing {key_name} ({api_key[:10]}...)")
    logger.info(f"{'='*60}")

    # Set up cache for this key
    cm = CacheManager(api_key)
    try:
        cache_name = await cm.ensure_cache()
        logger.info(f"  Cache: {cache_name}")
    except Exception as e:
        logger.warning(f"  Cache setup failed: {e}, running uncached")
        cache_name = None

    provider = AIStudioProvider(api_key=api_key, cached_content_name=cache_name)

    all_runs: dict[str, list[dict]] = defaultdict(list)
    run_summaries = []

    for run_idx in range(DETERMINISM_RUNS):
        logger.info(f"\n  --- Run {run_idx+1}/{DETERMINISM_RUNS} ---")
        start = time.monotonic()
        run_results = {}

        for lang_code, requests in sorted(all_requests.items()):
            logger.info(f"    [{lang_code}] Sending {len(requests)} segments...")
            responses = await provider.send_batch(requests)

            for resp in responses:
                if resp.status == RequestStatus.SUCCESS and resp.transcription_data:
                    key = f"{lang_code}:{resp.segment_id}"
                    all_runs[key].append(resp.transcription_data)
                    run_results[key] = {
                        "transcription": resp.transcription_data.get("transcription", ""),
                        "detected_language": resp.transcription_data.get("detected_language", ""),
                        "latency_ms": resp.latency_ms,
                        "cache_hit": resp.token_usage.cache_hit,
                        "cached_tokens": resp.token_usage.cached_tokens,
                    }
                else:
                    logger.error(f"    [{lang_code}] {resp.segment_id}: {resp.status} - {resp.error_message}")

        elapsed = time.monotonic() - start
        run_summaries.append({
            "run": run_idx + 1,
            "segments_ok": len(run_results),
            "time_s": round(elapsed, 2),
        })
        logger.info(f"  Run {run_idx+1}: {len(run_results)} OK in {elapsed:.1f}s")

        if run_idx < DETERMINISM_RUNS - 1:
            await asyncio.sleep(1)

    # Analyze
    deterministic = 0
    non_deterministic = 0
    lang_stats = defaultdict(lambda: {"total": 0, "deterministic": 0})
    diffs = []

    for key, runs in all_runs.items():
        lang = key.split(":")[0]
        lang_stats[lang]["total"] += 1

        if len(runs) < DETERMINISM_RUNS:
            non_deterministic += 1
            continue

        texts = [r.get("transcription", "") for r in runs]
        if all(t == texts[0] for t in texts):
            deterministic += 1
            lang_stats[lang]["deterministic"] += 1
        else:
            non_deterministic += 1
            diffs.append({
                "segment": key,
                "texts": texts,
                "first_diff_char": next(
                    (i for i, (a, b) in enumerate(zip(texts[0], texts[1])) if a != b),
                    min(len(texts[0]), len(texts[1]))
                ),
            })

    total = deterministic + non_deterministic
    pct = (deterministic / total * 100) if total > 0 else 0

    # Check cache metrics from last run
    last_run = run_summaries[-1] if run_summaries else {}
    cache_hits = sum(1 for key, runs in all_runs.items() if runs and runs[-1])

    result = {
        "key_name": key_name,
        "total_segments": total,
        "deterministic": deterministic,
        "non_deterministic": non_deterministic,
        "determinism_pct": round(pct, 1),
        "by_language": dict(lang_stats),
        "diffs": diffs[:10],
        "runs": run_summaries,
    }

    logger.info(f"\n  === {key_name} RESULTS ===")
    logger.info(f"  {deterministic}/{total} deterministic ({pct:.1f}%)")
    for lang, stats in sorted(lang_stats.items()):
        logger.info(f"    {lang}: {stats['deterministic']}/{stats['total']}")
    if diffs:
        for d in diffs[:3]:
            logger.info(f"    DIFF {d['segment']}: char {d['first_diff_char']}")
            for i, t in enumerate(d['texts']):
                logger.info(f"      run{i+1}: {t[:100]}")

    return result


async def main():
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)

    key2 = os.getenv("GEMINI_PROJECT2")
    key3 = os.getenv("GEMINI_PROJECT3")

    if not key2 or not key3:
        logger.error("GEMINI_PROJECT2 and GEMINI_PROJECT3 must be set in .env")
        return

    # Load segments from existing canary data
    logger.info("Loading segments from canary_data...")
    all_requests: dict[str, list] = {}
    for lang_code, video_id in TEST_LANGS.items():
        segs = load_segments(lang_code, video_id)
        if segs:
            all_requests[lang_code] = segs
            logger.info(f"  {lang_code}: {len(segs)} segments ready")
        else:
            logger.warning(f"  {lang_code}: no segments found")

    total = sum(len(v) for v in all_requests.values())
    logger.info(f"Total: {total} segments across {len(all_requests)} languages\n")

    if not all_requests:
        logger.error("No segments available. Run canary_test.py first to download data.")
        return

    # Test key 2
    result2 = await test_key("GEMINI_PROJECT2", key2, all_requests)

    # Test key 3
    result3 = await test_key("GEMINI_PROJECT3", key3, all_requests)

    # Save combined results
    combined = {
        "test_config": {
            "languages": list(all_requests.keys()),
            "segments_per_lang": {k: len(v) for k, v in all_requests.items()},
            "total_segments": total,
            "determinism_runs": DETERMINISM_RUNS,
        },
        "GEMINI_PROJECT2": result2,
        "GEMINI_PROJECT3": result3,
    }

    out_path = RESULTS_DIR / "key_consistency_results.json"
    with open(out_path, "w") as f:
        json.dump(combined, f, indent=2, ensure_ascii=False, default=str)

    logger.info(f"\n{'='*60}")
    logger.info("KEY CONSISTENCY TEST COMPLETE")
    logger.info(f"{'='*60}")
    logger.info(f"  PROJECT2: {result2['deterministic']}/{result2['total_segments']} deterministic ({result2['determinism_pct']}%)")
    logger.info(f"  PROJECT3: {result3['deterministic']}/{result3['total_segments']} deterministic ({result3['determinism_pct']}%)")
    logger.info(f"  Results: {out_path}")


if __name__ == "__main__":
    asyncio.run(main())
