"""
Recompute our models (qwen3-asr, gemma3n-e2b) using the SAME v1 normalization pipeline
so comparisons against API baselines are apples-to-apples.
"""
import json
from pathlib import Path
from datetime import datetime, timezone
from compute_metrics import compute_metrics_for_samples, build_sample_analysis, norm_raw, norm_standard

OUTPUT_BASE = Path('/home/ubuntu/training/benchmark_outputs')

OUR_MODELS = {
    'qwen3-asr': ['ckpt-24000', 'ckpt-72000'],
    'gemma3n-e2b': ['ckpt-10000', 'ckpt-20000'],
}

for model_id, checkpoints in OUR_MODELS.items():
    for ckpt in checkpoints:
        sa_path = OUTPUT_BASE / model_id / ckpt / 'sample_analysis.json'
        if not sa_path.exists():
            print(f'SKIP: {sa_path} not found')
            continue

        print(f'\nRecomputing {model_id}/{ckpt}...')
        with open(sa_path) as f:
            raw_samples = json.load(f)

        # Convert old sample format to match our compute pipeline
        samples = []
        for s in raw_samples:
            samples.append({
                'id': s['id'],
                'language': s['language'],
                'reference': s['reference'],
                'transcript': s['hypothesis'],
                'detected_lang': s.get('detected_language', ''),
            })

        metrics = compute_metrics_for_samples(samples)

        # Read existing meta
        old_metrics_path = OUTPUT_BASE / model_id / ckpt / 'metrics.json'
        old_meta = {}
        if old_metrics_path.exists():
            with open(old_metrics_path) as f:
                old_data = json.load(f)
                old_meta = old_data.get('__meta__', {})

        metrics['__meta__'] = {
            **old_meta,
            'model_id': model_id,
            'checkpoint_name': ckpt,
            'normalization_version': 'v1',
            'timestamp': datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ'),
        }

        # Write updated metrics
        with open(old_metrics_path, 'w') as f:
            json.dump(metrics, f, indent=2, ensure_ascii=False)

        # Rebuild sample_analysis with new schema
        new_samples = build_sample_analysis(samples)
        with open(sa_path, 'w') as f:
            json.dump(new_samples, f, indent=2, ensure_ascii=False)

        o = metrics['__overall__']
        print(f'  {model_id}/{ckpt}: wer_raw={o["wer_raw"]:.2f}  wer_norm={o["wer_norm"]:.2f}  wer_numcanon={o["wer_numcanon"]:.2f}  cer_norm={o["cer_norm"]:.2f}')
