#!/usr/bin/env python3
"""
Generate benchmark_outputs conforming to BENCHMARK_SCHEMA.md v1.

Reads raw predictions from benchmark_results/ and produces:
  - metrics.json       (4 metric tiers + normalization_delta + meta)
  - sample_analysis.json  (per-sample with flags)
  - error_analysis.json   (top subs/ins/del + error buckets + diagnosis)
"""

import json
import re
import sys
import unicodedata
from collections import Counter, defaultdict
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
from jiwer import cer, wer

# ── Constants ────────────────────────────────────────────────────────────

NORMALIZATION_VERSION = "v1"
GPU_NAME = "NVIDIA H200 80GB"
MODEL_ID = "qwen3-asr"
MODEL_TYPE = "Qwen3-ASR-1.7B"
DATASET = "BayAreaBoys/indic-asr-benchmark-6k"

LATIN_SCRIPT_LANGS = {"english"}

# Zero-width characters to strip
ZW_CHARS = re.compile(r"[\u200b-\u200f\u202a-\u202e\ufeff\u00ad]")

# Language-aware punctuation sets
# Base punctuation (safe to remove for all languages)
BASE_PUNCT = set("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~")
# Indic punctuation
INDIC_PUNCT = {
    "\u0964",  # Devanagari danda ।
    "\u0965",  # Devanagari double danda ॥
    "\u0970",  # Devanagari abbreviation sign
}
# We preserve danda for Indic but remove double-danda and others
INDIC_REMOVE_PUNCT = {"\u0965", "\u0970"}

# Curly quotes / special dashes to normalize before removal
QUOTE_MAP = {
    "\u2018": "'", "\u2019": "'", "\u201c": '"', "\u201d": '"',
    "\u2013": "-", "\u2014": "-", "\u2015": "-",
    "\u2026": "...",
}

# Number word mappings for English
EN_NUMBER_WORDS = {
    "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4",
    "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9",
    "ten": "10", "eleven": "11", "twelve": "12", "thirteen": "13",
    "fourteen": "14", "fifteen": "15", "sixteen": "16", "seventeen": "17",
    "eighteen": "18", "nineteen": "19", "twenty": "20", "thirty": "30",
    "forty": "40", "fifty": "50", "sixty": "60", "seventy": "70",
    "eighty": "80", "ninety": "90", "hundred": "100", "thousand": "1000",
    "lakh": "100000", "crore": "10000000",
}

# Hindi number words (Devanagari)
HI_NUMBER_WORDS = {
    "शून्य": "0", "एक": "1", "दो": "2", "तीन": "3", "चार": "4",
    "पांच": "5", "छह": "6", "सात": "7", "आठ": "8", "नौ": "9",
    "दस": "10", "ग्यारह": "11", "बारह": "12", "तेरह": "13",
    "चौदह": "14", "पंद्रह": "15", "सोलह": "16", "सत्रह": "17",
    "अठारह": "18", "उन्नीस": "19", "बीस": "20",
    "सौ": "100", "हज़ार": "1000", "हजार": "1000",
    "लाख": "100000", "करोड़": "10000000",
}


# ── Normalization Pipeline ───────────────────────────────────────────────

def norm_raw(text: str) -> str:
    """Tier 1: NFC unicode + trim only."""
    text = unicodedata.normalize("NFC", text)
    text = text.strip()
    return text


