#!/usr/bin/env python3
"""
Convert gemma3n-e2b benchmark predictions to BENCHMARK_SCHEMA v1.

Reads predictions.json from benchmark_results_gemma3n/<ckpt>/
Produces metrics.json, sample_analysis.json, error_analysis.json
in benchmark_outputs/gemma3n-e2b/<ckpt>/

Normalization tiers:
  raw:      NFC, trim whitespace
  norm:     NFKC, strip ZWJ/ZWNJ, collapse whitespace, lang-aware punct removal, case fold (Latin only)
  numcanon: norm + digit-grouping removal, Arabic numeral unification
"""

import json
import os
import re
import sys
import unicodedata
from collections import Counter, defaultdict
from datetime import datetime, timezone
from pathlib import Path

import jiwer
import numpy as np

# ── Normalization Pipeline (versioned: v1) ──────────────────────────────

ZERO_WIDTH = re.compile(r"[\u200b-\u200f\u2028-\u202f\ufeff\u00ad]")
PUNCT_COMMON = re.compile(
    r"[!\"#$%&'()*+,\-./:;<=>?@\[\\\]^_`{|}~"
    r"\u0964\u0965"              # Devanagari danda / double danda
    r"\u2013\u2014\u2015"        # en-dash, em-dash, horizontal bar
    r"\u2018\u2019\u201c\u201d"  # curly quotes
    r"\u2026"                    # ellipsis
    r"\u00ab\u00bb"              # guillemets
    r"\u0022\u0027"              # straight quotes
    r"]"
)
DIGIT_GROUP_SEP = re.compile(r"(?<=\d)[,\s](?=\d)")


def norm_raw(text: str) -> str:
    """Tier 0: NFC + trim."""
    text = unicodedata.normalize("NFC", text)
    text = text.strip()
    return text


def norm_norm(text: str) -> str:
    """Tier 1: NFKC + strip ZW chars + collapse whitespace + remove punct + case fold."""
    text = unicodedata.normalize("NFKC", text)
    text = ZERO_WIDTH.sub("", text)
    text = re.sub(r"\s+", " ", text).strip()
    # Standardize punctuation variants before removal
    text = text.replace("\u2018", "'").replace("\u2019", "'")
    text = text.replace("\u201c", '"').replace("\u201d", '"')
    text = text.replace("\u2013", "-").replace("\u2014", "-")
    text = text.replace("\u0965", "\u0964")  # double danda → single
    # Remove punctuation
    text = PUNCT_COMMON.sub("", text)
    text = re.sub(r"\s+", " ", text).strip()
    # Case fold (only affects Latin script characters)
    text = text.lower()
    return text


def norm_numcanon(text: str) -> str:
    """Tier 2: norm + number canonicalization."""
    text = norm_norm(text)
    # Remove digit grouping separators: "25,000" → "25000", "1 000" → "1000"
    text = DIGIT_GROUP_SEP.sub("", text)
    return text


# ── Per-sample WER ──────────────────────────────────────────────────────

def sample_wer(ref: str, hyp: str) -> float:
    """Compute WER for a single sample. Returns 0-100 float."""
    if not ref.strip():
        return 0.0 if not hyp.strip() else 100.0
    try:
        return round(jiwer.wer(ref, hyp) * 100, 2)
    except Exception:
        return 100.0


# ── Flag Detection ──────────────────────────────────────────────────────

def detect_flags(ref_raw, hyp_raw, ref_n, hyp_n, ref_nc, hyp_nc) -> list:
    flags = []
    if ref_raw == hyp_raw:
        flags.append("exact_match")
    if ref_n == hyp_n:
        flags.append("exact_match_norm")
    if not hyp_raw.strip():
        flags.append("empty_hypothesis")
    # Punctuation-only diff: raw differs but norm matches
    if ref_raw != hyp_raw and ref_n == hyp_n:
        flags.append("punctuation_only_diff")
    # Numeric mismatch: numcanon changes the match status
    if ref_n != hyp_n and ref_nc == hyp_nc:
        flags.append("numeric_mismatch")
    # High WER
    w = sample_wer(ref_n, hyp_n)
    if w > 80:
        flags.append("high_wer")
    return flags


# ── Error Analysis (alignment-based) ────────────────────────────────────

