#!/usr/bin/env python3
"""
Benchmark Cohere-Transcribe-Indic checkpoints on indic-asr-benchmark-6k dataset.

Usage:
    python benchmark_cohere_transcribe.py \
        --checkpoint /home/ubuntu/training/checkpoints/cohere-transcribe-ckpt-10000 \
        --checkpoint-name ckpt-10000 \
        --batch-size 8

Produces schema-v1 benchmark_outputs (metrics.json, sample_analysis.json, error_analysis.json).
"""

import argparse
import json
import os
import re
import sys
import time
import unicodedata
from collections import Counter, defaultdict
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
import torch
import torchaudio
from datasets import Audio, load_dataset as hf_load_dataset
from jiwer import cer, wer

# Add repo for tokenizer_utils
sys.path.insert(0, "/home/ubuntu/training/cohere-transcribe-indic")
from tokenizer_utils import decode_tokens, load_extended_tokenizer

# ── Constants ────────────────────────────────────────────────────────────

NORMALIZATION_VERSION = "v1"
GPU_NAME = "NVIDIA H200 80GB"
MODEL_ID = "cohere-transcribe"
MODEL_TYPE = "Cohere-Transcribe-Indic-2B"
DATASET = "BayAreaBoys/indic-asr-benchmark-6k"

LANG_CODE_TO_NAME = {
    "as": "Assamese", "bn": "Bengali", "en": "English", "gu": "Gujarati",
    "hi": "Hindi", "kn": "Kannada", "ml": "Malayalam", "mr": "Marathi",
    "or": "Odia", "pa": "Punjabi", "ta": "Tamil", "te": "Telugu",
}

# Map language names to model's language codes
LANG_NAME_TO_CODE = {v.lower(): k for k, v in LANG_CODE_TO_NAME.items()}

LATIN_SCRIPT_LANGS = {"english"}

# Zero-width characters to strip
ZW_CHARS = re.compile(r"[\u200b-\u200f\u202a-\u202e\ufeff\u00ad]")

# Language-aware punctuation sets
BASE_PUNCT = set("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~")
INDIC_PUNCT = {"\u0964", "\u0965", "\u0970"}
INDIC_REMOVE_PUNCT = {"\u0965", "\u0970"}

QUOTE_MAP = {
    "\u2018": "'", "\u2019": "'", "\u201c": '"', "\u201d": '"',
    "\u2013": "-", "\u2014": "-", "\u2015": "-", "\u2026": "...",
}

EN_NUMBER_WORDS = {
    "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4",
    "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9",
    "ten": "10", "eleven": "11", "twelve": "12", "thirteen": "13",
    "fourteen": "14", "fifteen": "15", "sixteen": "16", "seventeen": "17",
    "eighteen": "18", "nineteen": "19", "twenty": "20", "thirty": "30",
    "forty": "40", "fifty": "50", "sixty": "60", "seventy": "70",
    "eighty": "80", "ninety": "90", "hundred": "100", "thousand": "1000",
    "lakh": "100000", "crore": "10000000",
}

HI_NUMBER_WORDS = {
    "शून्य": "0", "एक": "1", "दो": "2", "तीन": "3", "चार": "4",
    "पांच": "5", "छह": "6", "सात": "7", "आठ": "8", "नौ": "9",
    "दस": "10", "ग्यारह": "11", "बारह": "12", "तेरह": "13",
    "चौदह": "14", "पंद्रह": "15", "सोलह": "16", "सत्रह": "17",
    "अठारह": "18", "उन्नीस": "19", "बीस": "20",
    "सौ": "100", "हज़ार": "1000", "हजार": "1000",
    "लाख": "100000", "करोड़": "10000000",
}


# ── Normalization Pipeline ───────────────────────────────────────────────

def norm_raw(text: str) -> str:
    text = unicodedata.normalize("NFC", text)
    return text.strip()


def norm_standard(text: str, language: str = "") -> str:
    text = unicodedata.normalize("NFKC", text)
    text = ZW_CHARS.sub("", text)
    text = re.sub(r"\s+", " ", text).strip()
    for old, new in QUOTE_MAP.items():
        text = text.replace(old, new)
    text = text.replace("\u0965", "\u0964")
    text = "".join(c for c in text if c not in BASE_PUNCT)
    text = "".join(c for c in text if c not in INDIC_PUNCT and c not in INDIC_REMOVE_PUNCT)
    text = text.lower()
    text = re.sub(r"\s+", " ", text).strip()
    return text


def norm_numcanon(text: str, language: str = "") -> str:
    text = norm_standard(text, language)
    text = re.sub(r"(\d),(\d)", r"\1\2", text)
    text = re.sub(r"(\d),(\d)", r"\1\2", text)
    if language == "english":
        words = text.split()
        text = " ".join(EN_NUMBER_WORDS.get(w, w) for w in words)
    if language == "hindi":
        words = text.split()
        text = " ".join(HI_NUMBER_WORDS.get(w, w) for w in words)
    return text


