#!/usr/bin/env python3
"""
Generate BENCHMARK_SCHEMA v1 outputs for gemma3n-e2b checkpoints.

Reads predictions.json → produces metrics.json, sample_analysis.json, error_analysis.json
with all 6 metric tiers: wer_raw, wer_norm, wer_numcanon, space_norm_wer, mer, cer_norm.

Usage:
    python generate_schema_v1.py --src benchmark_results_gemma3n/ckpt-20000 --ckpt ckpt-20000
    python generate_schema_v1.py --all   # process all gemma3n checkpoints
"""

import argparse
import json
import os
import re
import unicodedata
from collections import Counter, defaultdict
from datetime import datetime, timezone
from pathlib import Path

import jiwer
import numpy as np

NORMALIZATION_VERSION = "v1"

# ── Normalization Pipeline ──────────────────────────────────────────────

ZERO_WIDTH = re.compile(r"[\u200b-\u200f\u2028-\u202f\ufeff\u00ad]")
PUNCT_COMMON = re.compile(
    r"[!\"#$%&'()*+,\-./:;<=>?@\[\\\]^_`{|}~"
    r"\u0964\u0965"
    r"\u2013\u2014\u2015"
    r"\u2018\u2019\u201c\u201d"
    r"\u2026\u00ab\u00bb\u0022\u0027"
    r"]"
)
DIGIT_GROUP_SEP = re.compile(r"(?<=\d)[,\s](?=\d)")


def norm_raw(text: str) -> str:
    """Tier 0 — NFC + trim."""
    return unicodedata.normalize("NFC", text).strip()


def norm_standard(text: str) -> str:
    """Tier 1 — NFKC + ZW strip + whitespace collapse + punct removal + case fold."""
    text = unicodedata.normalize("NFKC", text)
    text = ZERO_WIDTH.sub("", text)
    text = re.sub(r"\s+", " ", text).strip()
    text = text.replace("\u2018", "'").replace("\u2019", "'")
    text = text.replace("\u201c", '"').replace("\u201d", '"')
    text = text.replace("\u2013", "-").replace("\u2014", "-")
    text = text.replace("\u0965", "\u0964")
    text = PUNCT_COMMON.sub("", text)
    text = re.sub(r"\s+", " ", text).strip()
    text = text.lower()
    return text


def norm_numcanon(text: str) -> str:
    """Tier 2 — norm + digit grouping removal."""
    text = norm_standard(text)
    text = DIGIT_GROUP_SEP.sub("", text)
    return text


def norm_spaceless(text: str) -> str:
    """For MER — remove all spaces after standard norm."""
    return norm_standard(text).replace(" ", "")


# ── space_norm_wer (per-sample) ─────────────────────────────────────────

def compute_space_norm_wer_sample(ref_norm: str, hyp_norm: str) -> tuple:
    """Returns (error_words, total_words) for one sample."""
    ref_words = ref_norm.split()
    total_words = len(ref_words)
    if total_words == 0:
        return (0, 0)

    ref_ns = ref_norm.replace(" ", "")
    hyp_ns = hyp_norm.replace(" ", "")

    if ref_ns == hyp_ns:
        return (0, total_words)

    # Build char_to_word mapping
    char_to_word = []
    for wi, w in enumerate(ref_words):
        for _ in w:
            char_to_word.append(wi)

    # Levenshtein DP
    n, m = len(ref_ns), len(hyp_ns)
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n + 1):
        dp[i][0] = i
    for j in range(m + 1):
        dp[0][j] = j
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref_ns[i - 1] == hyp_ns[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])

    # Backtrack to find touched words
    touched = set()
    i, j = n, m
    while i > 0 or j > 0:
        if i > 0 and j > 0 and ref_ns[i - 1] == hyp_ns[j - 1]:
            i -= 1
            j -= 1
        elif i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + 1:
            # substitution
            touched.add(char_to_word[i - 1])
            i -= 1
            j -= 1
        elif i > 0 and dp[i][j] == dp[i - 1][j] + 1:
            # deletion
            touched.add(char_to_word[i - 1])
            i -= 1
        elif j > 0 and dp[i][j] == dp[i][j - 1] + 1:
            # insertion — attribute to nearest ref word
            if i > 0:
                touched.add(char_to_word[i - 1])
            elif char_to_word:
                touched.add(char_to_word[0])
            j -= 1
        else:
            break

    return (len(touched), total_words)