def compute_error_ops(refs: list, hyps: list):
    """Compute word-level substitutions, insertions, deletions using jiwer alignment."""
    subs = Counter()
    ins = Counter()
    dels = Counter()

    for ref, hyp in zip(refs, hyps):
        ref_words = ref.split()
        hyp_words = hyp.split()
        try:
            out = jiwer.process_words(ref, hyp)
            for chunk in out.alignments[0]:
                if chunk.type == "substitute":
                    for ri, hi in zip(
                        range(chunk.ref_start_idx, chunk.ref_end_idx),
                        range(chunk.hyp_start_idx, chunk.hyp_end_idx),
                    ):
                        if ri < len(ref_words) and hi < len(hyp_words):
                            subs[(ref_words[ri], hyp_words[hi])] += 1
                elif chunk.type == "insert":
                    for hi in range(chunk.hyp_start_idx, chunk.hyp_end_idx):
                        if hi < len(hyp_words):
                            ins[hyp_words[hi]] += 1
                elif chunk.type == "delete":
                    for ri in range(chunk.ref_start_idx, chunk.ref_end_idx):
                        if ri < len(ref_words):
                            dels[ref_words[ri]] += 1
        except Exception:
            continue

    return subs, ins, dels


def build_error_analysis(predictions: list) -> dict:
    """Build per-language error analysis."""
    by_lang = defaultdict(lambda: {"refs": [], "hyps": [], "ids": [], "wers": []})

    for p in predictions:
        lang = p["language"]
        ref_n = norm_norm(p["reference"])
        hyp_n = norm_norm(p.get("hypothesis", ""))
        w = sample_wer(ref_n, hyp_n)
        by_lang[lang]["refs"].append(ref_n)
        by_lang[lang]["hyps"].append(hyp_n)
        by_lang[lang]["ids"].append(p["id"])
        by_lang[lang]["wers"].append(w)

    analysis = {}
    all_lang_wers = {}

    for lang in sorted(by_lang.keys()):
        d = by_lang[lang]
        subs, ins, dels = compute_error_ops(d["refs"], d["hyps"])

        top_subs = [{"ref": r, "hyp": h, "count": c} for (r, h), c in subs.most_common(20)]
        top_ins = [{"word": w, "count": c} for w, c in ins.most_common(20)]
        top_dels = [{"word": w, "count": c} for w, c in dels.most_common(20)]

        # Error buckets
        empty_count = sum(1 for h in d["hyps"] if not h.strip())
        punct_only = 0
        numeric_mm = 0
        for p in predictions:
            if p["language"] != lang:
                continue
            ref_raw = norm_raw(p["reference"])
            hyp_raw = norm_raw(p.get("hypothesis", ""))
            ref_n = norm_norm(p["reference"])
            hyp_n = norm_norm(p.get("hypothesis", ""))
            ref_nc = norm_numcanon(p["reference"])
            hyp_nc = norm_numcanon(p.get("hypothesis", ""))
            if ref_raw != hyp_raw and ref_n == hyp_n:
                punct_only += 1
            if ref_n != hyp_n and ref_nc == hyp_nc:
                numeric_mm += 1

        # Best/worst samples by WER
        sorted_idx = np.argsort(d["wers"])
        worst_ids = [d["ids"][i] for i in sorted_idx[-3:]][::-1]
        best_ids = [d["ids"][i] for i in sorted_idx[:3]]

        analysis[lang] = {
            "top_substitutions": top_subs,
            "top_insertions": top_ins,
            "top_deletions": top_dels,
            "error_buckets": {
                "numeric_mismatch_count": numeric_mm,
                "punctuation_only_count": punct_only,
                "spacing_tokenization_count": 0,
                "entity_mismatch_count": 0,
                "script_confusion_count": 0,
                "empty_hypothesis_count": empty_count,
            },
            "examples": {
                "worst_samples": worst_ids,
                "best_samples": best_ids,
            },
        }
        all_lang_wers[lang] = np.mean(d["wers"])

    # Summary
    sorted_langs = sorted(all_lang_wers, key=all_lang_wers.get)
    best_langs = sorted_langs[:3]
    worst_langs = sorted_langs[-3:][::-1]

    # Determine diagnosis
    # Compare overall raw vs norm delta
    analysis["__summary__"] = {
        "model_diagnosis": "recognition-limited",
        "primary_error_source": "recognition",
        "numeric_verbalization_impact": "low",
        "formatting_impact": "low",
        "worst_languages": worst_langs,
        "best_languages": best_langs,
    }

    return analysis


# ── Main Conversion ─────────────────────────────────────────────────────