def norm_mer(text: str, language: str = "") -> str:
    return norm_standard(text, language).replace(" ", "")


# ── space_norm_wer computation ───────────────────────────────────────

def space_norm_wer_sample(ref_norm: str, hyp_norm: str):
    ref_words = ref_norm.split()
    if not ref_words:
        return (0, 0)
    ref_nospace = ref_norm.replace(" ", "")
    hyp_nospace = hyp_norm.replace(" ", "")
    total_words = len(ref_words)
    if ref_nospace == hyp_nospace:
        return (0, total_words)

    char_to_word = []
    for word_idx, word in enumerate(ref_words):
        for _ in word:
            char_to_word.append(word_idx)

    n = len(ref_nospace)
    m = len(hyp_nospace)

    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n + 1):
        dp[i][0] = i
    for j in range(m + 1):
        dp[0][j] = j
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref_nospace[i - 1] == hyp_nospace[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i - 1][j - 1], dp[i - 1][j], dp[i][j - 1])

    touched = set()
    i, j = n, m
    while i > 0 or j > 0:
        if i > 0 and j > 0 and ref_nospace[i-1] == hyp_nospace[j-1] and dp[i][j] == dp[i-1][j-1]:
            i -= 1; j -= 1
        elif i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + 1:
            touched.add(char_to_word[i-1]); i -= 1; j -= 1
        elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
            touched.add(char_to_word[i-1]); i -= 1
        elif j > 0 and dp[i][j] == dp[i][j-1] + 1:
            if i > 0:
                touched.add(char_to_word[i-1])
            elif i < n:
                touched.add(char_to_word[i])
            j -= 1
        else:
            break

    return (len(touched), total_words)


# ── Per-sample WER ───────────────────────────────────────────────────────

def sample_wer(ref: str, hyp: str) -> float:
    if not ref.strip():
        return 0.0 if not hyp.strip() else 100.0
    if not hyp.strip():
        return 100.0
    return round(wer(ref, hyp) * 100, 2)


# ── Flag detection ───────────────────────────────────────────────────────

def detect_flags(ref: str, hyp: str, ref_n: str, hyp_n: str,
                 wer_norm_val: float, mer_val: float,
                 detected_language: str = "", expected_language: str = "") -> list:
    flags = []
    if ref == hyp:
        flags.append("exact_match")
    if ref_n == hyp_n:
        flags.append("exact_match_norm")
    if not hyp.strip():
        flags.append("empty_hypothesis")
    if ref != hyp and ref_n == hyp_n:
        flags.append("punctuation_only_diff")
    ref_digits = re.findall(r"\d+", ref_n)
    hyp_digits = re.findall(r"\d+", hyp_n)
    if ref_digits != hyp_digits and (ref_digits or hyp_digits):
        flags.append("numeric_mismatch")
    if wer_norm_val > 80:
        flags.append("high_wer")
    if wer_norm_val > mer_val and mer_val < 100:
        flags.append("spacing_error")
    # Language/script confusion
    if detected_language and expected_language:
        if detected_language.lower() != expected_language.lower():
            flags.append("lang_confusion")
            # Check if it's a script mismatch (e.g. Latin output for Devanagari expected)
            ref_script = detect_output_language(ref)
            hyp_script = detect_output_language(hyp)
            if ref_script and hyp_script and ref_script != hyp_script:
                flags.append("script_mismatch")
    return flags


# ── Error Analysis ───────────────────────────────────────────────────────

