#!/usr/bin/env python3
"""
Benchmark Maya ASR TDT 1.1B (Hybrid FastConformer RNNT/TDT) on indic-asr-benchmark-6k.

Outputs: metrics.json, sample_analysis.json, error_analysis.json
following BENCHMARK_SCHEMA.md v1 normalization.

Usage:
    python3 benchmark_maya_asr_tdt.py \
        --checkpoint /home/ubuntu/training/checkpoints/maya-asr-tdt-1.1b-ckpt-60000/model.ckpt \
        --config /home/ubuntu/training/maya-asr-hybrid-fastconformer-rnnt-stage1/configs/train/stage1_prod_8xh200.yaml \
        --checkpoint-name ckpt-60000
"""

import argparse
import json
import os
import re
import sys
import time
import unicodedata
from collections import Counter, defaultdict
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
import torch
from jiwer import wer as compute_wer, cer as compute_cer
from omegaconf import OmegaConf

os.environ["HF_AUDIO_BACKEND"] = "soundfile"

# ── Language conditioning (from train_prod.py) ──

LANG_TO_ID = {
    "hi": 0, "bn": 1, "ta": 2, "te": 3, "mr": 4, "gu": 5,
    "kn": 6, "ml": 7, "pa": 8, "or": 9, "as": 10, "en": 11,
}
# Map full language names to codes
LANG_NAME_TO_CODE = {
    "hindi": "hi", "bengali": "bn", "tamil": "ta", "telugu": "te",
    "marathi": "mr", "gujarati": "gu", "kannada": "kn", "malayalam": "ml",
    "punjabi": "pa", "odia": "or", "assamese": "as", "english": "en",
}
NUM_LANGUAGES = len(LANG_TO_ID)


class LanguageEmbedding(torch.nn.Module):
    def __init__(self, num_languages: int, embed_dim: int):
        super().__init__()
        self.embed = torch.nn.Embedding(num_languages, embed_dim)
        torch.nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)

    def forward(self, lang_ids: torch.Tensor) -> torch.Tensor:
        return self.embed(lang_ids).unsqueeze(2)

# ── Normalization pipeline (v1, matching compute_metrics.py) ──

ZW_CHARS = re.compile(r'[\u200b-\u200f\u2028-\u202f\ufeff\u00ad]')
PUNCT_COMMON = re.compile(r'[!"#$%&\'()*+,\-./:;<=>?@\[\\\]^_`{|}~]')
PUNCT_INDIC = re.compile(r'[\u0964\u0965\u0970\u0971]')
PUNCT_EXTENDED = re.compile(r'[\u2010-\u2027\u2030-\u205e\u2e00-\u2e4f]')
QUOTES = re.compile(r'[\u2018\u2019\u201a\u201b\u201c\u201d\u201e\u201f\u00ab\u00bb]')
ANNOTATIONS = re.compile(r'\([^)]*\)')

INDIC_DIGIT_OFFSETS = [
    0x0966, 0x09E6, 0x0A66, 0x0AE6, 0x0B66,
    0x0BE6, 0x0C66, 0x0CE6, 0x0D66,
]


def norm_raw(text: str) -> str:
    return unicodedata.normalize('NFC', text).strip()


def norm_standard(text: str) -> str:
    text = unicodedata.normalize('NFKC', text)
    text = ZW_CHARS.sub('', text)
    text = ANNOTATIONS.sub('', text)
    text = ' '.join(text.split())
    text = QUOTES.sub("'", text)
    text = text.replace('\u2014', '-').replace('\u2013', '-').replace('\u2012', '-')
    text = PUNCT_COMMON.sub('', text)
    text = PUNCT_INDIC.sub('', text)
    text = PUNCT_EXTENDED.sub('', text)
    text = text.lower()
    text = ' '.join(text.split()).strip()
    return text


def norm_numcanon(text: str) -> str:
    text = norm_standard(text)
    text = re.sub(r'(\d),(\d)', r'\1\2', text)
    for offset in INDIC_DIGIT_OFFSETS:
        for i in range(10):
            text = text.replace(chr(offset + i), str(i))
    return text


def norm_mer(text: str) -> str:
    """MER: norm_standard + remove ALL spaces -> single character stream."""
    return norm_standard(text).replace(' ', '')


# ── space_norm_wer computation ──

