"""
12-language canary test: determinism + concurrency on both AI Studio and OpenRouter.
Phase 1: Download 1 video per language, extract 5 segments each
Phase 2: Determinism test — 3 runs per provider, compare exact matches
Phase 3: 500-concurrency stress test on AI Studio
Phase 4: 500-concurrency stress test on OpenRouter
"""
import asyncio
import json
import logging
import os
import sys
import time
from collections import defaultdict
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
os.environ.setdefault("MOCK_MODE", "false")

from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent.parent / ".env")

from src.config import EnvConfig, LANGUAGE_MAP
from src.audio_polish import polish_all_segments
from src.providers.aistudio import AIStudioProvider
from src.providers.openrouter import OpenRouterProvider
from src.providers.base import TranscriptionRequest, RequestStatus
from src.cache_manager import CacheManager
from src.r2_client import R2Client

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("canary")

RESULTS_DIR = Path(__file__).parent / "canary_results"
WORK_DIR = Path(__file__).parent / "canary_data"

CANARY_VIDEOS = {
    "as": "efw1HydOOfI",
    "bn": "D7iL59a39Bs",
    "en": "nM2KMwb86IU",
    "gu": "X7tqVj0IuOI",
    "hi": "BFU9eSKO_t4",
    "kn": "o-zn8bAjkBM",
    "ml": "dlo3pvcyqFY",
    "mr": "mqyv2attoI0",
    "or": "IYKHdR3-w_U",
    "pa": "l5eQd1E_Huo",
    "ta": "8ozSnoys7JE",
    "te": "5MDwLvubhR8",
}

SEGMENTS_PER_LANG = 5
DETERMINISM_RUNS = 3
CONCURRENCY_TARGET = 500


def save_result(name: str, data: dict):
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    path = RESULTS_DIR / f"{name}.json"
    with open(path, "w") as f:
        json.dump(data, f, indent=2, ensure_ascii=False, default=str)
    logger.info(f"Saved -> {path}")


async def download_and_extract(config: EnvConfig) -> dict[str, list]:
    """Download tars, extract, polish, return {lang: [TranscriptionRequest]}."""
    r2 = R2Client(config)
    all_requests: dict[str, list[TranscriptionRequest]] = {}

    for lang_code, video_id in CANARY_VIDEOS.items():
        lang_dir = WORK_DIR / lang_code
        lang_dir.mkdir(parents=True, exist_ok=True)

        try:
            logger.info(f"[{lang_code}] Downloading {video_id}...")
            tar_path = r2.download_tar(video_id, lang_dir)
            extracted = r2.extract_tar(tar_path, video_id)

            if not extracted.segment_paths:
                logger.warning(f"[{lang_code}] No segments in {video_id}")
                continue

            # Polish and pick first N valid segments
            polished = polish_all_segments(extracted.segment_paths[:SEGMENTS_PER_LANG + 3])
            valid = [p for p in polished if not p.trim_meta.discarded][:SEGMENTS_PER_LANG]

            if not valid:
                logger.warning(f"[{lang_code}] All segments discarded")
                continue

            requests = []
            for seg in valid:
                seg_id = seg.trim_meta.original_file
                if seg.trim_meta.was_split:
                    seg_id = f"{seg.trim_meta.original_file}_split{seg.trim_meta.split_index}"
                requests.append(TranscriptionRequest(
                    segment_id=seg_id,
                    audio_base64=seg.base64_audio,
                    language_code=lang_code,
                    original_file=seg.trim_meta.original_file,
                ))

            all_requests[lang_code] = requests
            logger.info(f"[{lang_code}] Ready: {len(requests)} segments from {video_id}")

        except Exception as e:
            logger.error(f"[{lang_code}] Failed to prepare {video_id}: {e}")

    return all_requests