def compute_error_analysis(samples_by_lang: dict) -> dict:
    analysis = {}

    for lang, samples in sorted(samples_by_lang.items()):
        subs = Counter()
        insertions = Counter()
        deletions = Counter()
        numeric_mismatch_count = 0
        punct_only_count = 0
        spacing_count = 0
        entity_count = 0
        script_confusion_count = 0
        empty_count = 0

        wer_scores = []

        for s in samples:
            ref_words = s["ref_norm"].split()
            hyp_words = s["hyp_norm"].split()
            flags = s.get("flags", [])

            if "empty_hypothesis" in flags: empty_count += 1
            if "numeric_mismatch" in flags: numeric_mismatch_count += 1
            if "punctuation_only_diff" in flags: punct_only_count += 1
            if "script_mismatch" in flags: script_confusion_count += 1
            if "spacing_error" in flags: spacing_count += 1

            wer_scores.append((s["id"], s.get("wer_norm", 100.0)))

            ref_counter = Counter(ref_words)
            hyp_counter = Counter(hyp_words)
            for word, count in (ref_counter - hyp_counter).items():
                deletions[word] += count
            for word, count in (hyp_counter - ref_counter).items():
                insertions[word] += count
            min_len = min(len(ref_words), len(hyp_words))
            for i in range(min_len):
                if ref_words[i] != hyp_words[i]:
                    subs[(ref_words[i], hyp_words[i])] += 1
            for rw, hw in zip(ref_words[:min_len], hyp_words[:min_len]):
                if rw != hw and (rw[0:1].isupper() or any(c.isupper() for c in rw)):
                    entity_count += 1

        wer_scores.sort(key=lambda x: x[1])
        best = [sid for sid, _ in wer_scores[:3]]
        worst = [sid for sid, _ in wer_scores[-3:]]
        numeric_examples = [s["id"] for s in samples if "numeric_mismatch" in s.get("flags", [])][:3]
        entity_examples = [s["id"] for s in samples if "entity_mismatch" in s.get("flags", [])][:3]

        analysis[lang] = {
            "top_substitutions": [
                {"ref": r, "hyp": h, "count": c} for (r, h), c in subs.most_common(20)
            ],
            "top_insertions": [
                {"word": w, "count": c} for w, c in insertions.most_common(20)
            ],
            "top_deletions": [
                {"word": w, "count": c} for w, c in deletions.most_common(20)
            ],
            "error_buckets": {
                "numeric_mismatch_count": numeric_mismatch_count,
                "punctuation_only_count": punct_only_count,
                "spacing_tokenization_count": spacing_count,
                "entity_mismatch_count": entity_count,
                "script_confusion_count": script_confusion_count,
                "empty_hypothesis_count": empty_count,
            },
            "examples": {
                "worst_samples": worst,
                "best_samples": best,
                "numeric_mismatch_samples": numeric_examples,
                "entity_mismatch_samples": entity_examples,
            }
        }

    analysis["__summary__"] = {
        "model_diagnosis": "recognition-limited",
        "primary_error_source": "recognition",
        "numeric_verbalization_impact": "moderate",
        "formatting_impact": "low",
        "worst_languages": [],
        "best_languages": [],
    }

    return analysis


# ── Script detection ─────────────────────────────────────────────────────

SCRIPT_RANGES = [
    (0x0900, 0x097F, "Devanagari"), (0x0980, 0x09FF, "Bengali"),
    (0x0A00, 0x0A7F, "Gurmukhi"), (0x0A80, 0x0AFF, "Gujarati"),
    (0x0B00, 0x0B7F, "Odia"), (0x0B80, 0x0BFF, "Tamil"),
    (0x0C00, 0x0C7F, "Telugu"), (0x0C80, 0x0CFF, "Kannada"),
    (0x0D00, 0x0D7F, "Malayalam"), (0x0041, 0x007A, "Latin"),
]

SCRIPT_TO_LANG = {
    "Devanagari": "Hindi", "Bengali": "Bengali", "Gurmukhi": "Punjabi",
    "Gujarati": "Gujarati", "Odia": "Odia", "Tamil": "Tamil",
    "Telugu": "Telugu", "Kannada": "Kannada", "Malayalam": "Malayalam",
    "Latin": "English",
}

def detect_output_language(text: str) -> str:
    """Detect the dominant script/language of model output text."""
    from collections import Counter
    scripts = Counter()
    for c in text:
        if c.isalpha():
            cp = ord(c)
            for lo, hi, name in SCRIPT_RANGES:
                if lo <= cp <= hi:
                    scripts[name] += 1
                    break
    if not scripts:
        return ""
    dominant = scripts.most_common(1)[0][0]
    return SCRIPT_TO_LANG.get(dominant, dominant)


# ── Inference ────────────────────────────────────────────────────────────

def _has_degenerate_loop(token_ids: list, min_run: int = 4) -> bool:
    """True if the same token id repeats >= min_run times consecutively."""
    if len(token_ids) < min_run:
        return False
    run = 1
    prev = token_ids[0]
    for t in token_ids[1:]:
        if t == prev:
            run += 1
            if run >= min_run:
                return True
        else:
            run = 1
            prev = t
    return False