def norm_standard(text: str, language: str = "") -> str:
    """Tier 2 (wer_norm): NFKC + strip ZW + normalize whitespace + standardize punct + remove punct + case fold."""
    # Step 1: NFKC
    text = unicodedata.normalize("NFKC", text)
    # Step 2: Strip zero-width chars
    text = ZW_CHARS.sub("", text)
    # Step 3: Normalize whitespace
    text = re.sub(r"\s+", " ", text).strip()
    # Step 4: Standardize punctuation variants
    for old, new in QUOTE_MAP.items():
        text = text.replace(old, new)
    # Normalize double danda to single danda
    text = text.replace("\u0965", "\u0964")
    # Step 5: Remove punctuation (language-aware)
    # For Indic: keep danda (।) as it's segmentation-relevant... actually schema says remove
    # Schema: "Remove punctuation (norm/numcanon only, using language-aware set)"
    # Remove all base punctuation
    text = "".join(c for c in text if c not in BASE_PUNCT)
    # Remove Indic punctuation too (danda, double danda, abbreviation)
    text = "".join(c for c in text if c not in INDIC_PUNCT and c not in INDIC_REMOVE_PUNCT)
    # Step 6: Case fold (English/Latin only)
    if language in LATIN_SCRIPT_LANGS:
        text = text.lower()
    else:
        text = text.lower()  # safe: Indic scripts have no case, lowering is a no-op
    # Collapse whitespace again after removals
    text = re.sub(r"\s+", " ", text).strip()
    return text


def norm_mer(text: str, language: str = "") -> str:
    """Tier 4 (MER): norm_standard + remove ALL spaces."""
    text = norm_standard(text, language)
    text = text.replace(" ", "")
    return text


# ── space_norm_wer computation ───────────────────────────────────────

def space_norm_wer_sample(ref_norm: str, hyp_norm: str):
    """
    Word error rate after space-insensitive character alignment.
    Returns (error_words, total_words).
    """
    ref_words = ref_norm.split()
    if not ref_words:
        return (0, 0)

    ref_nospace = ref_norm.replace(" ", "")
    hyp_nospace = hyp_norm.replace(" ", "")
    total_words = len(ref_words)

    if ref_nospace == hyp_nospace:
        return (0, total_words)

    # char_to_word: ref_nospace[i] → word index
    char_to_word = []
    for word_idx, word in enumerate(ref_words):
        for _ in word:
            char_to_word.append(word_idx)

    n = len(ref_nospace)
    m = len(hyp_nospace)

    # Levenshtein DP
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n + 1):
        dp[i][0] = i
    for j in range(m + 1):
        dp[0][j] = j
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref_nospace[i - 1] == hyp_nospace[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i - 1][j - 1], dp[i - 1][j], dp[i][j - 1])

    # Backtrack and mark touched words
    touched = set()
    i, j = n, m
    while i > 0 or j > 0:
        if i > 0 and j > 0 and ref_nospace[i - 1] == hyp_nospace[j - 1] and dp[i][j] == dp[i - 1][j - 1]:
            i -= 1; j -= 1
        elif i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + 1:
            touched.add(char_to_word[i - 1]); i -= 1; j -= 1
        elif i > 0 and dp[i][j] == dp[i - 1][j] + 1:
            touched.add(char_to_word[i - 1]); i -= 1
        elif j > 0 and dp[i][j] == dp[i][j - 1] + 1:
            if i > 0:
                touched.add(char_to_word[i - 1])
            elif i < n:
                touched.add(char_to_word[i])
            j -= 1
        else:
            break

    return (len(touched), total_words)


def norm_numcanon(text: str, language: str = "") -> str:
    """Tier 3 (wer_numcanon): norm_standard + number canonicalization."""
    text = norm_standard(text, language)
    # Remove digit grouping (commas in numbers)
    text = re.sub(r"(\d),(\d)", r"\1\2", text)
    text = re.sub(r"(\d),(\d)", r"\1\2", text)  # repeat for 1,00,000 style
    # For English, replace simple number words with digits
    if language == "english":
        words = text.split()
        out = []
        for w in words:
            if w in EN_NUMBER_WORDS:
                out.append(EN_NUMBER_WORDS[w])
            else:
                out.append(w)
        text = " ".join(out)
    # For Hindi
    if language == "hindi":
        words = text.split()
        out = []
        for w in words:
            if w in HI_NUMBER_WORDS:
                out.append(HI_NUMBER_WORDS[w])
            else:
                out.append(w)
        text = " ".join(out)
    return text