def _levenshtein_backtrace(ref_chars: str, hyp_chars: str):
    """Compute Levenshtein DP and backtrace to find edit positions in ref."""
    n, m = len(ref_chars), len(hyp_chars)
    # DP table
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n + 1):
        dp[i][0] = i
    for j in range(m + 1):
        dp[0][j] = j
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if ref_chars[i - 1] == hyp_chars[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i - 1][j - 1], dp[i - 1][j], dp[i][j - 1])

    # Backtrace to find which ref positions have edits
    edited_ref_positions = set()
    i, j = n, m
    while i > 0 or j > 0:
        if i > 0 and j > 0 and ref_chars[i - 1] == hyp_chars[j - 1]:
            i -= 1; j -= 1  # match
        elif i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + 1:
            edited_ref_positions.add(i - 1)  # substitution
            i -= 1; j -= 1
        elif i > 0 and dp[i][j] == dp[i - 1][j] + 1:
            edited_ref_positions.add(i - 1)  # deletion from ref
            i -= 1
        elif j > 0 and dp[i][j] == dp[i][j - 1] + 1:
            # insertion — attribute to nearest ref char
            if i > 0:
                edited_ref_positions.add(i - 1)
            elif n > 0:
                edited_ref_positions.add(0)
            j -= 1
        else:
            break
    return edited_ref_positions


def compute_space_norm_wer_sample(ref_norm: str, hyp_norm: str):
    """Compute space_norm_wer for a single sample.
    Returns (error_words, total_words).
    """
    ref_words = ref_norm.split()
    if not ref_words:
        return (0, 0)

    ref_nospace = ref_norm.replace(' ', '')
    hyp_nospace = hyp_norm.replace(' ', '')

    total_words = len(ref_words)

    if ref_nospace == hyp_nospace:
        return (0, total_words)

    if not ref_nospace:
        return (0, 0)
    if not hyp_nospace:
        return (total_words, total_words)

    # Build char_to_word mapping
    char_to_word = []
    for wi, word in enumerate(ref_words):
        for _ in word:
            char_to_word.append(wi)

    # Get edited ref char positions
    edited_positions = _levenshtein_backtrace(ref_nospace, hyp_nospace)

    # Map to words
    touched_words = set()
    for pos in edited_positions:
        if pos < len(char_to_word):
            touched_words.add(char_to_word[pos])

    return (len(touched_words), total_words)


# ── Metric computation ──