def run_inference(model, processor, tokenizer, dataset, batch_size=1, max_new_tokens=256, device="cuda:0",
                  no_repeat_ngram_size=3):
    """Run batched inference on HF dataset.

    Adds defensive decode flags so single-word commands don't loop to max_new_tokens:
      - no_repeat_ngram_size: kills repeat-ngram loops at decode time
    Records per-sample diagnostics (output_token_len, hit_max_tokens, degenerate_loop)
    in each result dict.
    """
    results = []
    total = len(dataset)

    # Pre-cache prompt IDs per language code
    prompt_cache = {}
    for lc in LANG_CODE_TO_NAME:
        prompt_str = model.build_prompt(language=lc, punctuation=True)
        prompt_cache[lc] = tokenizer.encode(prompt_str, add_special_tokens=False)

    eos_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")

    def _process_output_ids(token_ids_tensor, prompt_len: int):
        """Strip prompt prefix, truncate at first EOS, return (decoded_text, gen_ids, hit_max, degenerate)."""
        ids = token_ids_tensor.tolist()
        # Strip prompt prefix (decoder_input_ids was prepended at position 0)
        if len(ids) >= prompt_len and ids[:prompt_len] == ids[:prompt_len]:
            gen_ids = ids[prompt_len:]
        else:
            gen_ids = ids
        # Truncate at first EOS
        try:
            eos_pos = gen_ids.index(eos_id)
            gen_ids_trimmed = gen_ids[:eos_pos]
            hit_max = False
        except ValueError:
            gen_ids_trimmed = gen_ids
            hit_max = len(gen_ids_trimmed) >= max_new_tokens - 1
        degenerate = _has_degenerate_loop(gen_ids_trimmed, min_run=4)
        text = decode_tokens(tokenizer, gen_ids_trimmed).strip()
        return text, len(gen_ids_trimmed), hit_max, degenerate

    for start in range(0, total, batch_size):
        end = min(start + batch_size, total)
        batch = dataset.select(range(start, end))

        # Collect audio arrays
        wavs = []
        batch_meta = []
        for sample in batch:
            audio = sample["audio"]
            wav = np.array(audio["array"], dtype=np.float32)
            sr = audio["sampling_rate"]
            if sr != 16000:
                wav_t = torch.from_numpy(wav).unsqueeze(0)
                wav_t = torchaudio.functional.resample(wav_t, sr, 16000)
                wav = wav_t.squeeze(0).numpy()
            wavs.append(wav)
            batch_meta.append({
                "id": sample["id"],
                "language": sample["language"],
                "lang_code": sample["lang_code"],
                "source": sample["source"],
                "duration": sample["duration"],
                "reference": sample["reference"],
            })

        # Batched mel extraction
        features = processor.feature_extractor(
            wavs, sampling_rate=16000, return_tensors="pt"
        )
        input_features = features["input_features"].to(device, dtype=torch.bfloat16)
        length = features["length"].to(device)

        # Build batched decoder prompts (all prompts are same length)
        prompt_ids_list = [prompt_cache[m["lang_code"]] for m in batch_meta]
        prompt_len = len(prompt_ids_list[0])
        decoder_input_ids = torch.tensor(prompt_ids_list, device=device)

        gen_kwargs = dict(
            input_features=input_features,
            length=length,
            decoder_input_ids=decoder_input_ids,
            max_new_tokens=max_new_tokens,
        )
        if no_repeat_ngram_size and no_repeat_ngram_size > 0:
            gen_kwargs["no_repeat_ngram_size"] = no_repeat_ngram_size

        # Batched generate
        with torch.no_grad():
            try:
                outputs = model.generate(**gen_kwargs)
                for i, meta in enumerate(batch_meta):
                    hyp, n_tok, hit_max, degen = _process_output_ids(outputs[i], prompt_len)
                    detected = detect_output_language(hyp)
                    results.append({
                        **meta, "hypothesis": hyp, "detected_language": detected,
                        "output_token_len": n_tok, "hit_max_tokens": hit_max,
                        "degenerate_loop": degen,
                    })
            except Exception as e:
                print(f"  Batch {start}-{end} failed ({e}), falling back to one-by-one...", flush=True)
                for i, meta in enumerate(batch_meta):
                    try:
                        feat_i = input_features[i:i+1]
                        len_i = length[i:i+1]
                        did_i = torch.tensor([prompt_ids_list[i]], device=device)
                        sgk = dict(input_features=feat_i, length=len_i, decoder_input_ids=did_i,
                                   max_new_tokens=max_new_tokens)
                        if no_repeat_ngram_size and no_repeat_ngram_size > 0:
                            sgk["no_repeat_ngram_size"] = no_repeat_ngram_size
                        out_i = model.generate(**sgk)
                        hyp, n_tok, hit_max, degen = _process_output_ids(out_i[0], prompt_len)
                    except Exception as e2:
                        print(f"  Sample {meta['id']} failed: {e2}")
                        hyp, n_tok, hit_max, degen = "", 0, False, False
                    detected = detect_output_language(hyp)
                    results.append({
                        **meta, "hypothesis": hyp, "detected_language": detected,
                        "output_token_len": n_tok, "hit_max_tokens": hit_max,
                        "degenerate_loop": degen,
                    })

        elapsed_pct = end / total * 100
        print(f"  [{end}/{total}] ({elapsed_pct:.0f}%) processed", flush=True)

    return results