# ── Per-sample WER ───────────────────────────────────────────────────────

def sample_wer(ref: str, hyp: str) -> float:
    """Compute WER for a single sample. Returns 0-100 float."""
    if not ref.strip():
        return 0.0 if not hyp.strip() else 100.0
    if not hyp.strip():
        return 100.0
    return round(wer(ref, hyp) * 100, 2)


# ── Flag detection ───────────────────────────────────────────────────────

def detect_flags(ref: str, hyp: str, ref_n: str, hyp_n: str,
                 detected_lang: str, expected_lang: str, wer_norm_val: float) -> list:
    flags = []
    if ref == hyp:
        flags.append("exact_match")
    if ref_n == hyp_n:
        flags.append("exact_match_norm")
    if not hyp.strip():
        flags.append("empty_hypothesis")
    if ref != hyp and ref_n == hyp_n:
        flags.append("punctuation_only_diff")
    # Numeric mismatch: digits present in either but differ
    ref_digits = re.findall(r"\d+", ref_n)
    hyp_digits = re.findall(r"\d+", hyp_n)
    if ref_digits != hyp_digits and (ref_digits or hyp_digits):
        flags.append("numeric_mismatch")
    # Script mismatch: check if hypothesis has different unicode blocks than reference
    if hyp.strip() and ref.strip():
        ref_scripts = set(unicodedata.script(c) for c in ref if not c.isspace() and c.isalpha()) if hasattr(unicodedata, 'script') else set()
        # Simplified: check if any non-Latin non-Common chars in hyp are from a different block than ref
        # Use character ranges as proxy
        ref_devanagari = bool(re.search(r"[\u0900-\u097F]", ref))
        hyp_devanagari = bool(re.search(r"[\u0900-\u097F]", hyp))
        ref_bengali = bool(re.search(r"[\u0980-\u09FF]", ref))
        hyp_bengali = bool(re.search(r"[\u0980-\u09FF]", hyp))
        ref_gujarati = bool(re.search(r"[\u0A80-\u0AFF]", ref))
        hyp_gujarati = bool(re.search(r"[\u0A80-\u0AFF]", hyp))
        ref_gurmukhi = bool(re.search(r"[\u0A00-\u0A7F]", ref))
        hyp_gurmukhi = bool(re.search(r"[\u0A00-\u0A7F]", hyp))
        # If ref has one script but hyp has a different one prominently
        if (ref_devanagari and not hyp_devanagari and hyp_gujarati) or \
           (ref_bengali and not hyp_bengali and hyp_gurmukhi) or \
           (ref_gujarati and not hyp_gujarati and hyp_devanagari):
            flags.append("script_mismatch")
    # Language confusion
    if detected_lang and expected_lang:
        if detected_lang.lower() != expected_lang.lower():
            flags.append("lang_confusion")
    if wer_norm_val > 80:
        flags.append("high_wer")
    return flags


# ── Error Analysis ───────────────────────────────────────────────────────