def compute_metrics_for_samples(samples: list) -> dict:
    by_lang = defaultdict(list)
    for s in samples:
        by_lang[s['language']].append(s)

    result = {}
    all_ref_raw, all_hyp_raw = [], []
    all_ref_norm, all_hyp_norm = [], []
    all_ref_nc, all_hyp_nc = [], []
    all_ref_mer, all_hyp_mer = [], []
    all_snw_errors, all_snw_total = 0, 0

    for lang in sorted(by_lang.keys()):
        lang_samples = by_lang[lang]
        refs_raw, hyps_raw = [], []
        refs_norm, hyps_norm = [], []
        refs_nc, hyps_nc = [], []
        refs_mer, hyps_mer = [], []
        lang_snw_errors, lang_snw_total = 0, 0
        empty_count = 0

        for s in lang_samples:
            ref = s['reference']
            hyp = s.get('hypothesis', '') or ''

            if not hyp.strip():
                empty_count += 1

            rr, hr = norm_raw(ref), norm_raw(hyp)
            rn, hn = norm_standard(ref), norm_standard(hyp)
            rnc, hnc = norm_numcanon(ref), norm_numcanon(hyp)
            rm, hm = norm_mer(ref), norm_mer(hyp)

            if not rr.strip():
                continue

            refs_raw.append(rr); hyps_raw.append(hr if hr.strip() else '<empty>')
            refs_norm.append(rn if rn else '<empty>'); hyps_norm.append(hn if hn else '<empty>')
            refs_nc.append(rnc if rnc else '<empty>'); hyps_nc.append(hnc if hnc else '<empty>')
            refs_mer.append(rm if rm else '<empty>'); hyps_mer.append(hm if hm else '<empty>')

            # space_norm_wer per sample
            ew, tw = compute_space_norm_wer_sample(rn if rn else '', hn if hn else '')
            lang_snw_errors += ew
            lang_snw_total += tw

        w_raw = compute_wer(refs_raw, hyps_raw) * 100
        w_norm = compute_wer(refs_norm, hyps_norm) * 100
        w_nc = compute_wer(refs_nc, hyps_nc) * 100
        w_mer = compute_cer(refs_mer, hyps_mer) * 100
        c_norm = compute_cer(refs_norm, hyps_norm) * 100
        w_snw = (lang_snw_errors / lang_snw_total * 100) if lang_snw_total > 0 else 0.0

        result[lang] = {
            'n_samples': len(lang_samples),
            'wer_raw': round(w_raw, 2),
            'wer_norm': round(w_norm, 2),
            'wer_numcanon': round(w_nc, 2),
            'space_norm_wer': round(w_snw, 2),
            'mer': round(w_mer, 2),
            'cer_norm': round(c_norm, 2),
            'empty_hypotheses': empty_count,
            'normalization_delta': {
                'raw_to_norm': round(w_raw - w_norm, 2),
                'norm_to_numcanon': round(w_norm - w_nc, 2),
                'norm_to_space_norm': round(w_norm - w_snw, 2),
                'norm_to_mer': round(w_norm - w_mer, 2),
            }
        }

        all_ref_raw.extend(refs_raw); all_hyp_raw.extend(hyps_raw)
        all_ref_norm.extend(refs_norm); all_hyp_norm.extend(hyps_norm)
        all_ref_nc.extend(refs_nc); all_hyp_nc.extend(hyps_nc)
        all_ref_mer.extend(refs_mer); all_hyp_mer.extend(hyps_mer)
        all_snw_errors += lang_snw_errors; all_snw_total += lang_snw_total

    overall_snw = (all_snw_errors / all_snw_total * 100) if all_snw_total > 0 else 0.0
    result['__overall__'] = {
        'n_samples': len(samples),
        'wer_raw': round(compute_wer(all_ref_raw, all_hyp_raw) * 100, 2),
        'wer_norm': round(compute_wer(all_ref_norm, all_hyp_norm) * 100, 2),
        'wer_numcanon': round(compute_wer(all_ref_nc, all_hyp_nc) * 100, 2),
        'space_norm_wer': round(overall_snw, 2),
        'mer': round(compute_cer(all_ref_mer, all_hyp_mer) * 100, 2),
        'cer_norm': round(compute_cer(all_ref_norm, all_hyp_norm) * 100, 2),
    }

    lang_keys = [k for k in result if not k.startswith('_')]
    n_langs = len(lang_keys)
    result['__macro_avg__'] = {
        'n_languages': n_langs,
        'wer_raw': round(sum(result[l]['wer_raw'] for l in lang_keys) / n_langs, 2),
        'wer_norm': round(sum(result[l]['wer_norm'] for l in lang_keys) / n_langs, 2),
        'wer_numcanon': round(sum(result[l]['wer_numcanon'] for l in lang_keys) / n_langs, 2),
        'space_norm_wer': round(sum(result[l]['space_norm_wer'] for l in lang_keys) / n_langs, 2),
        'mer': round(sum(result[l]['mer'] for l in lang_keys) / n_langs, 2),
        'cer_norm': round(sum(result[l]['cer_norm'] for l in lang_keys) / n_langs, 2),
    }

    return result


def build_sample_analysis(samples: list) -> list:
    out = []
    for s in samples:
        ref = s['reference']
        hyp = s.get('hypothesis', '') or ''
        rn = norm_standard(ref)
        hn = norm_standard(hyp)
        rnc = norm_numcanon(ref)
        hnc = norm_numcanon(hyp)
        rm = norm_mer(ref)
        hm = norm_mer(hyp)

        # Per-sample WER
        rr = norm_raw(ref)
        hr = norm_raw(hyp)
        try:
            w_raw = compute_wer([rr], [hr if hr.strip() else '<empty>']) * 100
            w_norm = compute_wer([rn if rn else '<empty>'], [hn if hn else '<empty>']) * 100
            w_mer = compute_cer([rm if rm else '<empty>'], [hm if hm else '<empty>']) * 100
            ew, tw = compute_space_norm_wer_sample(rn if rn else '', hn if hn else '')
            w_snw = (ew / tw * 100) if tw > 0 else 0.0
        except Exception:
            w_raw, w_norm, w_mer, w_snw = 100.0, 100.0, 100.0, 100.0

        flags = []
        if ref == hyp:
            flags.append('exact_match')
        if rn == hn:
            flags.append('exact_match_norm')
        if ref != hyp and rn == hn:
            flags.append('punctuation_only_diff')
        if not hyp.strip():
            flags.append('empty_hypothesis')
        if re.search(r'\d', ref) or re.search(r'\d', hyp):
            if rn != hn:
                flags.append('numeric_mismatch')
        if w_norm > 80:
            flags.append('high_wer')
        if w_norm > w_mer + 1.0:
            flags.append('spacing_error')

        out.append({
            'id': s['id'],
            'language': s['language'],
            'reference': ref,
            'hypothesis': hyp,
            'ref_norm': rn,
            'hyp_norm': hn,
            'ref_numcanon': rnc,
            'hyp_numcanon': hnc,
            'ref_mer': rm,
            'hyp_mer': hm,
            'wer_raw': round(w_raw, 2),
            'wer_norm': round(w_norm, 2),
            'space_norm_wer': round(w_snw, 2),
            'mer': round(w_mer, 2),
            'flags': flags,
        })
    return out


