"""
Compute all 4 metric tiers for API baseline models.
Normalization version: v1
"""
import json
import re
import unicodedata
import sys
from collections import defaultdict
from pathlib import Path
from datetime import datetime, timezone

from jiwer import wer as compute_wer, cer as compute_cer

# ── Normalization pipeline ──

# Zero-width characters to strip
ZW_CHARS = re.compile(r'[\u200b-\u200f\u2028-\u202f\ufeff\u00ad]')

# Punctuation sets
# Common punctuation (ASCII + Unicode general)
PUNCT_COMMON = re.compile(r'[!"#$%&\'()*+,\-./:;<=>?@\[\\\]^_`{|}~]')
# Indic punctuation (dandas, etc.) — keep single danda for segmentation? No, remove for norm.
PUNCT_INDIC = re.compile(r'[\u0964\u0965\u0970\u0971]')  # danda, double danda, abbreviation signs
# Extended punctuation
PUNCT_EXTENDED = re.compile(r'[\u2010-\u2027\u2030-\u205e\u2e00-\u2e4f]')
# Quotes normalization
QUOTES = re.compile(r'[\u2018\u2019\u201a\u201b\u201c\u201d\u201e\u201f\u00ab\u00bb]')

# Parenthetical annotations like "(coughing)", "(two seconds pause)"
ANNOTATIONS = re.compile(r'\([^)]*\)')

def norm_raw(text: str) -> str:
    """Step for wer_raw: NFC unicode + trim only."""
    text = unicodedata.normalize('NFC', text)
    text = text.strip()
    return text

def norm_standard(text: str) -> str:
    """Steps for wer_norm: NFKC + strip ZW + whitespace + punct + case fold."""
    # Step 1: NFKC
    text = unicodedata.normalize('NFKC', text)
    # Step 2: Strip zero-width chars
    text = ZW_CHARS.sub('', text)
    # Step 2.5: Remove parenthetical annotations
    text = ANNOTATIONS.sub('', text)
    # Step 3: Normalize whitespace
    text = ' '.join(text.split())
    # Step 4: Standardize quotes/dashes
    text = QUOTES.sub("'", text)
    text = text.replace('\u2014', '-').replace('\u2013', '-').replace('\u2012', '-')
    # Step 5: Remove punctuation
    text = PUNCT_COMMON.sub('', text)
    text = PUNCT_INDIC.sub('', text)
    text = PUNCT_EXTENDED.sub('', text)
    # Step 6: Case fold (affects English only, Indic scripts have no case)
    text = text.lower()
    # Final whitespace cleanup
    text = ' '.join(text.split()).strip()
    return text

def norm_mer(text: str) -> str:
    """Steps for MER: norm_standard + remove ALL spaces."""
    return norm_standard(text).replace(' ', '')

def norm_numcanon(text: str) -> str:
    """Steps for wer_numcanon: norm_standard + number canonicalization."""
    text = norm_standard(text)
    # Remove commas in numbers (e.g., "25,000" -> "25000")
    text = re.sub(r'(\d),(\d)', r'\1\2', text)
    # Normalize Devanagari digits to ASCII
    for offset, base in [(0x0966, '0'), (0x09E6, '0'), (0x0A66, '0'),
                          (0x0AE6, '0'), (0x0B66, '0'), (0x0BE6, '0'),
                          (0x0C66, '0'), (0x0CE6, '0'), (0x0D66, '0')]:
        for i in range(10):
            text = text.replace(chr(offset + i), str(i))
    return text

# ── Metric computation ──