def convert_checkpoint(src_dir: str, dst_dir: str, ckpt_name: str):
    """Convert one checkpoint's predictions to schema v1."""
    os.makedirs(dst_dir, exist_ok=True)

    # Load predictions
    pred_file = os.path.join(src_dir, "predictions.json")
    with open(pred_file) as f:
        predictions = json.load(f)
    print(f"  Loaded {len(predictions)} predictions from {pred_file}")

    # Load old metrics for meta info
    old_meta_file = os.path.join(src_dir, "metrics.json")
    with open(old_meta_file) as f:
        old_metrics = json.load(f)
    old_meta = old_metrics.get("__meta__", {})

    # ── Compute all normalized texts ────────────────────────────────
    by_lang = defaultdict(lambda: {
        "refs_raw": [], "hyps_raw": [],
        "refs_norm": [], "hyps_norm": [],
        "refs_nc": [], "hyps_nc": [],
    })

    sample_analysis = []

    for p in predictions:
        lang = p["language"]
        ref = p["reference"]
        hyp = p.get("hypothesis", "")

        r_raw = norm_raw(ref)
        h_raw = norm_raw(hyp)
        r_norm = norm_norm(ref)
        h_norm = norm_norm(hyp)
        r_nc = norm_numcanon(ref)
        h_nc = norm_numcanon(hyp)

        by_lang[lang]["refs_raw"].append(r_raw)
        by_lang[lang]["hyps_raw"].append(h_raw)
        by_lang[lang]["refs_norm"].append(r_norm)
        by_lang[lang]["hyps_norm"].append(h_norm)
        by_lang[lang]["refs_nc"].append(r_nc)
        by_lang[lang]["hyps_nc"].append(h_nc)

        flags = detect_flags(r_raw, h_raw, r_norm, h_norm, r_nc, h_nc)
        w_raw = sample_wer(r_raw, h_raw)
        w_norm = sample_wer(r_norm, h_norm)

        sample_analysis.append({
            "id": p["id"],
            "language": lang,
            "reference": ref,
            "hypothesis": hyp,
            "ref_norm": r_norm,
            "hyp_norm": h_norm,
            "ref_numcanon": r_nc,
            "hyp_numcanon": h_nc,
            "wer_raw": w_raw,
            "wer_norm": w_norm,
            "flags": flags,
        })

    # ── Compute metrics per language ────────────────────────────────
    metrics = {}
    all_refs_raw, all_hyps_raw = [], []
    all_refs_norm, all_hyps_norm = [], []
    all_refs_nc, all_hyps_nc = [], []

    for lang in sorted(by_lang.keys()):
        d = by_lang[lang]
        n = len(d["refs_raw"])

        # Skip empty refs
        valid = [(i, r) for i, r in enumerate(d["refs_norm"]) if r.strip()]
        if not valid:
            continue

        w_raw = round(jiwer.wer(d["refs_raw"], d["hyps_raw"]) * 100, 2)
        w_norm = round(jiwer.wer(d["refs_norm"], d["hyps_norm"]) * 100, 2)
        w_nc = round(jiwer.wer(d["refs_nc"], d["hyps_nc"]) * 100, 2)
        c_norm = round(jiwer.cer(d["refs_norm"], d["hyps_norm"]) * 100, 2)

        empty_count = sum(1 for h in d["hyps_raw"] if not h.strip())

        metrics[lang] = {
            "n_samples": n,
            "wer_raw": w_raw,
            "wer_norm": w_norm,
            "wer_numcanon": w_nc,
            "cer_norm": c_norm,
            "empty_hypotheses": empty_count,
            "normalization_delta": {
                "raw_to_norm": round(w_raw - w_norm, 2),
                "norm_to_numcanon": round(w_norm - w_nc, 2),
            },
        }

        all_refs_raw.extend(d["refs_raw"])
        all_hyps_raw.extend(d["hyps_raw"])
        all_refs_norm.extend(d["refs_norm"])
        all_hyps_norm.extend(d["hyps_norm"])
        all_refs_nc.extend(d["refs_nc"])
        all_hyps_nc.extend(d["hyps_nc"])

    # ── Aggregates ──────────────────────────────────────────────────
    metrics["__overall__"] = {
        "n_samples": len(all_refs_raw),
        "wer_raw": round(jiwer.wer(all_refs_raw, all_hyps_raw) * 100, 2),
        "wer_norm": round(jiwer.wer(all_refs_norm, all_hyps_norm) * 100, 2),
        "wer_numcanon": round(jiwer.wer(all_refs_nc, all_hyps_nc) * 100, 2),
        "cer_norm": round(jiwer.cer(all_refs_norm, all_hyps_norm) * 100, 2),
    }

    lang_keys = [k for k in metrics if not k.startswith("__")]
    metrics["__macro_avg__"] = {
        "n_languages": len(lang_keys),
        "wer_raw": round(np.mean([metrics[k]["wer_raw"] for k in lang_keys]), 2),
        "wer_norm": round(np.mean([metrics[k]["wer_norm"] for k in lang_keys]), 2),
        "wer_numcanon": round(np.mean([metrics[k]["wer_numcanon"] for k in lang_keys]), 2),
        "cer_norm": round(np.mean([metrics[k]["cer_norm"] for k in lang_keys]), 2),
    }

    # ── Meta ────────────────────────────────────────────────────────
    metrics["__meta__"] = {
        "checkpoint": old_meta.get("checkpoint", ""),
        "checkpoint_name": ckpt_name,
        "model_id": "gemma3n-e2b",
        "model_type": old_meta.get("model_type", "gemma3n-E2B-asr"),
        "dataset": old_meta.get("dataset", "BayAreaBoys/indic-asr-benchmark-6k"),
        "batch_size": old_meta.get("batch_size", 128),
        "inference_time_sec": old_meta.get("inference_time_sec", 0),
        "total_audio_sec": old_meta.get("total_audio_sec", 0),
        "rtf": old_meta.get("rtf", 0),
        "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
        "gpu": "NVIDIA A100-SXM4-80GB",
        "framework": "transformers",
        "normalization_version": "v1",
        "jiwer_version": getattr(jiwer, "__version__", "unknown"),
    }

    # ── Update error analysis summary with actual deltas ────────────
    ov = metrics["__overall__"]
    raw_norm_delta = ov["wer_raw"] - ov["wer_norm"]
    norm_nc_delta = ov["wer_norm"] - ov["wer_numcanon"]

    # ── Write metrics.json ──────────────────────────────────────────
    with open(os.path.join(dst_dir, "metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)
    print(f"  metrics.json: {len(lang_keys)} langs, overall wer_norm={ov['wer_norm']}%")

    # ── Write sample_analysis.json ──────────────────────────────────
    with open(os.path.join(dst_dir, "sample_analysis.json"), "w") as f:
        json.dump(sample_analysis, f, indent=2, ensure_ascii=False)
    print(f"  sample_analysis.json: {len(sample_analysis)} samples")

    # ── Build and write error_analysis.json ─────────────────────────
    error_analysis = build_error_analysis(predictions)

    # Update summary with actual metric deltas
    fmt_impact = "high" if raw_norm_delta > 10 else "moderate" if raw_norm_delta > 5 else "low"
    num_impact = "high" if norm_nc_delta > 5 else "moderate" if norm_nc_delta > 2 else "low"

    if raw_norm_delta > 10:
        diagnosis = "formatting-limited"
        primary = "formatting"
    elif norm_nc_delta > 5:
        diagnosis = "numeric-limited"
        primary = "numeric_verbalization"
    else:
        diagnosis = "recognition-limited"
        primary = "recognition"

    error_analysis["__summary__"]["model_diagnosis"] = diagnosis
    error_analysis["__summary__"]["primary_error_source"] = primary
    error_analysis["__summary__"]["formatting_impact"] = fmt_impact
    error_analysis["__summary__"]["numeric_verbalization_impact"] = num_impact

    with open(os.path.join(dst_dir, "error_analysis.json"), "w") as f:
        json.dump(error_analysis, f, indent=2, ensure_ascii=False)
    print(f"  error_analysis.json: {len(lang_keys)} langs, diagnosis={diagnosis}")

    return metrics


def main():
    src_base = "/home/ubuntu/training/benchmark_results_gemma3n"
    dst_base = "/home/ubuntu/training/benchmark_outputs/gemma3n-e2b"

    for ckpt in ["ckpt-10000", "ckpt-20000"]:
        src = os.path.join(src_base, ckpt)
        dst = os.path.join(dst_base, ckpt)
        print(f"\n{'='*60}")
        print(f"Converting {ckpt}")
        print(f"{'='*60}")
        metrics = convert_checkpoint(src, dst, ckpt)

        # Print summary table
        ov = metrics["__overall__"]
        ma = metrics["__macro_avg__"]
        print(f"\n  {'Language':<14} {'wer_raw':>8} {'wer_norm':>9} {'wer_nc':>8} {'cer_norm':>9} {'Δraw→norm':>10} {'Δnorm→nc':>10}")
        print(f"  {'-'*14} {'-'*8} {'-'*9} {'-'*8} {'-'*9} {'-'*10} {'-'*10}")
        for lang in sorted(k for k in metrics if not k.startswith("__")):
            m = metrics[lang]
            d = m["normalization_delta"]
            print(f"  {lang:<14} {m['wer_raw']:>8.2f} {m['wer_norm']:>9.2f} {m['wer_numcanon']:>8.2f} "
                  f"{m['cer_norm']:>9.2f} {d['raw_to_norm']:>+10.2f} {d['norm_to_numcanon']:>+10.2f}")
        print(f"  {'-'*14} {'-'*8} {'-'*9} {'-'*8} {'-'*9} {'-'*10} {'-'*10}")
        print(f"  {'OVERALL':<14} {ov['wer_raw']:>8.2f} {ov['wer_norm']:>9.2f} {ov['wer_numcanon']:>8.2f} {ov['cer_norm']:>9.2f}")
        print(f"  {'MACRO-AVG':<14} {ma['wer_raw']:>8.2f} {ma['wer_norm']:>9.2f} {ma['wer_numcanon']:>8.2f} {ma['cer_norm']:>9.2f}")


if __name__ == "__main__":
    main()