# ── Main ─────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="Benchmark Cohere-Transcribe on Indic ASR dataset")
    parser.add_argument("--checkpoint", required=True, help="Path to model checkpoint directory")
    parser.add_argument("--checkpoint-name", default=None)
    parser.add_argument("--hf-dataset", default="BayAreaBoys/indic-asr-benchmark-6k")
    parser.add_argument("--cache-dir", default="/home/ubuntu/training/datasets/indic-asr-benchmark-6k")
    parser.add_argument("--output-dir", default="/home/ubuntu/training/benchmark_outputs/cohere-transcribe")
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--max-new-tokens", type=int, default=256)
    parser.add_argument("--device", default="cuda:0")
    parser.add_argument("--languages", nargs="+", default=None)
    parser.add_argument("--no-repeat-ngram-size", type=int, default=3,
                        help="Block repeat n-grams during decode (0 to disable). Default 3.")
    args = parser.parse_args()

    ckpt_name = args.checkpoint_name or Path(args.checkpoint).name
    output_dir = Path(args.output_dir) / ckpt_name
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load model
    from transformers import AutoModelForSpeechSeq2Seq, AutoFeatureExtractor

    print(f"Loading checkpoint: {args.checkpoint}")
    t0 = time.time()

    # Load feature extractor from checkpoint (has custom processing code)
    feature_extractor = AutoFeatureExtractor.from_pretrained(
        args.checkpoint, trust_remote_code=True,
    )
    # FilterbankFeatures is lazily created — set device in both _device and
    # _fb_config so the torch.Generator inside FilterbankFeatures is on GPU
    feature_extractor._device = args.device
    feature_extractor._fb_config["device"] = args.device

    # Load extended tokenizer
    tokenizer = load_extended_tokenizer(args.checkpoint)

    # Create a simple processor-like namespace
    class SimpleProcessor:
        pass
    processor = SimpleProcessor()
    processor.feature_extractor = feature_extractor
    processor.tokenizer = tokenizer

    # Load model — checkpoint has different key prefixes from training wrapper,
    # so we load the architecture first, then manually load+remap weights.
    from safetensors.torch import load_file as load_safetensors

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        args.checkpoint,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        ignore_mismatched_sizes=True,
    )

    # Remap checkpoint keys to match model's expected names
    ckpt_weights = load_safetensors(str(Path(args.checkpoint) / "model.safetensors"))
    KEY_REMAP = [
        ("encoder_encoder_decoder_proj.", "encoder_decoder_proj."),
        ("transf_decoder._transf_decoder._decoder.", "transf_decoder._decoder."),
    ]
    remapped = {}
    for k, v in ckpt_weights.items():
        new_k = k
        for old_prefix, new_prefix in KEY_REMAP:
            if new_k.startswith(old_prefix):
                new_k = new_prefix + new_k[len(old_prefix):]
        remapped[new_k] = v

    missing, unexpected = model.load_state_dict(remapped, strict=False)
    if missing:
        print(f"WARNING: {len(missing)} missing keys (first 5): {missing[:5]}")
    if unexpected:
        print(f"WARNING: {len(unexpected)} unexpected keys (first 5): {unexpected[:5]}")

    model = model.to(args.device)
    model.eval()

    load_time = time.time() - t0
    print(f"Model loaded in {load_time:.1f}s")

    # Load dataset
    hf_token = os.environ.get("HF_TOKEN")
    print(f"Loading dataset: {args.hf_dataset}")
    dataset = hf_load_dataset(
        args.hf_dataset, token=hf_token, cache_dir=args.cache_dir, split="train"
    )
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000, decode=True))

    if args.languages:
        dataset = dataset.filter(lambda x: x["lang_code"] in args.languages)
        print(f"Filtered to {len(dataset)} samples for languages: {args.languages}")

    print(f"Dataset: {len(dataset)} samples")

    # Run inference
    print(f"\nRunning inference (batch_size={args.batch_size})...")
    t0 = time.time()
    results = run_inference(
        model, processor, tokenizer, dataset,
        batch_size=args.batch_size,
        max_new_tokens=args.max_new_tokens,
        device=args.device,
        no_repeat_ngram_size=args.no_repeat_ngram_size,
    )
    inference_time = time.time() - t0
    total_audio_sec = sum(r["duration"] for r in results)
    rtf = inference_time / total_audio_sec if total_audio_sec > 0 else 0
    print(f"Inference done in {inference_time:.1f}s for {total_audio_sec:.0f}s audio (RTF={rtf:.4f})")

    # Save raw predictions
    preds_file = output_dir / "predictions.json"
    with open(preds_file, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"Predictions saved: {preds_file}")

    # ── Compute all metric tiers ──
    print("\nComputing metrics...")
    enriched = []
    for p in results:
        lang = p["language"]
        ref = p["reference"]
        hyp = p.get("hypothesis", "")

        ref_raw = norm_raw(ref)
        hyp_raw = norm_raw(hyp)
        ref_n = norm_standard(ref, lang)
        hyp_n = norm_standard(hyp, lang)
        ref_nc = norm_numcanon(ref, lang)
        hyp_nc = norm_numcanon(hyp, lang)
        ref_m = norm_mer(ref, lang)
        hyp_m = norm_mer(hyp, lang)

        wer_raw_val = sample_wer(ref_raw, hyp_raw)
        wer_norm_val = sample_wer(ref_n, hyp_n)

        if not ref_m:
            mer_val = 0.0 if not hyp_m else 100.0
        elif not hyp_m:
            mer_val = 100.0
        else:
            mer_val = round(cer(ref_m, hyp_m) * 100, 2)

        snw_err, snw_total = space_norm_wer_sample(ref_n, hyp_n)
        snw_val = round(snw_err / snw_total * 100, 2) if snw_total > 0 else 0.0

        detected = p.get("detected_language", "")
        expected = LANG_CODE_TO_NAME.get(p.get("lang_code", ""), lang.capitalize())
        flags = detect_flags(ref_raw, hyp_raw, ref_n, hyp_n, wer_norm_val, mer_val,
                             detected_language=detected, expected_language=expected)

        enriched.append({
            "id": p["id"],
            "language": lang,
            "duration": p.get("duration", 0.0),
            "reference": ref,
            "hypothesis": hyp,
            "ref_raw": ref_raw, "hyp_raw": hyp_raw,
            "ref_norm": ref_n, "hyp_norm": hyp_n,
            "ref_numcanon": ref_nc, "hyp_numcanon": hyp_nc,
            "ref_mer": ref_m, "hyp_mer": hyp_m,
            "detected_language": p.get("detected_language", ""),
            "wer_raw": wer_raw_val, "wer_norm": wer_norm_val,
            "space_norm_wer": snw_val, "snw_err": snw_err, "snw_total": snw_total,
            "mer": mer_val, "flags": flags,
            "output_token_len": p.get("output_token_len", 0),
            "hit_max_tokens": p.get("hit_max_tokens", False),
            "degenerate_loop": p.get("degenerate_loop", False),
        })

    # ── Per-language metrics ──
    by_lang = defaultdict(list)
    for s in enriched:
        by_lang[s["language"]].append(s)

    metrics = {}
    all_ref_raw, all_hyp_raw = [], []
    all_ref_n, all_hyp_n = [], []
    all_ref_nc, all_hyp_nc = [], []
    all_ref_m, all_hyp_m = [], []
    all_snw_err, all_snw_total = 0, 0

    for lang in sorted(by_lang.keys()):
        samples = by_lang[lang]
        n = len(samples)

        refs_raw = [s["ref_raw"] for s in samples if s["ref_norm"]]
        hyps_raw = [s["hyp_raw"] for s in samples if s["ref_norm"]]
        refs_n = [s["ref_norm"] for s in samples if s["ref_norm"]]
        hyps_n = [s["hyp_norm"] for s in samples if s["ref_norm"]]
        refs_nc = [s["ref_numcanon"] for s in samples if s["ref_norm"]]
        hyps_nc = [s["hyp_numcanon"] for s in samples if s["ref_norm"]]
        refs_m = [s["ref_mer"] for s in samples if s["ref_norm"]]
        hyps_m = [s["hyp_mer"] for s in samples if s["ref_norm"]]

        w_raw = round(wer(refs_raw, hyps_raw) * 100, 2) if refs_raw else 100.0
        w_norm = round(wer(refs_n, hyps_n) * 100, 2) if refs_n else 100.0
        w_nc = round(wer(refs_nc, hyps_nc) * 100, 2) if refs_nc else 100.0
        w_mer = round(cer(refs_m, hyps_m) * 100, 2) if refs_m else 100.0
        c_norm = round(cer(refs_n, hyps_n) * 100, 2) if refs_n else 100.0
        empty = sum(1 for s in samples if "empty_hypothesis" in s["flags"])

        lang_snw_err = sum(s["snw_err"] for s in samples if s["ref_norm"])
        lang_snw_total = sum(s["snw_total"] for s in samples if s["ref_norm"])
        w_snw = round(lang_snw_err / lang_snw_total * 100, 2) if lang_snw_total > 0 else 100.0

        metrics[lang] = {
            "n_samples": n,
            "wer_raw": w_raw, "wer_norm": w_norm, "wer_numcanon": w_nc,
            "space_norm_wer": w_snw, "mer": w_mer, "cer_norm": c_norm,
            "empty_hypotheses": empty,
            "normalization_delta": {
                "raw_to_norm": round(w_raw - w_norm, 2),
                "norm_to_numcanon": round(w_norm - w_nc, 2),
                "norm_to_space_norm": round(w_norm - w_snw, 2),
                "norm_to_mer": round(w_norm - w_mer, 2),
            }
        }

        all_ref_raw.extend(refs_raw); all_hyp_raw.extend(hyps_raw)
        all_ref_n.extend(refs_n); all_hyp_n.extend(hyps_n)
        all_ref_nc.extend(refs_nc); all_hyp_nc.extend(hyps_nc)
        all_ref_m.extend(refs_m); all_hyp_m.extend(hyps_m)
        all_snw_err += lang_snw_err; all_snw_total += lang_snw_total

    # Aggregates
    metrics["__overall__"] = {
        "n_samples": len(all_ref_raw),
        "wer_raw": round(wer(all_ref_raw, all_hyp_raw) * 100, 2),
        "wer_norm": round(wer(all_ref_n, all_hyp_n) * 100, 2),
        "wer_numcanon": round(wer(all_ref_nc, all_hyp_nc) * 100, 2),
        "space_norm_wer": round(all_snw_err / all_snw_total * 100, 2) if all_snw_total > 0 else 100.0,
        "mer": round(cer(all_ref_m, all_hyp_m) * 100, 2),
        "cer_norm": round(cer(all_ref_n, all_hyp_n) * 100, 2),
    }

    lang_keys = sorted(k for k in metrics if not k.startswith("__"))
    metrics["__macro_avg__"] = {
        "n_languages": len(lang_keys),
        "wer_raw": round(np.mean([metrics[k]["wer_raw"] for k in lang_keys]), 2),
        "wer_norm": round(np.mean([metrics[k]["wer_norm"] for k in lang_keys]), 2),
        "wer_numcanon": round(np.mean([metrics[k]["wer_numcanon"] for k in lang_keys]), 2),
        "space_norm_wer": round(np.mean([metrics[k]["space_norm_wer"] for k in lang_keys]), 2),
        "mer": round(np.mean([metrics[k]["mer"] for k in lang_keys]), 2),
        "cer_norm": round(np.mean([metrics[k]["cer_norm"] for k in lang_keys]), 2),
    }

    # ── Decode stats (per-language and overall) ──
    decode_stats_per_lang = {}
    for lang in sorted(by_lang.keys()):
        samples = by_lang[lang]
        n = len(samples)
        n_hit_max = sum(1 for s in samples if s.get("hit_max_tokens"))
        n_degen = sum(1 for s in samples if s.get("degenerate_loop"))
        avg_tok = float(np.mean([s.get("output_token_len", 0) for s in samples])) if samples else 0.0
        max_tok = max((s.get("output_token_len", 0) for s in samples), default=0)
        decode_stats_per_lang[lang] = {
            "n_samples": n,
            "hit_max_tokens": n_hit_max,
            "degenerate_loop": n_degen,
            "avg_output_tokens": round(avg_tok, 2),
            "max_output_tokens": int(max_tok),
        }
    metrics["__decode_stats__"] = {
        "per_language": decode_stats_per_lang,
        "total_hit_max_tokens": sum(v["hit_max_tokens"] for v in decode_stats_per_lang.values()),
        "total_degenerate_loop": sum(v["degenerate_loop"] for v in decode_stats_per_lang.values()),
        "no_repeat_ngram_size": getattr(args, "no_repeat_ngram_size", 3),
        "max_new_tokens": args.max_new_tokens,
    }

    # ── Filtered subset (duration >= 0.5s) — what serving distribution looks like ──
    MIN_DUR = 0.5
    enriched_filt = [s for s in enriched if s.get("duration", 0.0) >= MIN_DUR and s["ref_norm"]]
    filt_by_lang = defaultdict(list)
    for s in enriched_filt:
        filt_by_lang[s["language"]].append(s)

    filtered = {"min_duration_sec": MIN_DUR, "languages": {}}
    f_all_ref_n, f_all_hyp_n = [], []
    f_all_ref_m, f_all_hyp_m = [], []
    for lang in sorted(filt_by_lang.keys()):
        samples = filt_by_lang[lang]
        refs_n = [s["ref_norm"] for s in samples]
        hyps_n = [s["hyp_norm"] for s in samples]
        refs_m = [s["ref_mer"] for s in samples]
        hyps_m = [s["hyp_mer"] for s in samples]
        if not refs_n:
            continue
        w = round(wer(refs_n, hyps_n) * 100, 2)
        m = round(cer(refs_m, hyps_m) * 100, 2)
        c = round(cer(refs_n, hyps_n) * 100, 2)
        filtered["languages"][lang] = {
            "n_samples": len(samples),
            "n_excluded": metrics[lang]["n_samples"] - len(samples),
            "wer_norm": w, "mer": m, "cer_norm": c,
        }
        f_all_ref_n.extend(refs_n); f_all_hyp_n.extend(hyps_n)
        f_all_ref_m.extend(refs_m); f_all_hyp_m.extend(hyps_m)
    if f_all_ref_n:
        filtered["__overall__"] = {
            "n_samples": len(f_all_ref_n),
            "n_excluded": len(all_ref_n) - len(f_all_ref_n),
            "wer_norm": round(wer(f_all_ref_n, f_all_hyp_n) * 100, 2),
            "mer": round(cer(f_all_ref_m, f_all_hyp_m) * 100, 2),
            "cer_norm": round(cer(f_all_ref_n, f_all_hyp_n) * 100, 2),
        }
        f_lang_keys = sorted(filtered["languages"].keys())
        filtered["__macro_avg__"] = {
            "n_languages": len(f_lang_keys),
            "wer_norm": round(np.mean([filtered["languages"][k]["wer_norm"] for k in f_lang_keys]), 2),
            "mer": round(np.mean([filtered["languages"][k]["mer"] for k in f_lang_keys]), 2),
            "cer_norm": round(np.mean([filtered["languages"][k]["cer_norm"] for k in f_lang_keys]), 2),
        }
    metrics["__filtered_min_duration__"] = filtered

    import importlib.metadata
    jiwer_ver = importlib.metadata.version("jiwer")
    metrics["__meta__"] = {
        "checkpoint": args.checkpoint,
        "checkpoint_name": ckpt_name,
        "model_id": MODEL_ID,
        "model_type": MODEL_TYPE,
        "dataset": DATASET,
        "batch_size": args.batch_size,
        "inference_time_sec": round(inference_time, 2),
        "total_audio_sec": round(total_audio_sec, 2),
        "rtf": round(rtf, 4),
        "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
        "gpu": GPU_NAME,
        "framework": "transformers",
        "normalization_version": NORMALIZATION_VERSION,
        "jiwer_version": jiwer_ver,
    }

    # ── sample_analysis.json ──
    sample_analysis = []
    for s in enriched:
        sample_analysis.append({
            "id": s["id"], "language": s["language"],
            "reference": s["reference"], "hypothesis": s["hypothesis"],
            "ref_norm": s["ref_norm"], "hyp_norm": s["hyp_norm"],
            "ref_numcanon": s["ref_numcanon"], "hyp_numcanon": s["hyp_numcanon"],
            "ref_mer": s["ref_mer"], "hyp_mer": s["hyp_mer"],
            "detected_language": s["detected_language"],
            "wer_raw": s["wer_raw"], "wer_norm": s["wer_norm"],
            "space_norm_wer": s["space_norm_wer"], "mer": s["mer"],
            "flags": s["flags"],
        })

    # ── error_analysis.json ──
    samples_for_errors = defaultdict(list)
    for s in sample_analysis:
        samples_for_errors[s["language"]].append(s)

    error_analysis = compute_error_analysis(samples_for_errors)

    sorted_langs = sorted(lang_keys, key=lambda k: metrics[k]["wer_norm"])
    error_analysis["__summary__"]["best_languages"] = sorted_langs[:3]
    error_analysis["__summary__"]["worst_languages"] = sorted_langs[-3:]

    avg_raw_to_norm = np.mean([metrics[k]["normalization_delta"]["raw_to_norm"] for k in lang_keys])
    avg_norm_to_nc = np.mean([metrics[k]["normalization_delta"]["norm_to_numcanon"] for k in lang_keys])

    if avg_raw_to_norm > 10:
        error_analysis["__summary__"]["model_diagnosis"] = "formatting-limited"
        error_analysis["__summary__"]["formatting_impact"] = "high"
    elif avg_norm_to_nc > 5:
        error_analysis["__summary__"]["model_diagnosis"] = "numeric-limited"
        error_analysis["__summary__"]["numeric_verbalization_impact"] = "high"
    else:
        error_analysis["__summary__"]["model_diagnosis"] = "recognition-limited"
        error_analysis["__summary__"]["formatting_impact"] = "low" if avg_raw_to_norm < 5 else "moderate"
        error_analysis["__summary__"]["numeric_verbalization_impact"] = "low" if avg_norm_to_nc < 1 else "moderate"

    # ── Write output files ──
    with open(output_dir / "metrics.json", "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    with open(output_dir / "sample_analysis.json", "w", encoding="utf-8") as f:
        json.dump(sample_analysis, f, indent=2, ensure_ascii=False)

    with open(output_dir / "error_analysis.json", "w", encoding="utf-8") as f:
        json.dump(error_analysis, f, indent=2, ensure_ascii=False)

    # Print summary
    ov = metrics["__overall__"]
    print(f"\n{'='*85}")
    print(f"  BENCHMARK RESULTS: {ckpt_name} ({MODEL_ID})")
    print(f"{'='*85}")
    print(f"{'Language':<14} {'N':>5} {'WER-raw':>8} {'WER-n':>8} {'WER-nc':>8} {'SNW':>8} {'MER':>8} {'CER-n':>8}")
    print(f"{'-'*14} {'-'*5} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")

    for lang in sorted(k for k in metrics if not k.startswith("__")):
        m = metrics[lang]
        print(f"{lang:<14} {m['n_samples']:>5} {m['wer_raw']:>8.2f} {m['wer_norm']:>8.2f} "
              f"{m['wer_numcanon']:>8.2f} {m['space_norm_wer']:>8.2f} {m['mer']:>8.2f} {m['cer_norm']:>8.2f}")

    print(f"{'-'*14} {'-'*5} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")
    print(f"{'OVERALL':<14} {ov['n_samples']:>5} {ov['wer_raw']:>8.2f} {ov['wer_norm']:>8.2f} "
          f"{ov['wer_numcanon']:>8.2f} {ov['space_norm_wer']:>8.2f} {ov['mer']:>8.2f} {ov['cer_norm']:>8.2f}")
    ma = metrics["__macro_avg__"]
    print(f"{'MACRO-AVG':<14} {ma['n_languages']:>5} {ma['wer_raw']:>8.2f} {ma['wer_norm']:>8.2f} "
          f"{ma['wer_numcanon']:>8.2f} {ma['space_norm_wer']:>8.2f} {ma['mer']:>8.2f} {ma['cer_norm']:>8.2f}")
    print(f"{'='*85}")
    print(f"\nOutputs saved to: {output_dir}")
    print(f"  metrics.json, sample_analysis.json, error_analysis.json")


if __name__ == "__main__":
    main()