def build_error_analysis(samples: list) -> dict:
    by_lang = defaultdict(list)
    for s in samples:
        by_lang[s['language']].append(s)

    result = {}
    lang_wer_norms = {}

    for lang in sorted(by_lang.keys()):
        lang_samples = by_lang[lang]
        subs = Counter()
        ins = Counter()
        dels = Counter()
        numeric_mm = 0
        punct_only = 0
        spacing_tok = 0
        entity_mm = 0
        script_conf = 0
        empty_hyp = 0

        per_sample_wer = []

        for s in lang_samples:
            ref = s['reference']
            hyp = s.get('hypothesis', '') or ''
            rn = norm_standard(ref)
            hn = norm_standard(hyp)

            if not hyp.strip():
                empty_hyp += 1

            # Compute per-sample WER for sorting
            try:
                sw = compute_wer([rn if rn else '<empty>'], [hn if hn else '<empty>']) * 100
            except Exception:
                sw = 100.0
            per_sample_wer.append((s['id'], sw))

            ref_words = rn.split()
            hyp_words = hn.split()

            # Simple alignment-free error counting via set difference
            ref_counter = Counter(ref_words)
            hyp_counter = Counter(hyp_words)
            for w in ref_counter:
                diff = ref_counter[w] - hyp_counter.get(w, 0)
                if diff > 0:
                    dels[w] += diff
            for w in hyp_counter:
                diff = hyp_counter[w] - ref_counter.get(w, 0)
                if diff > 0:
                    ins[w] += diff

            # Substitution pairs (aligned by position where possible)
            min_len = min(len(ref_words), len(hyp_words))
            for i in range(min_len):
                if ref_words[i] != hyp_words[i]:
                    subs[(ref_words[i], hyp_words[i])] += 1

            # Error buckets
            if re.search(r'\d', ref) or re.search(r'\d', hyp):
                if rn != hn:
                    numeric_mm += 1
            if ref != hyp and rn == hn:
                punct_only += 1

        per_sample_wer.sort(key=lambda x: x[1])
        worst = [sid for sid, _ in per_sample_wer[-3:]][::-1]
        best = [sid for sid, _ in per_sample_wer[:3]]

        numeric_samples = [s['id'] for s in lang_samples
                          if (re.search(r'\d', s['reference']) or re.search(r'\d', s.get('hypothesis', '') or ''))
                          and norm_standard(s['reference']) != norm_standard(s.get('hypothesis', '') or '')][:3]

        lang_wer_norms[lang] = sum(w for _, w in per_sample_wer) / len(per_sample_wer) if per_sample_wer else 100.0

        result[lang] = {
            'top_substitutions': [
                {'ref': r, 'hyp': h, 'count': c}
                for (r, h), c in subs.most_common(20)
            ],
            'top_insertions': [
                {'word': w, 'count': c}
                for w, c in ins.most_common(20)
            ],
            'top_deletions': [
                {'word': w, 'count': c}
                for w, c in dels.most_common(20)
            ],
            'error_buckets': {
                'numeric_mismatch_count': numeric_mm,
                'punctuation_only_count': punct_only,
                'spacing_tokenization_count': spacing_tok,
                'entity_mismatch_count': entity_mm,
                'script_confusion_count': script_conf,
                'empty_hypothesis_count': empty_hyp,
            },
            'examples': {
                'worst_samples': worst,
                'best_samples': best,
                'numeric_mismatch_samples': numeric_samples,
            }
        }

    # Summary
    sorted_langs = sorted(lang_wer_norms.items(), key=lambda x: x[1])
    worst_langs = [l for l, _ in sorted_langs[-3:]][::-1]
    best_langs = [l for l, _ in sorted_langs[:3]]

    # Determine diagnosis
    total_punct_only = sum(result[l]['error_buckets']['punctuation_only_count'] for l in result if not l.startswith('_'))
    total_numeric = sum(result[l]['error_buckets']['numeric_mismatch_count'] for l in result if not l.startswith('_'))
    total_samples = len(samples)

    if total_punct_only / total_samples > 0.3:
        diagnosis = 'formatting-limited'
        primary_source = 'formatting'
    elif total_numeric / total_samples > 0.2:
        diagnosis = 'numeric-limited'
        primary_source = 'numeric'
    else:
        diagnosis = 'recognition-limited'
        primary_source = 'recognition'

    numeric_impact = 'high' if total_numeric / total_samples > 0.15 else ('moderate' if total_numeric / total_samples > 0.05 else 'low')
    formatting_impact = 'high' if total_punct_only / total_samples > 0.2 else ('moderate' if total_punct_only / total_samples > 0.05 else 'low')

    result['__summary__'] = {
        'model_diagnosis': diagnosis,
        'primary_error_source': primary_source,
        'numeric_verbalization_impact': numeric_impact,
        'formatting_impact': formatting_impact,
        'worst_languages': worst_langs,
        'best_languages': best_langs,
    }

    return result