# ── Per-sample WER/MER ──────────────────────────────────────────────────

def safe_wer(ref: str, hyp: str) -> float:
    if not ref.strip():
        return 0.0 if not hyp.strip() else 100.0
    try:
        return round(jiwer.wer(ref, hyp) * 100, 2)
    except Exception:
        return 100.0


def safe_cer(ref: str, hyp: str) -> float:
    if not ref.strip():
        return 0.0 if not hyp.strip() else 100.0
    try:
        return round(jiwer.cer(ref, hyp) * 100, 2)
    except Exception:
        return 100.0


# ── Flag detection ──────────────────────────────────────────────────────

def detect_flags(r_raw, h_raw, r_n, h_n, r_nc, h_nc, r_ns, h_ns, wer_norm_val) -> list:
    flags = []
    if r_raw == h_raw:
        flags.append("exact_match")
    if r_n == h_n:
        flags.append("exact_match_norm")
    if not h_raw.strip():
        flags.append("empty_hypothesis")
    if r_raw != h_raw and r_n == h_n:
        flags.append("punctuation_only_diff")
    if r_n != h_n and r_nc == h_nc:
        flags.append("numeric_mismatch")
    if wer_norm_val > 80:
        flags.append("high_wer")
    # spacing_error: norm differs but space-stripped matches
    if r_n != h_n and r_ns == h_ns:
        flags.append("spacing_error")
    return flags


# ── Error analysis ──────────────────────────────────────────────────────

def compute_error_ops(refs, hyps):
    subs, ins, dels = Counter(), Counter(), Counter()
    for ref, hyp in zip(refs, hyps):
        ref_words = ref.split()
        hyp_words = hyp.split()
        try:
            out = jiwer.process_words(ref, hyp)
            for chunk in out.alignments[0]:
                if chunk.type == "substitute":
                    for ri, hi in zip(range(chunk.ref_start_idx, chunk.ref_end_idx),
                                      range(chunk.hyp_start_idx, chunk.hyp_end_idx)):
                        if ri < len(ref_words) and hi < len(hyp_words):
                            subs[(ref_words[ri], hyp_words[hi])] += 1
                elif chunk.type == "insert":
                    for hi in range(chunk.hyp_start_idx, chunk.hyp_end_idx):
                        if hi < len(hyp_words):
                            ins[hyp_words[hi]] += 1
                elif chunk.type == "delete":
                    for ri in range(chunk.ref_start_idx, chunk.ref_end_idx):
                        if ri < len(ref_words):
                            dels[ref_words[ri]] += 1
        except Exception:
            continue
    return subs, ins, dels


# ── Main conversion ─────────────────────────────────────────────────────

