"""
Compute space_norm_wer for all models.

space_norm_wer: Word error rate after forgiving whitespace boundaries by doing
character alignment on space-stripped text and then counting how many reference
words are touched by real content edits.

This is NOT MER (CER on joined text). This is a word-level metric.
"""
import json
import os
import re
import unicodedata
from collections import defaultdict
from pathlib import Path

# === Normalization (same v1 pipeline) ===

ZW_CHARS = re.compile(r'[\u200b-\u200f\u2028-\u202f\ufeff\u00ad]')
PUNCT_COMMON = re.compile(r'[!"#$%&\'()*+,\-./:;<=>?@\[\\\]^_`{|}~]')
PUNCT_INDIC = re.compile(r'[\u0964\u0965\u0970\u0971]')
PUNCT_EXTENDED = re.compile(r'[\u2010-\u2027\u2030-\u205e\u2e00-\u2e4f]')
QUOTES = re.compile(r'[\u2018\u2019\u201a\u201b\u201c\u201d\u201e\u201f\u00ab\u00bb]')
ANNOTATIONS = re.compile(r'\([^)]*\)')

def norm_standard(text: str) -> str:
    text = unicodedata.normalize('NFKC', text)
    text = ZW_CHARS.sub('', text)
    text = ANNOTATIONS.sub('', text)
    text = ' '.join(text.split())
    text = QUOTES.sub("'", text)
    text = text.replace('\u2014', '-').replace('\u2013', '-').replace('\u2012', '-')
    text = PUNCT_COMMON.sub('', text)
    text = PUNCT_INDIC.sub('', text)
    text = PUNCT_EXTENDED.sub('', text)
    text = text.lower()
    return ' '.join(text.split()).strip()


# === space_norm_wer core ===

def space_norm_wer_sample(ref_norm: str, hyp_norm: str) -> tuple[int, int]:
    """
    Compute space_norm_wer for a single sample.
    Returns (error_words, total_words).
    """
    ref_words = ref_norm.split()
    if not ref_words:
        return (0, 0)

    ref_nospace = ref_norm.replace(' ', '')
    hyp_nospace = hyp_norm.replace(' ', '')

    total_words = len(ref_words)

    # Fast path: identical after space removal = pure spacing difference
    if ref_nospace == hyp_nospace:
        return (0, total_words)

    # Build char_to_word mapping: ref_nospace[i] belongs to ref_words[word_idx]
    char_to_word = []
    for word_idx, word in enumerate(ref_words):
        for _ in word:
            char_to_word.append(word_idx)

    n = len(ref_nospace)
    m = len(hyp_nospace)

    # Levenshtein DP
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n + 1):
        dp[i][0] = i
    for j in range(m + 1):
        dp[0][j] = j

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref_nospace[i - 1] == hyp_nospace[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(
                    dp[i - 1][j - 1],  # substitution
                    dp[i - 1][j],       # deletion from ref
                    dp[i][j - 1],       # insertion in hyp
                )

    # Backtrack to find which reference words are touched by edits
    touched_words = set()
    i, j = n, m

    while i > 0 or j > 0:
        if i > 0 and j > 0 and ref_nospace[i - 1] == hyp_nospace[j - 1]:
            # Match — no edit
            i -= 1
            j -= 1
        elif i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + 1:
            # Substitution
            touched_words.add(char_to_word[i - 1])
            i -= 1
            j -= 1
        elif i > 0 and dp[i][j] == dp[i - 1][j] + 1:
            # Deletion from reference
            touched_words.add(char_to_word[i - 1])
            i -= 1
        elif j > 0 and dp[i][j] == dp[i][j - 1] + 1:
            # Insertion in hypothesis — assign to nearest reference word
            if i > 0:
                touched_words.add(char_to_word[i - 1])  # previous ref word
            elif i < n:
                touched_words.add(char_to_word[i])  # next ref word
            j -= 1
        else:
            # Fallback (shouldn't happen with correct DP)
            if j > 0:
                j -= 1
            elif i > 0:
                i -= 1

    return (len(touched_words), total_words)


# === Process all models ===

LANGUAGES = ['assamese', 'bengali', 'english', 'gujarati', 'hindi', 'kannada',
             'malayalam', 'marathi', 'odia', 'punjabi', 'tamil', 'telugu']

BASE = Path('/home/ubuntu/training/benchmark_outputs')

for model_dir in sorted(BASE.iterdir()):
    if not model_dir.is_dir():
        continue
    for ckpt_dir in sorted(model_dir.iterdir()):
        sa_path = ckpt_dir / 'sample_analysis.json'
        metrics_path = ckpt_dir / 'metrics.json'
        if not sa_path.exists() or not metrics_path.exists():
            continue

        samples = json.load(open(sa_path))
        hyp_key = 'hypothesis' if 'hypothesis' in samples[0] else 'transcript'

        by_lang = defaultdict(lambda: {'err': 0, 'total': 0})
        overall_err = 0
        overall_total = 0

        for s in samples:
            lang = s['language']
            ref = s['reference']
            hyp = s.get(hyp_key, '') or ''

            ref_norm = norm_standard(ref)
            hyp_norm = norm_standard(hyp)

            if not ref_norm or not hyp_norm:
                continue

            err, tot = space_norm_wer_sample(ref_norm, hyp_norm)
            by_lang[lang]['err'] += err
            by_lang[lang]['total'] += tot
            overall_err += err
            overall_total += tot

        # Update metrics.json
        metrics = json.load(open(metrics_path))

        for lang in LANGUAGES:
            if lang in metrics and isinstance(metrics[lang], dict):
                e = by_lang[lang]['err']
                t = by_lang[lang]['total']
                metrics[lang]['space_norm_wer'] = round((e / t * 100) if t > 0 else 0, 2)

        metrics['__overall__']['space_norm_wer'] = round((overall_err / overall_total * 100) if overall_total > 0 else 0, 2)

        # Macro avg
        lang_vals = [metrics[l]['space_norm_wer'] for l in LANGUAGES if l in metrics and isinstance(metrics[l], dict) and 'space_norm_wer' in metrics[l]]
        if lang_vals:
            metrics['__macro_avg__']['space_norm_wer'] = round(sum(lang_vals) / len(lang_vals), 2)

        json.dump(metrics, open(metrics_path, 'w'), indent=2, ensure_ascii=False)

        o_snw = metrics['__overall__']['space_norm_wer']
        o_wn = metrics['__overall__'].get('wer_norm', 0)
        o_mer = metrics['__overall__'].get('mer', 0)
        gap = o_wn - o_snw
        print(f'{model_dir.name}/{ckpt_dir.name:<15}  wer_norm={o_wn:>6.2f}  space_norm_wer={o_snw:>6.2f}  mer={o_mer:>6.2f}  wer-snw={gap:>+6.2f}')