# ── Model loading and inference ──

def load_model(config_path: str, checkpoint_path: str):
    import nemo.collections.asr as nemo_asr

    print(f"Loading config from {config_path}")
    cfg = OmegaConf.load(config_path)

    # Fix tokenizer path to be absolute (relative to repo root, not config dir)
    repo_dir = str(Path(config_path).resolve().parent.parent.parent)
    tok_dir = cfg.model.tokenizer.dir
    if not os.path.isabs(tok_dir):
        cfg.model.tokenizer.dir = os.path.join(repo_dir, tok_dir)
    # Ensure the path exists
    assert os.path.isdir(cfg.model.tokenizer.dir), f"Tokenizer dir not found: {cfg.model.tokenizer.dir}"
    print(f"Tokenizer dir: {cfg.model.tokenizer.dir}")

    # Need a dummy manifest for model init - create one
    dummy_manifest = Path("/tmp/dummy_manifest.jsonl")
    if not dummy_manifest.exists():
        dummy_manifest.write_text('{"audio_filepath": "/dev/null", "text": "dummy", "duration": 1.0}\n')
    cfg.model.train_ds.manifest_filepath = str(dummy_manifest)
    cfg.model.validation_ds.manifest_filepath = str(dummy_manifest)

    # Override trainer for single GPU inference
    cfg.trainer.devices = 1
    cfg.trainer.accelerator = "gpu"

    print("Initializing model from config...")
    model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=None)

    print(f"Loading checkpoint from {checkpoint_path}")
    # Extract just the state_dict from the Lightning checkpoint to avoid
    # pickle issues with missing NeMo modules from the training environment.
    # Lightning checkpoints are ZIP files containing data.pkl + tensors.
    import zipfile
    import io
    import pickle

    class _SafeUnpickler(pickle.Unpickler):
        """Unpickler that creates stubs for any missing classes."""
        _stub_cache = {}

        def find_class(self, module, name):
            try:
                return super().find_class(module, name)
            except (ModuleNotFoundError, AttributeError):
                key = f"{module}.{name}"
                if key not in self._stub_cache:
                    class Stub:
                        def __init__(self, *a, **kw): pass
                        def __setstate__(self, state):
                            self.__dict__.update(state if isinstance(state, dict) else {})
                    Stub.__name__ = name
                    Stub.__qualname__ = name
                    self._stub_cache[key] = Stub
                return self._stub_cache[key]

    # Monkey-patch torch's _load to use our unpickler
    import torch.serialization as _ts
    _orig_rebuild = _ts._load

    def _patched_load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **kwargs):
        # Override pickle_module with our safe version
        class SafePickleModule:
            Unpickler = _SafeUnpickler
            # Pass through everything else
            load = pickle.load
            dumps = pickle.dumps
            loads = pickle.loads
        return _orig_rebuild(zip_file, map_location, SafePickleModule, pickle_file, **kwargs)

    _ts._load = _patched_load
    try:
        ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
    finally:
        _ts._load = _orig_rebuild
    # Register language embedding module BEFORE loading state_dict
    # so the _lang_embed_module weights get loaded from checkpoint
    lang_embed = LanguageEmbedding(NUM_LANGUAGES, 1024)
    model._lang_embed = lang_embed
    model.register_module('_lang_embed_module', lang_embed)

    missing, unexpected = model.load_state_dict(ckpt["state_dict"], strict=False)
    print(f"  Loaded state_dict: {len(missing)} missing, {len(unexpected)} unexpected keys")
    if missing:
        print(f"  Missing (first 5): {missing[:5]}")

    # Patch encoder to inject language bias (same as training)
    original_encoder_forward = model.encoder.forward

    def encoder_forward_with_lang(audio_signal, length):
        encoded, encoded_len = original_encoder_forward(audio_signal=audio_signal, length=length)
        if (hasattr(model, '_current_lang_ids') and hasattr(model, '_lang_embed')
                and model._current_lang_ids is not None):
            lang_ids = model._current_lang_ids
            if lang_ids.shape[0] == encoded.shape[0]:
                lang_bias = model._lang_embed(lang_ids)
                encoded = encoded + lang_bias
            model._current_lang_ids = None
        return encoded, encoded_len

    model.encoder.forward = encoder_forward_with_lang
    print("  Language conditioning enabled (encoder patched)")

    model.eval()
    model.cuda()
    model = model.to(torch.bfloat16)
    print("Model loaded and ready for inference.")
    return model