async def run_determinism_test(
    requests_by_lang: dict[str, list],
    provider,
    provider_name: str,
    num_runs: int = DETERMINISM_RUNS,
) -> dict:
    """Run same segments N times, compare transcription outputs for exact match."""
    logger.info(f"\n{'='*70}")
    logger.info(f"DETERMINISM TEST: {provider_name} ({num_runs} runs)")
    logger.info(f"{'='*70}")

    all_runs: dict[str, list[dict]] = defaultdict(list)  # seg_id -> [run1, run2, ...]
    run_summaries = []

    for run_idx in range(num_runs):
        logger.info(f"\n--- Run {run_idx+1}/{num_runs} ---")
        run_start = time.monotonic()
        run_results = {}

        for lang_code, requests in sorted(requests_by_lang.items()):
            logger.info(f"  [{lang_code}] Sending {len(requests)} segments...")
            responses = await provider.send_batch(requests)

            for resp in responses:
                if resp.status == RequestStatus.SUCCESS and resp.transcription_data:
                    key = f"{lang_code}:{resp.segment_id}"
                    all_runs[key].append(resp.transcription_data)
                    run_results[key] = {
                        "transcription": resp.transcription_data.get("transcription", ""),
                        "detected_language": resp.transcription_data.get("detected_language", ""),
                        "latency_ms": resp.latency_ms,
                        "cache_hit": resp.token_usage.cache_hit,
                        "input_tokens": resp.token_usage.input_tokens,
                        "cached_tokens": resp.token_usage.cached_tokens,
                        "output_tokens": resp.token_usage.output_tokens,
                    }
                elif resp.status != RequestStatus.SUCCESS:
                    logger.error(f"  [{lang_code}] {resp.segment_id}: {resp.status} - {resp.error_message}")

        run_time = time.monotonic() - run_start
        run_summaries.append({
            "run": run_idx + 1,
            "segments_ok": len(run_results),
            "time_s": round(run_time, 2),
            "results": run_results,
        })
        logger.info(f"  Run {run_idx+1} done: {len(run_results)} OK in {run_time:.1f}s")

        if run_idx < num_runs - 1:
            await asyncio.sleep(2)

    # Analyze determinism
    deterministic = 0
    non_deterministic = 0
    lang_stats = defaultdict(lambda: {"total": 0, "deterministic": 0})
    diffs = []

    for key, runs in all_runs.items():
        lang = key.split(":")[0]
        lang_stats[lang]["total"] += 1

        if len(runs) < num_runs:
            non_deterministic += 1
            continue

        texts = [r.get("transcription", "") for r in runs]
        if all(t == texts[0] for t in texts):
            deterministic += 1
            lang_stats[lang]["deterministic"] += 1
        else:
            non_deterministic += 1
            diffs.append({
                "segment": key,
                "texts": texts[:4],
                "first_diff_char": next(
                    (i for i, (a, b) in enumerate(zip(texts[0], texts[1])) if a != b),
                    min(len(texts[0]), len(texts[1]))
                ),
            })

    total = deterministic + non_deterministic
    pct = (deterministic / total * 100) if total > 0 else 0

    summary = {
        "provider": provider_name,
        "num_runs": num_runs,
        "total_segments": total,
        "deterministic": deterministic,
        "non_deterministic": non_deterministic,
        "determinism_pct": round(pct, 1),
        "by_language": dict(lang_stats),
        "diffs": diffs[:20],
    }

    logger.info(f"\n=== DETERMINISM RESULTS ({provider_name}) ===")
    logger.info(f"  {deterministic}/{total} segments deterministic ({pct:.1f}%)")
    for lang, stats in sorted(lang_stats.items()):
        logger.info(f"  {lang}: {stats['deterministic']}/{stats['total']}")
    if diffs:
        logger.info(f"  Non-deterministic segments:")
        for d in diffs[:5]:
            logger.info(f"    {d['segment']}: first diff at char {d['first_diff_char']}")

    return {
        "summary": summary,
        "run_details": run_summaries,
    }