def compute_metrics_for_samples(samples: list[dict]) -> dict:
    """Compute all 4 metric tiers for a list of samples."""
    by_lang = defaultdict(list)
    for s in samples:
        by_lang[s['language']].append(s)

    result = {}
    all_ref_raw, all_hyp_raw = [], []
    all_ref_norm, all_hyp_norm = [], []
    all_ref_nc, all_hyp_nc = [], []
    all_ref_mer, all_hyp_mer = [], []

    for lang in sorted(by_lang.keys()):
        lang_samples = by_lang[lang]
        refs_raw, hyps_raw = [], []
        refs_norm, hyps_norm = [], []
        refs_nc, hyps_nc = [], []
        refs_mer, hyps_mer = [], []
        empty_count = 0

        for s in lang_samples:
            ref = s['reference']
            hyp = s.get('transcript', '') or ''

            if not hyp.strip():
                empty_count += 1

            rr, hr = norm_raw(ref), norm_raw(hyp)
            rn, hn = norm_standard(ref), norm_standard(hyp)
            rnc, hnc = norm_numcanon(ref), norm_numcanon(hyp)
            rm, hm = norm_mer(ref), norm_mer(hyp)

            # Skip empty refs (shouldn't happen but safety)
            if not rr.strip():
                continue

            refs_raw.append(rr); hyps_raw.append(hr if hr.strip() else '<empty>')
            refs_norm.append(rn if rn else '<empty>'); hyps_norm.append(hn if hn else '<empty>')
            refs_nc.append(rnc if rnc else '<empty>'); hyps_nc.append(hnc if hnc else '<empty>')
            refs_mer.append(rm if rm else '<empty>'); hyps_mer.append(hm if hm else '<empty>')

        w_raw = compute_wer(refs_raw, hyps_raw) * 100
        w_norm = compute_wer(refs_norm, hyps_norm) * 100
        w_nc = compute_wer(refs_nc, hyps_nc) * 100
        c_norm = compute_cer(refs_norm, hyps_norm) * 100
        # MER: CER on space-stripped text (each sample is one "word")
        m_er = compute_cer(refs_mer, hyps_mer) * 100

        result[lang] = {
            'n_samples': len(lang_samples),
            'wer_raw': round(w_raw, 2),
            'wer_norm': round(w_norm, 2),
            'wer_numcanon': round(w_nc, 2),
            'mer': round(m_er, 2),
            'cer_norm': round(c_norm, 2),
            'empty_hypotheses': empty_count,
            'normalization_delta': {
                'raw_to_norm': round(w_raw - w_norm, 2),
                'norm_to_numcanon': round(w_norm - w_nc, 2),
                'norm_to_mer': round(w_norm - m_er, 2),
            }
        }

        all_ref_raw.extend(refs_raw); all_hyp_raw.extend(hyps_raw)
        all_ref_norm.extend(refs_norm); all_hyp_norm.extend(hyps_norm)
        all_ref_nc.extend(refs_nc); all_hyp_nc.extend(hyps_nc)
        all_ref_mer.extend(refs_mer); all_hyp_mer.extend(hyps_mer)

    # Overall (micro-average)
    result['__overall__'] = {
        'n_samples': len(samples),
        'wer_raw': round(compute_wer(all_ref_raw, all_hyp_raw) * 100, 2),
        'wer_norm': round(compute_wer(all_ref_norm, all_hyp_norm) * 100, 2),
        'wer_numcanon': round(compute_wer(all_ref_nc, all_hyp_nc) * 100, 2),
        'mer': round(compute_cer(all_ref_mer, all_hyp_mer) * 100, 2),
        'cer_norm': round(compute_cer(all_ref_norm, all_hyp_norm) * 100, 2),
    }

    # Macro average
    lang_keys = [k for k in result if not k.startswith('_')]
    n_langs = len(lang_keys)
    result['__macro_avg__'] = {
        'n_languages': n_langs,
        'wer_raw': round(sum(result[l]['wer_raw'] for l in lang_keys) / n_langs, 2),
        'wer_norm': round(sum(result[l]['wer_norm'] for l in lang_keys) / n_langs, 2),
        'wer_numcanon': round(sum(result[l]['wer_numcanon'] for l in lang_keys) / n_langs, 2),
        'mer': round(sum(result[l]['mer'] for l in lang_keys) / n_langs, 2),
        'cer_norm': round(sum(result[l]['cer_norm'] for l in lang_keys) / n_langs, 2),
    }

    return result


def build_sample_analysis(samples: list[dict]) -> list[dict]:
    """Build sample_analysis.json entries."""
    out = []
    for s in samples:
        ref = s['reference']
        hyp = s.get('transcript', '') or ''
        rn = norm_standard(ref)
        hn = norm_standard(hyp)
        rnc = norm_numcanon(ref)
        hnc = norm_numcanon(hyp)
        rm = norm_mer(ref)
        hm = norm_mer(hyp)

        flags = []
        if ref == hyp:
            flags.append('exact_match')
        if rn == hn:
            flags.append('exact_match_norm')
        if ref != hyp and rn == hn:
            flags.append('punctuation_only_diff')
        if not hyp.strip():
            flags.append('empty_hypothesis')
        # Numeric mismatch: digits present in either
        if re.search(r'\d', ref) or re.search(r'\d', hyp):
            if rn != hn:
                flags.append('numeric_mismatch')

        out.append({
            'id': s['id'],
            'language': s['language'],
            'reference': ref,
            'hypothesis': hyp,
            'ref_norm': rn,
            'hyp_norm': hn,
            'ref_numcanon': rnc,
            'hyp_numcanon': hnc,
            'ref_mer': rm,
            'hyp_mer': hm,
            'detected_language': s.get('detected_lang', ''),
            'flags': flags,
        })
    return out