def transcribe_dataset(model, dataset, batch_size=32):
    """Transcribe all samples using NeMo's transcribe API with temp wav files."""
    import soundfile as sf
    import tempfile
    import shutil

    results = []
    total = len(dataset)
    tmpdir = tempfile.mkdtemp(prefix="maya_asr_bench_")

    try:
        # Process in batches
        for batch_start in range(0, total, batch_size):
            batch_end = min(batch_start + batch_size, total)
            batch = dataset[batch_start:batch_end]

            # Write temp wav files for this batch
            wav_paths = []
            for i in range(len(batch['id'])):
                audio = batch['audio'][i]
                arr = np.array(audio['array'], dtype=np.float32)
                sr = audio['sampling_rate']
                wav_path = os.path.join(tmpdir, f"sample_{batch_start + i}.wav")
                sf.write(wav_path, arr, sr)
                wav_paths.append(wav_path)

            # Set language IDs for this batch
            lang_ids = []
            for i in range(len(batch['id'])):
                lang_name = batch['language'][i]
                lang_code = LANG_NAME_TO_CODE.get(lang_name, 'en')
                lang_ids.append(LANG_TO_ID.get(lang_code, 11))
            model._current_lang_ids = torch.tensor(lang_ids, dtype=torch.long, device='cuda')

            # Transcribe batch
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                transcriptions = model.transcribe(wav_paths, batch_size=batch_size)

            # Handle NeMo return format (may be list or tuple)
            if isinstance(transcriptions, tuple):
                transcriptions = transcriptions[0]

            for i, hyp in enumerate(transcriptions):
                # NeMo may return Hypothesis objects — extract .text
                if isinstance(hyp, str):
                    text = hyp
                elif hasattr(hyp, 'text'):
                    text = hyp.text
                else:
                    text = str(hyp)
                results.append({
                    'id': batch['id'][i],
                    'language': batch['language'][i],
                    'reference': batch['reference'][i],
                    'hypothesis': text,
                })

            # Clean up wav files for this batch
            for p in wav_paths:
                os.unlink(p)

            elapsed_samples = batch_end
            print(f"  Transcribed {elapsed_samples}/{total} samples ({elapsed_samples/total*100:.1f}%)")

    finally:
        shutil.rmtree(tmpdir, ignore_errors=True)

    return results