def compute_error_analysis(samples_by_lang: dict) -> dict:
    """Compute top substitutions/insertions/deletions per language."""
    analysis = {}

    for lang, samples in sorted(samples_by_lang.items()):
        subs = Counter()
        insertions = Counter()
        deletions = Counter()
        numeric_mismatch_count = 0
        punct_only_count = 0
        spacing_count = 0
        entity_count = 0
        script_confusion_count = 0
        empty_count = 0

        wer_scores = []
        sample_ids = []

        for s in samples:
            ref_words = s["ref_norm"].split()
            hyp_words = s["hyp_norm"].split()
            flags = s.get("flags", [])

            if "empty_hypothesis" in flags:
                empty_count += 1
            if "numeric_mismatch" in flags:
                numeric_mismatch_count += 1
            if "punctuation_only_diff" in flags:
                punct_only_count += 1
            if "script_mismatch" in flags:
                script_confusion_count += 1

            wer_scores.append((s["id"], s.get("wer_norm", 100.0)))

            # Simple alignment for top errors (using set difference as approximation)
            ref_counter = Counter(ref_words)
            hyp_counter = Counter(hyp_words)

            # Deletions: in ref but not hyp
            for word, count in (ref_counter - hyp_counter).items():
                deletions[word] += count

            # Insertions: in hyp but not ref
            for word, count in (hyp_counter - ref_counter).items():
                insertions[word] += count

            # Substitutions: approximate by pairing mismatched words positionally
            min_len = min(len(ref_words), len(hyp_words))
            for i in range(min_len):
                if ref_words[i] != hyp_words[i]:
                    subs[(ref_words[i], hyp_words[i])] += 1

            # Check for entity mismatches (capitalized words that differ)
            for rw, hw in zip(ref_words[:min_len], hyp_words[:min_len]):
                if rw != hw and (rw[0:1].isupper() or any(c.isupper() for c in rw)):
                    entity_count += 1

        # Sort samples by WER for best/worst
        wer_scores.sort(key=lambda x: x[1])
        best = [sid for sid, _ in wer_scores[:3]]
        worst = [sid for sid, _ in wer_scores[-3:]]
        numeric_examples = [s["id"] for s in samples if "numeric_mismatch" in s.get("flags", [])][:3]
        entity_examples = [s["id"] for s in samples if "entity_mismatch" in s.get("flags", [])][:3]

        analysis[lang] = {
            "top_substitutions": [
                {"ref": r, "hyp": h, "count": c}
                for (r, h), c in subs.most_common(20)
            ],
            "top_insertions": [
                {"word": w, "count": c}
                for w, c in insertions.most_common(20)
            ],
            "top_deletions": [
                {"word": w, "count": c}
                for w, c in deletions.most_common(20)
            ],
            "error_buckets": {
                "numeric_mismatch_count": numeric_mismatch_count,
                "punctuation_only_count": punct_only_count,
                "spacing_tokenization_count": spacing_count,
                "entity_mismatch_count": entity_count,
                "script_confusion_count": script_confusion_count,
                "empty_hypothesis_count": empty_count,
            },
            "examples": {
                "worst_samples": worst,
                "best_samples": best,
                "numeric_mismatch_samples": numeric_examples,
                "entity_mismatch_samples": entity_examples,
            }
        }

    # Summary
    all_langs = sorted(samples_by_lang.keys())
    # Determine diagnosis
    # Load from the metrics we'll compute separately - use a simple heuristic here
    analysis["__summary__"] = {
        "model_diagnosis": "recognition-limited",
        "primary_error_source": "recognition",
        "numeric_verbalization_impact": "moderate",
        "formatting_impact": "low",
        "worst_languages": [],  # filled in by caller
        "best_languages": [],
    }

    return analysis


# ── Main Generator ───────────────────────────────────────────────────────

