#!/bin/bash
# Benchmark qwen3-asr-mixed ckpt-76000
# Step 1: Run inference with existing script
# Step 2: Recompute metrics with full schema (6 tiers)

set -e

CKPT_DIR="/home/ubuntu/training/checkpoints/qwen3-asr-mixed-ckpt-76000"
CKPT_NAME="ckpt-76000"
MODEL_ID="qwen3-asr-mixed"
OUTPUT_BASE="/home/ubuntu/training/benchmark_outputs"
PREDICTIONS_DIR="/home/ubuntu/training/benchmark_results/${CKPT_NAME}"

echo "=== Step 1: Running inference ==="
HF_TOKEN=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 python3 /home/ubuntu/training/benchmark_qwen3_asr.py \
  --checkpoint "$CKPT_DIR" \
  --checkpoint-name "$CKPT_NAME" \
  --output-dir /home/ubuntu/training/benchmark_results \
  --batch-size 64 \
  --backend transformers \
  --device cuda:0

echo "=== Step 2: Converting to schema format ==="
python3 -c "
import json, sys, os, re, unicodedata, time
from collections import defaultdict, Counter
from datetime import datetime, timezone
from pathlib import Path
from jiwer import wer as compute_wer, cer as compute_cer

# Import normalization and space_norm_wer from our benchmark script
sys.path.insert(0, '/home/ubuntu/training')
from benchmark_maya_asr_tdt import (
    norm_raw, norm_standard, norm_numcanon, norm_mer,
    compute_metrics_for_samples, build_sample_analysis, build_error_analysis,
    compute_space_norm_wer_sample,
)

# Load predictions
with open('${PREDICTIONS_DIR}/predictions.json') as f:
    preds = json.load(f)
print(f'Loaded {len(preds)} predictions')

# Compute metrics
metrics = compute_metrics_for_samples(preds)

# Read meta from original metrics
orig_meta_path = '${PREDICTIONS_DIR}/metrics.json'
orig_meta = {}
if os.path.exists(orig_meta_path):
    with open(orig_meta_path) as f:
        orig_meta = json.load(f).get('__meta__', {})

total_audio = sum(p.get('duration', 0) for p in preds)
metrics['__meta__'] = {
    'checkpoint': '${CKPT_DIR}',
    'checkpoint_name': '${CKPT_NAME}',
    'model_id': '${MODEL_ID}',
    'model_type': 'qwen3-asr-1.7B-mixed',
    'dataset': 'BayAreaBoys/indic-asr-benchmark-6k',
    'batch_size': 8,
    'inference_time_sec': orig_meta.get('inference_time_sec', 0),
    'total_audio_sec': round(total_audio, 2),
    'rtf': orig_meta.get('rtf', 0),
    'timestamp': datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ'),
    'gpu': 'NVIDIA A100-SXM4-80GB',
    'framework': 'transformers',
    'normalization_version': 'v1',
}

# Build sample analysis and error analysis
sample_analysis = build_sample_analysis(preds)
error_analysis = build_error_analysis(preds)

# Write outputs
out_dir = Path('${OUTPUT_BASE}/${MODEL_ID}/${CKPT_NAME}')
out_dir.mkdir(parents=True, exist_ok=True)

with open(out_dir / 'metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2, ensure_ascii=False)
with open(out_dir / 'sample_analysis.json', 'w') as f:
    json.dump(sample_analysis, f, indent=2, ensure_ascii=False)
with open(out_dir / 'error_analysis.json', 'w') as f:
    json.dump(error_analysis, f, indent=2, ensure_ascii=False)

o = metrics['__overall__']
print(f'Results: wer_norm={o[\"wer_norm\"]:.2f}%  space_norm={o[\"space_norm_wer\"]:.2f}%  mer={o[\"mer\"]:.2f}%  cer_norm={o[\"cer_norm\"]:.2f}%')
print(f'Written to {out_dir}/')
"

echo "=== Done ==="