def main():
    parser = argparse.ArgumentParser(description="Benchmark Maya ASR TDT 1.1B")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to .ckpt file")
    parser.add_argument("--config", type=str, required=True, help="Path to training YAML config")
    parser.add_argument("--checkpoint-name", type=str, default="ckpt-60000")
    parser.add_argument("--model-id", type=str, default="parakeet-tdt-1.1b-lang")
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--output-dir", type=str, default="/home/ubuntu/training/benchmark_outputs")
    args = parser.parse_args()

    out_dir = Path(args.output_dir) / args.model_id / args.checkpoint_name
    out_dir.mkdir(parents=True, exist_ok=True)

    # Load model
    model = load_model(args.config, args.checkpoint)

    # Load dataset
    print("Loading benchmark dataset...")
    from datasets import load_dataset, Audio
    ds = load_dataset(
        "BayAreaBoys/indic-asr-benchmark-6k",
        split="train",
        token=os.environ.get("HF_TOKEN", True),
    )
    ds = ds.cast_column("audio", Audio(sampling_rate=16000, decode=True))
    print(f"  Loaded {len(ds)} samples across {len(set(ds['language']))} languages")

    total_audio_sec = sum(ds['duration'])
    print(f"  Total audio: {total_audio_sec:.1f}s ({total_audio_sec/3600:.2f}h)")

    # Run inference (with caching to avoid re-running on crash)
    predictions_cache = out_dir / '_predictions_cache.json'
    if predictions_cache.exists():
        print(f"\nLoading cached predictions from {predictions_cache}")
        with open(predictions_cache) as f:
            cache = json.load(f)
        samples = cache['samples']
        inference_time = cache['inference_time']
        print(f"  Loaded {len(samples)} cached predictions (original inference: {inference_time:.1f}s)")
    else:
        print("\nStarting transcription...")
        t0 = time.time()
        samples = transcribe_dataset(model, ds, batch_size=args.batch_size)
        inference_time = time.time() - t0
        print(f"\nTranscription complete in {inference_time:.1f}s (RTF: {inference_time/total_audio_sec:.4f})")
        # Cache predictions
        with open(predictions_cache, 'w') as f:
            json.dump({'samples': samples, 'inference_time': inference_time}, f, ensure_ascii=False)
        print(f"  Cached predictions to {predictions_cache}")

    # Compute metrics
    print("\nComputing metrics...")
    metrics = compute_metrics_for_samples(samples)

    # Add meta
    gpu_name = "unknown"
    try:
        gpu_name = torch.cuda.get_device_name(0)
    except Exception:
        pass

    metrics['__meta__'] = {
        'checkpoint': args.checkpoint,
        'checkpoint_name': args.checkpoint_name,
        'model_id': args.model_id,
        'model_type': 'maya-asr-hybrid-tdt-1.1b',
        'dataset': 'BayAreaBoys/indic-asr-benchmark-6k',
        'batch_size': args.batch_size,
        'inference_time_sec': round(inference_time, 2),
        'total_audio_sec': round(total_audio_sec, 2),
        'rtf': round(inference_time / total_audio_sec, 4),
        'timestamp': datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ'),
        'gpu': gpu_name,
        'framework': 'nemo',
        'normalization_version': 'v1',
        'jiwer_version': getattr(__import__('jiwer'), '__version__', 'unknown'),
    }

    # Print summary
    o = metrics['__overall__']
    print(f"\n{'='*80}")
    print(f"OVERALL: wer_norm={o['wer_norm']:.2f}%  space_norm={o['space_norm_wer']:.2f}%  "
          f"mer={o['mer']:.2f}%  cer_norm={o['cer_norm']:.2f}%")
    print(f"{'='*80}")

    print(f"\n{'Language':<15} {'WER Raw':>8} {'WER Norm':>9} {'SpaceNorm':>10} {'MER':>7} {'CER Norm':>9}")
    print('-' * 65)
    for lang in sorted(k for k in metrics if not k.startswith('_')):
        m = metrics[lang]
        print(f"{lang:<15} {m['wer_raw']:>7.2f}% {m['wer_norm']:>8.2f}% {m['space_norm_wer']:>9.2f}% {m['mer']:>6.2f}% {m['cer_norm']:>8.2f}%")

    # Build sample analysis
    print("\nBuilding sample analysis...")
    sample_analysis = build_sample_analysis(samples)

    # Build error analysis
    print("Building error analysis...")
    error_analysis = build_error_analysis(samples)

    # Write outputs
    with open(out_dir / 'metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    with open(out_dir / 'sample_analysis.json', 'w') as f:
        json.dump(sample_analysis, f, indent=2, ensure_ascii=False)

    with open(out_dir / 'error_analysis.json', 'w') as f:
        json.dump(error_analysis, f, indent=2, ensure_ascii=False)

    print(f"\nResults written to {out_dir}/")
    print(f"  - metrics.json")
    print(f"  - sample_analysis.json")
    print(f"  - error_analysis.json")


if __name__ == "__main__":
    main()