def process_checkpoint(src_dir: str, dst_dir: str, ckpt_name: str):
    os.makedirs(dst_dir, exist_ok=True)

    with open(os.path.join(src_dir, "predictions.json")) as f:
        predictions = json.load(f)
    print(f"  Loaded {len(predictions)} predictions")

    # Load old meta
    with open(os.path.join(src_dir, "metrics.json")) as f:
        old_meta = json.load(f).get("__meta__", {})

    # ── Per-sample processing ───────────────────────────────────────
    by_lang = defaultdict(lambda: {
        "refs_raw": [], "hyps_raw": [],
        "refs_norm": [], "hyps_norm": [],
        "refs_nc": [], "hyps_nc": [],
        "refs_ns": [], "hyps_ns": [],
        "snwer_errors": [], "snwer_totals": [],
    })

    samples = []

    for p in predictions:
        lang = p["language"]
        ref, hyp = p["reference"], p.get("hypothesis", "")

        r_raw = norm_raw(ref)
        h_raw = norm_raw(hyp)
        r_n = norm_standard(ref)
        h_n = norm_standard(hyp)
        r_nc = norm_numcanon(ref)
        h_nc = norm_numcanon(hyp)
        r_ns = norm_spaceless(ref)
        h_ns = norm_spaceless(hyp)

        w_raw = safe_wer(r_raw, h_raw)
        w_norm = safe_wer(r_n, h_n)
        # MER = CER on space-stripped text (since it's a single "word")
        mer_val = safe_cer(r_ns, h_ns) if r_ns else (0.0 if not h_ns else 100.0)
        flags = detect_flags(r_raw, h_raw, r_n, h_n, r_nc, h_nc, r_ns, h_ns, w_norm)

        # space_norm_wer per sample
        sn_err, sn_tot = compute_space_norm_wer_sample(r_n, h_n)

        d = by_lang[lang]
        d["refs_raw"].append(r_raw); d["hyps_raw"].append(h_raw)
        d["refs_norm"].append(r_n); d["hyps_norm"].append(h_n)
        d["refs_nc"].append(r_nc); d["hyps_nc"].append(h_nc)
        d["refs_ns"].append(r_ns); d["hyps_ns"].append(h_ns)
        d["snwer_errors"].append(sn_err); d["snwer_totals"].append(sn_tot)

        samples.append({
            "id": p["id"],
            "language": lang,
            "reference": ref,
            "hypothesis": hyp,
            "ref_norm": r_n,
            "hyp_norm": h_n,
            "ref_numcanon": r_nc,
            "hyp_numcanon": h_nc,
            "ref_mer": r_ns,
            "hyp_mer": h_ns,
            "wer_raw": w_raw,
            "wer_norm": w_norm,
            "mer": mer_val,
            "flags": flags,
        })

    # ── Per-language metrics ────────────────────────────────────────
    metrics = {}
    agg = {"raw": ([], []), "norm": ([], []), "nc": ([], []), "ns": ([], []),
           "snwer_e": [], "snwer_t": []}

    for lang in sorted(by_lang.keys()):
        d = by_lang[lang]
        n = len(d["refs_raw"])

        w_raw = round(jiwer.wer(d["refs_raw"], d["hyps_raw"]) * 100, 2)
        w_norm = round(jiwer.wer(d["refs_norm"], d["hyps_norm"]) * 100, 2)
        w_nc = round(jiwer.wer(d["refs_nc"], d["hyps_nc"]) * 100, 2)
        c_norm = round(jiwer.cer(d["refs_norm"], d["hyps_norm"]) * 100, 2)

        # MER (corpus-level CER on space-stripped)
        all_ref_ns = "".join(d["refs_ns"])
        all_hyp_ns = "".join(d["hyps_ns"])
        mer_val = round(jiwer.cer(all_ref_ns, all_hyp_ns) * 100, 2) if all_ref_ns else 0.0

        # space_norm_wer (micro-average)
        sn_e = sum(d["snwer_errors"])
        sn_t = sum(d["snwer_totals"])
        snwer = round(sn_e / sn_t * 100, 2) if sn_t > 0 else 0.0

        empty_count = sum(1 for h in d["hyps_raw"] if not h.strip())

        metrics[lang] = {
            "n_samples": n,
            "wer_raw": w_raw,
            "wer_norm": w_norm,
            "wer_numcanon": w_nc,
            "space_norm_wer": snwer,
            "mer": mer_val,
            "cer_norm": c_norm,
            "empty_hypotheses": empty_count,
            "normalization_delta": {
                "raw_to_norm": round(w_raw - w_norm, 2),
                "norm_to_numcanon": round(w_norm - w_nc, 2),
                "norm_to_space_norm": round(w_norm - snwer, 2),
                "norm_to_mer": round(w_norm - mer_val, 2),
            },
        }

        agg["raw"][0].extend(d["refs_raw"]); agg["raw"][1].extend(d["hyps_raw"])
        agg["norm"][0].extend(d["refs_norm"]); agg["norm"][1].extend(d["hyps_norm"])
        agg["nc"][0].extend(d["refs_nc"]); agg["nc"][1].extend(d["hyps_nc"])
        agg["ns"][0].extend(d["refs_ns"]); agg["ns"][1].extend(d["hyps_ns"])
        agg["snwer_e"].extend(d["snwer_errors"]); agg["snwer_t"].extend(d["snwer_totals"])

    # ── Overall ─────────────────────────────────────────────────────
    all_ns_ref = "".join(agg["ns"][0])
    all_ns_hyp = "".join(agg["ns"][1])
    sn_e_all = sum(agg["snwer_e"])
    sn_t_all = sum(agg["snwer_t"])

    metrics["__overall__"] = {
        "n_samples": len(agg["raw"][0]),
        "wer_raw": round(jiwer.wer(agg["raw"][0], agg["raw"][1]) * 100, 2),
        "wer_norm": round(jiwer.wer(agg["norm"][0], agg["norm"][1]) * 100, 2),
        "wer_numcanon": round(jiwer.wer(agg["nc"][0], agg["nc"][1]) * 100, 2),
        "space_norm_wer": round(sn_e_all / sn_t_all * 100, 2) if sn_t_all else 0.0,
        "mer": round(jiwer.cer(all_ns_ref, all_ns_hyp) * 100, 2) if all_ns_ref else 0.0,
        "cer_norm": round(jiwer.cer(agg["norm"][0], agg["norm"][1]) * 100, 2),
    }

    # ── Macro average ───────────────────────────────────────────────
    lk = [k for k in metrics if not k.startswith("__")]
    metrics["__macro_avg__"] = {
        "n_languages": len(lk),
        "wer_raw": round(np.mean([metrics[k]["wer_raw"] for k in lk]), 2),
        "wer_norm": round(np.mean([metrics[k]["wer_norm"] for k in lk]), 2),
        "wer_numcanon": round(np.mean([metrics[k]["wer_numcanon"] for k in lk]), 2),
        "space_norm_wer": round(np.mean([metrics[k]["space_norm_wer"] for k in lk]), 2),
        "mer": round(np.mean([metrics[k]["mer"] for k in lk]), 2),
        "cer_norm": round(np.mean([metrics[k]["cer_norm"] for k in lk]), 2),
    }

    # ── Meta ────────────────────────────────────────────────────────
    metrics["__meta__"] = {
        "checkpoint": old_meta.get("checkpoint", ""),
        "checkpoint_name": ckpt_name,
        "model_id": "gemma3n-e2b",
        "model_type": old_meta.get("model_type", "gemma3n-E2B-asr"),
        "dataset": old_meta.get("dataset", "BayAreaBoys/indic-asr-benchmark-6k"),
        "batch_size": old_meta.get("batch_size", 128),
        "inference_time_sec": old_meta.get("inference_time_sec", 0),
        "total_audio_sec": old_meta.get("total_audio_sec", 0),
        "rtf": old_meta.get("rtf", 0),
        "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
        "gpu": "NVIDIA A100-SXM4-80GB",
        "framework": "transformers",
        "normalization_version": NORMALIZATION_VERSION,
        "jiwer_version": getattr(jiwer, "__version__", "unknown"),
    }

    # ── Write metrics.json ──────────────────────────────────────────
    with open(os.path.join(dst_dir, "metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    # ── Write sample_analysis.json ──────────────────────────────────
    with open(os.path.join(dst_dir, "sample_analysis.json"), "w") as f:
        json.dump(samples, f, indent=2, ensure_ascii=False)

    # ── Build error_analysis.json ───────────────────────────────────
    ea_by_lang = defaultdict(lambda: {"refs": [], "hyps": [], "ids": [], "wers": []})
    for s in samples:
        ea_by_lang[s["language"]]["refs"].append(s["ref_norm"])
        ea_by_lang[s["language"]]["hyps"].append(s["hyp_norm"])
        ea_by_lang[s["language"]]["ids"].append(s["id"])
        ea_by_lang[s["language"]]["wers"].append(s["wer_norm"])

    error_analysis = {}
    lang_avg_wers = {}

    for lang in sorted(ea_by_lang.keys()):
        d = ea_by_lang[lang]
        subs, ins, dels = compute_error_ops(d["refs"], d["hyps"])

        # Count error buckets from flags
        lang_samples = [s for s in samples if s["language"] == lang]
        num_mm = sum(1 for s in lang_samples if "numeric_mismatch" in s["flags"])
        punct_only = sum(1 for s in lang_samples if "punctuation_only_diff" in s["flags"])
        spacing_err = sum(1 for s in lang_samples if "spacing_error" in s["flags"])
        empty_hyp = sum(1 for s in lang_samples if "empty_hypothesis" in s["flags"])

        sorted_idx = np.argsort(d["wers"])
        worst = [d["ids"][i] for i in sorted_idx[-3:]][::-1]
        best = [d["ids"][i] for i in sorted_idx[:3]]
        num_samples = [s["id"] for s in lang_samples if "numeric_mismatch" in s["flags"]][:3]

        error_analysis[lang] = {
            "top_substitutions": [{"ref": r, "hyp": h, "count": c} for (r, h), c in subs.most_common(20)],
            "top_insertions": [{"word": w, "count": c} for w, c in ins.most_common(20)],
            "top_deletions": [{"word": w, "count": c} for w, c in dels.most_common(20)],
            "error_buckets": {
                "numeric_mismatch_count": num_mm,
                "punctuation_only_count": punct_only,
                "spacing_tokenization_count": spacing_err,
                "entity_mismatch_count": 0,
                "script_confusion_count": 0,
                "empty_hypothesis_count": empty_hyp,
            },
            "examples": {
                "worst_samples": worst,
                "best_samples": best,
                "numeric_mismatch_samples": num_samples,
            },
        }
        lang_avg_wers[lang] = np.mean(d["wers"])

    sorted_langs = sorted(lang_avg_wers, key=lang_avg_wers.get)
    ov = metrics["__overall__"]
    raw_norm_d = ov["wer_raw"] - ov["wer_norm"]
    norm_nc_d = ov["wer_norm"] - ov["wer_numcanon"]

    if raw_norm_d > 10:
        diag, primary = "formatting-limited", "formatting"
    elif norm_nc_d > 5:
        diag, primary = "numeric-limited", "numeric_verbalization"
    else:
        diag, primary = "recognition-limited", "recognition"

    error_analysis["__summary__"] = {
        "model_diagnosis": diag,
        "primary_error_source": primary,
        "numeric_verbalization_impact": "high" if norm_nc_d > 5 else "moderate" if norm_nc_d > 2 else "low",
        "formatting_impact": "high" if raw_norm_d > 10 else "moderate" if raw_norm_d > 5 else "low",
        "worst_languages": sorted_langs[-3:][::-1],
        "best_languages": sorted_langs[:3],
    }

    with open(os.path.join(dst_dir, "error_analysis.json"), "w") as f:
        json.dump(error_analysis, f, indent=2, ensure_ascii=False)

    # ── Print summary ───────────────────────────────────────────────
    print(f"  Overall: wer_raw={ov['wer_raw']} wer_norm={ov['wer_norm']} wer_nc={ov['wer_numcanon']} "
          f"snwer={ov['space_norm_wer']} mer={ov['mer']} cer_norm={ov['cer_norm']}")
    print(f"  Diagnosis: {diag}")
    return metrics


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--src", help="Source predictions dir")
    parser.add_argument("--ckpt", help="Checkpoint name")
    parser.add_argument("--all", action="store_true", help="Process all gemma3n checkpoints")
    args = parser.parse_args()

    dst_base = "/home/ubuntu/training/benchmark_outputs/gemma3n-e2b"

    if args.all:
        src_base = "/home/ubuntu/training/benchmark_results_gemma3n"
        for ckpt_dir in sorted(Path(src_base).iterdir()):
            if ckpt_dir.is_dir() and (ckpt_dir / "predictions.json").exists():
                ckpt = ckpt_dir.name
                print(f"\n{'='*60}\n  {ckpt}\n{'='*60}")
                process_checkpoint(str(ckpt_dir), f"{dst_base}/{ckpt}", ckpt)
    else:
        print(f"\n{'='*60}\n  {args.ckpt}\n{'='*60}")
        process_checkpoint(args.src, f"{dst_base}/{args.ckpt}", args.ckpt)


if __name__ == "__main__":
    main()