async def run_concurrency_test(
    requests_by_lang: dict[str, list],
    provider,
    provider_name: str,
    target_concurrency: int = CONCURRENCY_TARGET,
) -> dict:
    """Fire target_concurrency requests simultaneously, measure throughput."""
    logger.info(f"\n{'='*70}")
    logger.info(f"CONCURRENCY TEST: {provider_name} ({target_concurrency} concurrent)")
    logger.info(f"{'='*70}")

    # Build a big batch by repeating segments across all languages
    all_requests = []
    for lang_code, requests in sorted(requests_by_lang.items()):
        all_requests.extend(requests)

    # Duplicate to reach target if needed
    if len(all_requests) < target_concurrency:
        multiplier = (target_concurrency // len(all_requests)) + 1
        expanded = []
        for i in range(multiplier):
            for req in all_requests:
                expanded.append(TranscriptionRequest(
                    segment_id=f"{req.segment_id}_rep{i}",
                    audio_base64=req.audio_base64,
                    language_code=req.language_code,
                    original_file=req.original_file,
                ))
        all_requests = expanded[:target_concurrency]

    logger.info(f"Sending {len(all_requests)} requests concurrently...")
    start = time.monotonic()
    responses = await provider.send_batch(all_requests)
    wall_time = time.monotonic() - start

    successes = [r for r in responses if r.status == RequestStatus.SUCCESS]
    errors = [r for r in responses if r.status == RequestStatus.ERROR]
    rate_limited = [r for r in responses if r.status == RequestStatus.RATE_LIMITED]
    timeouts = [r for r in responses if r.status == RequestStatus.TIMEOUT]
    latencies = [r.latency_ms for r in successes if r.latency_ms > 0]

    total_input = sum(r.token_usage.input_tokens for r in successes)
    total_output = sum(r.token_usage.output_tokens for r in successes)
    total_cached = sum(r.token_usage.cached_tokens for r in successes)
    cache_hits = sum(1 for r in successes if r.token_usage.cache_hit)

    summary = {
        "provider": provider_name,
        "target_concurrency": target_concurrency,
        "actual_sent": len(all_requests),
        "successes": len(successes),
        "errors": len(errors),
        "rate_limited": len(rate_limited),
        "timeouts": len(timeouts),
        "wall_time_s": round(wall_time, 2),
        "throughput_rps": round(len(successes) / wall_time, 1) if wall_time > 0 else 0,
        "avg_latency_ms": round(sum(latencies) / len(latencies), 1) if latencies else 0,
        "p50_latency_ms": round(sorted(latencies)[len(latencies)//2], 1) if latencies else 0,
        "p95_latency_ms": round(sorted(latencies)[int(len(latencies)*0.95)], 1) if latencies else 0,
        "max_latency_ms": round(max(latencies), 1) if latencies else 0,
        "total_input_tokens": total_input,
        "total_output_tokens": total_output,
        "total_cached_tokens": total_cached,
        "cache_hits": cache_hits,
        "cache_hit_rate": round(cache_hits / max(len(successes), 1), 3),
    }

    # Collect error details
    error_details = []
    for r in errors[:10]:
        error_details.append({"segment": r.segment_id, "error": r.error_message[:200]})
    for r in rate_limited[:5]:
        error_details.append({"segment": r.segment_id, "status": "429", "error": r.error_message[:200]})

    logger.info(f"\n=== CONCURRENCY RESULTS ({provider_name}) ===")
    logger.info(f"  {len(successes)}/{len(all_requests)} OK, {len(errors)} errors, {len(rate_limited)} 429s, {len(timeouts)} timeouts")
    logger.info(f"  Wall time: {wall_time:.1f}s, Throughput: {summary['throughput_rps']} req/s")
    logger.info(f"  Latency: avg={summary['avg_latency_ms']}ms, p50={summary['p50_latency_ms']}ms, p95={summary['p95_latency_ms']}ms")
    logger.info(f"  Tokens: input={total_input}, output={total_output}, cached={total_cached}")
    logger.info(f"  Cache hits: {cache_hits}/{len(successes)} ({summary['cache_hit_rate']*100:.1f}%)")
    if error_details:
        logger.info(f"  Sample errors:")
        for e in error_details[:3]:
            logger.info(f"    {e}")

    return {"summary": summary, "errors": error_details}


async def main():
    config = EnvConfig()
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)

    # Phase 1: Download and prepare segments
    logger.info("=" * 70)
    logger.info("PHASE 1: Downloading canary videos for all 12 languages")
    logger.info("=" * 70)

    requests_by_lang = await download_and_extract(config)
    total_segments = sum(len(r) for r in requests_by_lang.values())
    logger.info(f"\nPrepared {total_segments} segments across {len(requests_by_lang)} languages")
    save_result("phase1_preparation", {
        "languages": {k: len(v) for k, v in requests_by_lang.items()},
        "total_segments": total_segments,
    })

    if not requests_by_lang:
        logger.error("No segments available. Aborting.")
        return

    # Phase 2a: Determinism test on AI Studio (with V2 cache)
    logger.info("\n\nSetting up AI Studio with V2 cache...")
    cm = CacheManager(config.gemini_key)
    cache_name = await cm.ensure_cache()
    aistudio = AIStudioProvider(api_key=config.gemini_key, cached_content_name=cache_name)

    det_aistudio = await run_determinism_test(requests_by_lang, aistudio, "aistudio", DETERMINISM_RUNS)
    save_result("phase2a_determinism_aistudio", det_aistudio)

    # Phase 2b: Determinism test on OpenRouter (V2 prompt, no explicit cache)
    logger.info("\n\nSetting up OpenRouter...")
    openrouter = OpenRouterProvider(api_key=config.openrouter_api_key)

    det_openrouter = await run_determinism_test(requests_by_lang, openrouter, "openrouter", DETERMINISM_RUNS)
    save_result("phase2b_determinism_openrouter", det_openrouter)

    # Phase 3: 500-concurrency on AI Studio
    conc_aistudio = await run_concurrency_test(requests_by_lang, aistudio, "aistudio", CONCURRENCY_TARGET)
    save_result("phase3_concurrency_aistudio", conc_aistudio)

    # Phase 4: 500-concurrency on OpenRouter
    conc_openrouter = await run_concurrency_test(requests_by_lang, openrouter, "openrouter", CONCURRENCY_TARGET)
    save_result("phase4_concurrency_openrouter", conc_openrouter)

    # Combined summary
    combined = {
        "languages_tested": list(requests_by_lang.keys()),
        "segments_per_language": {k: len(v) for k, v in requests_by_lang.items()},
        "determinism": {
            "aistudio": det_aistudio["summary"],
            "openrouter": det_openrouter["summary"],
        },
        "concurrency": {
            "aistudio": conc_aistudio["summary"],
            "openrouter": conc_openrouter["summary"],
        },
    }
    save_result("combined_summary", combined)

    logger.info("\n" + "=" * 70)
    logger.info("CANARY TEST COMPLETE")
    logger.info("=" * 70)
    logger.info(f"  Languages: {len(requests_by_lang)}")
    logger.info(f"  AI Studio determinism: {det_aistudio['summary']['determinism_pct']}%")
    logger.info(f"  OpenRouter determinism: {det_openrouter['summary']['determinism_pct']}%")
    logger.info(f"  AI Studio concurrency: {conc_aistudio['summary']['successes']}/{conc_aistudio['summary']['actual_sent']} OK")
    logger.info(f"  OpenRouter concurrency: {conc_openrouter['summary']['successes']}/{conc_openrouter['summary']['actual_sent']} OK")


if __name__ == "__main__":
    asyncio.run(main())