# ── Model registry ──

MODELS = {
    'elevenlabs_full_6k.jsonl': {
        'model_id': 'elevenlabs-scribe-v2',
        'display_name': 'ElevenLabs Scribe v2',
        'model_type': 'elevenlabs-scribe-v2',
    },
    'elevenlabs_v1_full_6k.jsonl': {
        'model_id': 'elevenlabs-scribe-v1',
        'display_name': 'ElevenLabs Scribe v1',
        'model_type': 'elevenlabs-scribe-v1',
    },
    'gemini_3_flash_strict_full_6k.jsonl': {
        'model_id': 'gemini-3-flash-strict',
        'display_name': 'Gemini 3 Flash (Strict)',
        'model_type': 'gemini-3-flash-preview',
    },
    'gemini_gemini_3_flash_preview_full_6k.jsonl': {
        'model_id': 'gemini-3-flash-preview',
        'display_name': 'Gemini 3 Flash (Preview)',
        'model_type': 'gemini-3-flash-preview',
    },
    'gemini_flash_v2_full_6k.jsonl': {
        'model_id': 'gemini-2.5-flash',
        'display_name': 'Gemini 2.5 Flash',
        'model_type': 'gemini-2.5-flash',
    },
    'gemini_gemini_3_1_pro_preview_full_6k.jsonl': {
        'model_id': 'gemini-3.1-pro',
        'display_name': 'Gemini 3.1 Pro',
        'model_type': 'gemini-3.1-pro-preview',
    },
    'sarvam_full_6k.jsonl': {
        'model_id': 'sarvam-saaras-v3',
        'display_name': 'Sarvam Saaras v3',
        'model_type': 'saaras:v3',
    },
}

OUTPUT_BASE = Path('/home/ubuntu/training/benchmark_outputs')

def process_model(filename: str, info: dict):
    model_id = info['model_id']
    print(f'\n{"="*60}')
    print(f'Processing: {info["display_name"]} ({filename})')
    print(f'{"="*60}')

    with open(filename) as f:
        samples = [json.loads(line) for line in f]

    print(f'  Loaded {len(samples)} samples')

    # Compute metrics
    metrics = compute_metrics_for_samples(samples)

    # Add meta
    total_audio = sum(s.get('duration', 0) for s in samples)
    metrics['__meta__'] = {
        'checkpoint': f'api/{model_id}',
        'checkpoint_name': 'baseline',
        'model_id': model_id,
        'model_type': info['model_type'],
        'dataset': 'BayAreaBoys/indic-asr-benchmark-6k',
        'batch_size': 1,
        'inference_time_sec': 0,
        'total_audio_sec': round(total_audio, 2),
        'rtf': 0,
        'timestamp': datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ'),
        'normalization_version': 'v1',
        'framework': 'api',
    }

    # Print summary
    o = metrics['__overall__']
    print(f'  Overall: wer_raw={o["wer_raw"]:.2f}  wer_norm={o["wer_norm"]:.2f}  mer={o["mer"]:.2f}  cer_norm={o["cer_norm"]:.2f}')

    # Build sample analysis
    sample_analysis = build_sample_analysis(samples)

    # Write outputs
    out_dir = OUTPUT_BASE / model_id / 'baseline'
    out_dir.mkdir(parents=True, exist_ok=True)

    with open(out_dir / 'metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    with open(out_dir / 'sample_analysis.json', 'w') as f:
        json.dump(sample_analysis, f, indent=2, ensure_ascii=False)

    print(f'  Written to {out_dir}/')
    return metrics


if __name__ == '__main__':
    results = {}
    for filename, info in MODELS.items():
        fpath = Path(filename)
        if not fpath.exists():
            fpath = Path('/home/ubuntu/training/api_results') / filename
        if not fpath.exists():
            print(f'SKIP: {filename} not found')
            continue
        results[info['model_id']] = process_model(str(fpath), info)

    # Print comparison table
    print(f'\n\n{"="*80}')
    print('COMPARISON TABLE (all models)')
    print(f'{"="*80}')
    print(f'{"Model":<30} {"WER Raw":>8} {"WER Norm":>9} {"WER Num":>8} {"CER Norm":>9}')
    print('-' * 70)
    for mid, m in sorted(results.items(), key=lambda x: x[1]['__overall__']['wer_norm']):
        o = m['__overall__']
        print(f'{mid:<30} {o["wer_raw"]:>7.2f}% {o["wer_norm"]:>8.2f}% {o["wer_numcanon"]:>7.2f}% {o["cer_norm"]:>8.2f}%')