def process_checkpoint(src_dir: str, ckpt_name: str, out_dir: str):
    """Process one checkpoint: read predictions, compute all tiers, write 3 files."""
    preds_path = Path(src_dir) / "predictions.json"
    old_metrics_path = Path(src_dir) / "metrics.json"

    with open(preds_path, encoding="utf-8") as f:
        preds = json.load(f)

    with open(old_metrics_path) as f:
        old_meta = json.load(f).get("__meta__", {})

    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)

    # ── Compute normalized texts for all samples ──
    enriched = []
    for p in preds:
        lang = p["language"]
        ref = p["reference"]
        hyp = p.get("hypothesis", "")

        ref_raw = norm_raw(ref)
        hyp_raw = norm_raw(hyp)
        ref_n = norm_standard(ref, lang)
        hyp_n = norm_standard(hyp, lang)
        ref_nc = norm_numcanon(ref, lang)
        hyp_nc = norm_numcanon(hyp, lang)
        ref_m = norm_mer(ref, lang)
        hyp_m = norm_mer(hyp, lang)

        detected = p.get("detected_language", "")
        expected_lang_name = lang.capitalize()

        wer_raw_val = sample_wer(ref_raw, hyp_raw)
        wer_norm_val = sample_wer(ref_n, hyp_n)
        # MER: use CER on space-stripped text (single char stream comparison)
        if not ref_m:
            mer_val = 0.0 if not hyp_m else 100.0
        elif not hyp_m:
            mer_val = 100.0
        else:
            mer_val = round(cer(ref_m, hyp_m) * 100, 2)

        # space_norm_wer per sample
        snw_err, snw_total = space_norm_wer_sample(ref_n, hyp_n)
        snw_val = round(snw_err / snw_total * 100, 2) if snw_total > 0 else 0.0

        flags = detect_flags(ref_raw, hyp_raw, ref_n, hyp_n,
                             detected, expected_lang_name, wer_norm_val)
        # spacing_error flag: content is right but word boundaries are wrong
        if wer_norm_val > mer_val and mer_val < 100:
            flags.append("spacing_error")

        enriched.append({
            "id": p["id"],
            "language": lang,
            "reference": ref,
            "hypothesis": hyp,
            "ref_raw": ref_raw,
            "hyp_raw": hyp_raw,
            "ref_norm": ref_n,
            "hyp_norm": hyp_n,
            "ref_numcanon": ref_nc,
            "hyp_numcanon": hyp_nc,
            "ref_mer": ref_m,
            "hyp_mer": hyp_m,
            "detected_language": detected,
            "wer_raw": wer_raw_val,
            "wer_norm": wer_norm_val,
            "space_norm_wer": snw_val,
            "snw_err": snw_err,
            "snw_total": snw_total,
            "mer": mer_val,
            "flags": flags,
        })

    # ── Compute per-language metrics ──
    by_lang = defaultdict(list)
    for s in enriched:
        by_lang[s["language"]].append(s)

    metrics = {}
    all_ref_raw, all_hyp_raw = [], []
    all_ref_n, all_hyp_n = [], []
    all_ref_nc, all_hyp_nc = [], []
    all_ref_m, all_hyp_m = [], []
    all_snw_err, all_snw_total = 0, 0

    for lang in sorted(by_lang.keys()):
        samples = by_lang[lang]
        n = len(samples)

        refs_raw = [s["ref_raw"] for s in samples if s["ref_norm"]]
        hyps_raw = [s["hyp_raw"] for s in samples if s["ref_norm"]]
        refs_n = [s["ref_norm"] for s in samples if s["ref_norm"]]
        hyps_n = [s["hyp_norm"] for s in samples if s["ref_norm"]]
        refs_nc = [s["ref_numcanon"] for s in samples if s["ref_norm"]]
        hyps_nc = [s["hyp_numcanon"] for s in samples if s["ref_norm"]]
        refs_m = [s["ref_mer"] for s in samples if s["ref_norm"]]
        hyps_m = [s["hyp_mer"] for s in samples if s["ref_norm"]]

        w_raw = round(wer(refs_raw, hyps_raw) * 100, 2) if refs_raw else 100.0
        w_norm = round(wer(refs_n, hyps_n) * 100, 2) if refs_n else 100.0
        w_nc = round(wer(refs_nc, hyps_nc) * 100, 2) if refs_nc else 100.0
        # MER: CER on space-stripped text (since it's one "word", WER is meaningless)
        w_mer = round(cer(refs_m, hyps_m) * 100, 2) if refs_m else 100.0
        c_norm = round(cer(refs_n, hyps_n) * 100, 2) if refs_n else 100.0
        empty = sum(1 for s in samples if "empty_hypothesis" in s["flags"])

        # space_norm_wer: micro-average over ref words
        lang_snw_err = sum(s["snw_err"] for s in samples if s["ref_norm"])
        lang_snw_total = sum(s["snw_total"] for s in samples if s["ref_norm"])
        w_snw = round(lang_snw_err / lang_snw_total * 100, 2) if lang_snw_total > 0 else 100.0

        metrics[lang] = {
            "n_samples": n,
            "wer_raw": w_raw,
            "wer_norm": w_norm,
            "wer_numcanon": w_nc,
            "space_norm_wer": w_snw,
            "mer": w_mer,
            "cer_norm": c_norm,
            "empty_hypotheses": empty,
            "normalization_delta": {
                "raw_to_norm": round(w_raw - w_norm, 2),
                "norm_to_numcanon": round(w_norm - w_nc, 2),
                "norm_to_space_norm": round(w_norm - w_snw, 2),
                "norm_to_mer": round(w_norm - w_mer, 2),
            }
        }

        all_ref_raw.extend(refs_raw)
        all_hyp_raw.extend(hyps_raw)
        all_ref_n.extend(refs_n)
        all_hyp_n.extend(hyps_n)
        all_ref_nc.extend(refs_nc)
        all_hyp_nc.extend(hyps_nc)
        all_ref_m.extend(refs_m)
        all_hyp_m.extend(hyps_m)
        all_snw_err += lang_snw_err
        all_snw_total += lang_snw_total

    # ── Aggregates ──
    metrics["__overall__"] = {
        "n_samples": len(all_ref_raw),
        "wer_raw": round(wer(all_ref_raw, all_hyp_raw) * 100, 2),
        "wer_norm": round(wer(all_ref_n, all_hyp_n) * 100, 2),
        "wer_numcanon": round(wer(all_ref_nc, all_hyp_nc) * 100, 2),
        "space_norm_wer": round(all_snw_err / all_snw_total * 100, 2) if all_snw_total > 0 else 100.0,
        "mer": round(cer(all_ref_m, all_hyp_m) * 100, 2),
        "cer_norm": round(cer(all_ref_n, all_hyp_n) * 100, 2),
    }

    lang_keys = sorted(k for k in metrics if not k.startswith("__"))
    metrics["__macro_avg__"] = {
        "n_languages": len(lang_keys),
        "wer_raw": round(np.mean([metrics[k]["wer_raw"] for k in lang_keys]), 2),
        "wer_norm": round(np.mean([metrics[k]["wer_norm"] for k in lang_keys]), 2),
        "wer_numcanon": round(np.mean([metrics[k]["wer_numcanon"] for k in lang_keys]), 2),
        "space_norm_wer": round(np.mean([metrics[k]["space_norm_wer"] for k in lang_keys]), 2),
        "mer": round(np.mean([metrics[k]["mer"] for k in lang_keys]), 2),
        "cer_norm": round(np.mean([metrics[k]["cer_norm"] for k in lang_keys]), 2),
    }

    import importlib.metadata
    jiwer_ver = importlib.metadata.version("jiwer")
    metrics["__meta__"] = {
        "checkpoint": old_meta.get("checkpoint", ""),
        "checkpoint_name": ckpt_name,
        "model_id": MODEL_ID,
        "model_type": MODEL_TYPE,
        "dataset": DATASET,
        "batch_size": old_meta.get("batch_size", 64),
        "inference_time_sec": old_meta.get("inference_time_sec", 0),
        "total_audio_sec": old_meta.get("total_audio_sec", 40354.42),
        "rtf": old_meta.get("rtf", 0),
        "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
        "gpu": GPU_NAME,
        "framework": "vllm",
        "normalization_version": NORMALIZATION_VERSION,
        "jiwer_version": jiwer_ver,
    }

    # ── sample_analysis.json ──
    sample_analysis = []
    for s in enriched:
        sample_analysis.append({
            "id": s["id"],
            "language": s["language"],
            "reference": s["reference"],
            "hypothesis": s["hypothesis"],
            "ref_norm": s["ref_norm"],
            "hyp_norm": s["hyp_norm"],
            "ref_numcanon": s["ref_numcanon"],
            "hyp_numcanon": s["hyp_numcanon"],
            "ref_mer": s["ref_mer"],
            "hyp_mer": s["hyp_mer"],
            "detected_language": s["detected_language"],
            "wer_raw": s["wer_raw"],
            "wer_norm": s["wer_norm"],
            "space_norm_wer": s["space_norm_wer"],
            "mer": s["mer"],
            "flags": s["flags"],
        })

    # ── error_analysis.json ──
    samples_for_errors = defaultdict(list)
    for s in sample_analysis:
        samples_for_errors[s["language"]].append(s)

    error_analysis = compute_error_analysis(samples_for_errors)

    # Fill in summary worst/best languages
    sorted_langs = sorted(lang_keys, key=lambda k: metrics[k]["wer_norm"])
    error_analysis["__summary__"]["best_languages"] = sorted_langs[:3]
    error_analysis["__summary__"]["worst_languages"] = sorted_langs[-3:]

    # Determine diagnosis from normalization deltas
    avg_raw_to_norm = np.mean([metrics[k]["normalization_delta"]["raw_to_norm"] for k in lang_keys])
    avg_norm_to_nc = np.mean([metrics[k]["normalization_delta"]["norm_to_numcanon"] for k in lang_keys])

    if avg_raw_to_norm > 10:
        error_analysis["__summary__"]["model_diagnosis"] = "formatting-limited"
        error_analysis["__summary__"]["formatting_impact"] = "high"
    elif avg_norm_to_nc > 5:
        error_analysis["__summary__"]["model_diagnosis"] = "numeric-limited"
        error_analysis["__summary__"]["numeric_verbalization_impact"] = "high"
    else:
        error_analysis["__summary__"]["model_diagnosis"] = "recognition-limited"
        error_analysis["__summary__"]["formatting_impact"] = "low" if avg_raw_to_norm < 5 else "moderate"
        error_analysis["__summary__"]["numeric_verbalization_impact"] = "low" if avg_norm_to_nc < 1 else "moderate"

    # ── Write files ──
    with open(out_path / "metrics.json", "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    with open(out_path / "sample_analysis.json", "w", encoding="utf-8") as f:
        json.dump(sample_analysis, f, indent=2, ensure_ascii=False)

    with open(out_path / "error_analysis.json", "w", encoding="utf-8") as f:
        json.dump(error_analysis, f, indent=2, ensure_ascii=False)

    ov = metrics["__overall__"]
    print(f"  {ckpt_name}: wer_raw={ov['wer_raw']}  wer_norm={ov['wer_norm']}  "
          f"wer_numcanon={ov['wer_numcanon']}  cer_norm={ov['cer_norm']}")
    return metrics


# ── Entry point ──────────────────────────────────────────────────────────

if __name__ == "__main__":
    CHECKPOINTS = [
        ("ckpt-24000", "ckpt-24000-vllm"),
        ("ckpt-72000", "ckpt-72000-vllm"),
        ("ckpt-100000", "ckpt-100000-vllm"),
        ("ckpt-170000", "ckpt-170000-vllm"),
        ("ckpt-200000", "ckpt-200000-vllm"),
        ("ckpt-250000", "ckpt-250000-vllm"),
        ("ckpt-300000", "ckpt-300000-vllm"),
    ]

    base_results = "/home/ubuntu/training/benchmark_results"
    base_output = "/home/ubuntu/training/benchmark_outputs/qwen3-asr"

    print("Generating schema-v1 outputs for qwen3-asr checkpoints...")
    for ckpt_name, src_dir_name in CHECKPOINTS:
        src_dir = f"{base_results}/{src_dir_name}"
        out_dir = f"{base_output}/{ckpt_name}"
        process_checkpoint(src_dir, ckpt_name, out_dir)

    print("\nDone. All outputs in /home/ubuntu/training/benchmark_outputs/qwen3-asr/")
