#!/usr/bin/env python3
"""
Benchmark Gemma3n-E2B-ASR checkpoints on indic-asr-benchmark-6k dataset.

Usage:
    python benchmark_gemma3n_asr.py \
        --checkpoint /home/ubuntu/training/checkpoints/gemma3n-e2b-ckpt-10000 \
        --checkpoint-name ckpt-10000 \
        --output-dir /home/ubuntu/training/benchmark_results_gemma3n \
        --batch-size 1

Metrics computed per language and overall:
    - WER  (Word Error Rate)
    - CER  (Character Error Rate)
    - WER-norm (WER after unicode/punctuation/space normalization)
    - CER-norm (CER after normalization)
"""

import argparse
import json
import os
import re
import sys
import time
import unicodedata
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
import torchaudio
from datasets import load_dataset as hf_load_dataset
from jiwer import cer, wer
from transformers import AutoModelForImageTextToText, AutoProcessor

# Language code → display name mapping
LANG_CODE_TO_NAME = {
    "as": "Assamese", "bn": "Bengali", "en": "English", "gu": "Gujarati",
    "hi": "Hindi", "kn": "Kannada", "ml": "Malayalam", "mr": "Marathi",
    "or": "Odia", "pa": "Punjabi", "ta": "Tamil", "te": "Telugu",
}


# ── Text normalization (standard ASR eval practice) ──────────────────────

def normalize_text(text: str) -> str:
    """Normalize text for fair WER/CER comparison.
    Unicode NFC → lowercase → remove punctuation → collapse whitespace → strip.
    """
    text = unicodedata.normalize("NFC", text)
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text, flags=re.UNICODE)
    text = re.sub(r"\s+", " ", text).strip()
    return text


def load_audio_array(sample, target_sr=16000):
    """Extract audio as numpy float32 array at target sample rate."""
    audio = sample["audio"]
    wav = np.array(audio["array"], dtype=np.float32)
    sr = audio["sampling_rate"]

    if sr != target_sr:
        wav_tensor = torch.from_numpy(wav).unsqueeze(0)
        wav_tensor = torchaudio.functional.resample(wav_tensor, sr, target_sr)
        wav = wav_tensor.squeeze(0).numpy()

    return wav


def run_inference(model, processor, dataset, batch_size: int = 16, max_new_tokens: int = 256):
    """Run batched inference on HF dataset and return list of result dicts."""
    results = []
    total = len(dataset)

    for start in range(0, total, batch_size):
        end = min(start + batch_size, total)
        batch = dataset.select(range(start, end))

        # Prepare batch: audio arrays + prompts
        wavs = []
        prompts = []
        batch_meta = []
        for sample in batch:
            lang_name = LANG_CODE_TO_NAME.get(sample["lang_code"], sample["language"].capitalize())
            user_content = f"<audio_soft_token>\nTranscribe the audio.\nLanguage: {lang_name}"
            messages = [{"role": "user", "content": user_content}]

            wavs.append(load_audio_array(sample))
            prompts.append(processor.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=False
            ))
            batch_meta.append({
                "id": sample["id"],
                "language": sample["language"],
                "lang_code": sample["lang_code"],
                "source": sample["source"],
                "duration": sample["duration"],
                "reference": sample["reference"],
            })

        try:
            # Process batch
            inputs = processor(
                text=prompts, audio=wavs, sampling_rate=16000,
                return_tensors="pt", padding=True,
            )
            inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v
                      for k, v in inputs.items()}
            input_len = inputs["input_ids"].shape[1]

            # Generate
            with torch.no_grad():
                output_ids = model.generate(
                    **inputs, max_new_tokens=max_new_tokens, do_sample=False,
                )

            # Decode each sample in the batch
            for i in range(len(batch_meta)):
                gen_ids = output_ids[i][input_len:]
                hypothesis = processor.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
                results.append({**batch_meta[i], "hypothesis": hypothesis})

        except Exception as e:
            print(f"  Batch {start}-{end} FAILED ({e}), falling back to sequential...", flush=True)
            for i, sample in enumerate(batch):
                try:
                    inp = processor(
                        text=[prompts[i]], audio=[wavs[i]], sampling_rate=16000,
                        return_tensors="pt",
                    )
                    inp = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v
                           for k, v in inp.items()}
                    with torch.no_grad():
                        out = model.generate(**inp, max_new_tokens=max_new_tokens, do_sample=False)
                    gen_ids = out[0][inp["input_ids"].shape[1]:]
                    hypothesis = processor.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
                except Exception as e2:
                    print(f"  Sample {start+i} ({batch_meta[i]['id']}) FAILED: {e2}")
                    hypothesis = ""
                results.append({**batch_meta[i], "hypothesis": hypothesis})

        elapsed_pct = end / total * 100
        print(f"  [{end}/{total}] ({elapsed_pct:.0f}%) processed", flush=True)

    return results


def compute_metrics(results: list) -> dict:
    """Compute WER, CER, normalized-WER/CER per language and overall."""
    by_lang = defaultdict(lambda: {"refs": [], "hyps": [], "refs_norm": [], "hyps_norm": []})

    for r in results:
        lang = r["language"]
        ref = r["reference"]
        hyp = r.get("hypothesis", "")

        ref_norm = normalize_text(ref)
        hyp_norm = normalize_text(hyp)

        if not ref_norm:
            continue

        by_lang[lang]["refs"].append(ref)
        by_lang[lang]["hyps"].append(hyp)
        by_lang[lang]["refs_norm"].append(ref_norm)
        by_lang[lang]["hyps_norm"].append(hyp_norm)

    metrics = {}
    all_refs, all_hyps = [], []
    all_refs_norm, all_hyps_norm = [], []

    for lang in sorted(by_lang.keys()):
        d = by_lang[lang]
        n = len(d["refs"])
        lang_wer = wer(d["refs"], d["hyps"])
        lang_cer = cer(d["refs"], d["hyps"])
        lang_wer_norm = wer(d["refs_norm"], d["hyps_norm"])
        lang_cer_norm = cer(d["refs_norm"], d["hyps_norm"])

        empty_count = sum(1 for h in d["hyps"] if not h.strip())

        metrics[lang] = {
            "n_samples": n,
            "wer": round(lang_wer * 100, 2),
            "cer": round(lang_cer * 100, 2),
            "wer_normalized": round(lang_wer_norm * 100, 2),
            "cer_normalized": round(lang_cer_norm * 100, 2),
            "empty_hypotheses": empty_count,
        }

        all_refs.extend(d["refs"])
        all_hyps.extend(d["hyps"])
        all_refs_norm.extend(d["refs_norm"])
        all_hyps_norm.extend(d["hyps_norm"])

    # Overall (corpus-level)
    metrics["__overall__"] = {
        "n_samples": len(all_refs),
        "wer": round(wer(all_refs, all_hyps) * 100, 2),
        "cer": round(cer(all_refs, all_hyps) * 100, 2),
        "wer_normalized": round(wer(all_refs_norm, all_hyps_norm) * 100, 2),
        "cer_normalized": round(cer(all_refs_norm, all_hyps_norm) * 100, 2),
    }

    # Macro average across languages
    lang_keys = [k for k in metrics if not k.startswith("__")]
    metrics["__macro_avg__"] = {
        "n_languages": len(lang_keys),
        "wer": round(np.mean([metrics[k]["wer"] for k in lang_keys]), 2),
        "cer": round(np.mean([metrics[k]["cer"] for k in lang_keys]), 2),
        "wer_normalized": round(np.mean([metrics[k]["wer_normalized"] for k in lang_keys]), 2),
        "cer_normalized": round(np.mean([metrics[k]["cer_normalized"] for k in lang_keys]), 2),
    }

    return metrics


def print_metrics_table(metrics: dict, checkpoint_name: str):
    """Pretty print metrics as a table."""
    print(f"\n{'='*85}")
    print(f"  BENCHMARK RESULTS: {checkpoint_name}")
    print(f"{'='*85}")
    print(f"{'Language':<14} {'N':>5} {'WER%':>8} {'CER%':>8} {'WER-n%':>8} {'CER-n%':>8}")
    print(f"{'-'*14} {'-'*5} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")

    for lang in sorted(k for k in metrics if not k.startswith("__")):
        m = metrics[lang]
        print(f"{lang:<14} {m['n_samples']:>5} {m['wer']:>8.2f} {m['cer']:>8.2f} "
              f"{m['wer_normalized']:>8.2f} {m['cer_normalized']:>8.2f}")

    print(f"{'-'*14} {'-'*5} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")

    ov = metrics["__overall__"]
    print(f"{'OVERALL':<14} {ov['n_samples']:>5} {ov['wer']:>8.2f} {ov['cer']:>8.2f} "
          f"{ov['wer_normalized']:>8.2f} {ov['cer_normalized']:>8.2f}")

    ma = metrics["__macro_avg__"]
    print(f"{'MACRO-AVG':<14} {ma['n_languages']:>5} {ma['wer']:>8.2f} {ma['cer']:>8.2f} "
          f"{ma['wer_normalized']:>8.2f} {ma['cer_normalized']:>8.2f}")
    print(f"{'='*85}\n")


def save_sample_analysis(results: list, output_dir: Path, n_per_lang: int = 5):
    """Save a few sample predictions per language for manual inspection."""
    by_lang = defaultdict(list)
    for r in results:
        by_lang[r["language"]].append(r)

    samples = []
    for lang in sorted(by_lang.keys()):
        lang_results = by_lang[lang]
        for r in lang_results[:n_per_lang]:
            samples.append({
                "id": r["id"],
                "language": r["language"],
                "reference": r["reference"],
                "hypothesis": r["hypothesis"],
                "ref_normalized": normalize_text(r["reference"]),
                "hyp_normalized": normalize_text(r["hypothesis"]),
            })

    path = output_dir / "sample_analysis.json"
    with open(path, "w", encoding="utf-8") as f:
        json.dump(samples, f, ensure_ascii=False, indent=2)
    print(f"Sample analysis saved: {path}")


def main():
    parser = argparse.ArgumentParser(description="Benchmark Gemma3n-E2B-ASR on Indic ASR dataset")
    parser.add_argument("--checkpoint", required=True, help="Path to model checkpoint directory")
    parser.add_argument("--checkpoint-name", default=None,
                        help="Name for this checkpoint in output files (default: dirname)")
    parser.add_argument("--hf-dataset", default="BayAreaBoys/indic-asr-benchmark-6k",
                        help="HuggingFace dataset name")
    parser.add_argument("--hf-token", default=None, help="HuggingFace token (or set HF_TOKEN env)")
    parser.add_argument("--cache-dir", default="/home/ubuntu/training/datasets/indic-asr-benchmark-6k")
    parser.add_argument("--output-dir", default="/home/ubuntu/training/benchmark_results_gemma3n")
    parser.add_argument("--batch-size", type=int, default=16,
                        help="Samples per batch for offline batched inference")
    parser.add_argument("--max-new-tokens", type=int, default=256)
    parser.add_argument("--device", default="cuda:0")
    parser.add_argument("--languages", nargs="+", default=None,
                        help="Filter to specific language codes (e.g. hi en ta)")
    args = parser.parse_args()

    ckpt_name = args.checkpoint_name or Path(args.checkpoint).name
    output_dir = Path(args.output_dir) / ckpt_name
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load model and processor
    print(f"Loading checkpoint: {args.checkpoint}")
    t0 = time.time()

    # Load tokenizer first, patch missing audio_token_id (version mismatch workaround)
    from transformers import AutoTokenizer, AutoConfig
    config = AutoConfig.from_pretrained(args.checkpoint, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, padding_side="left")

    # Gemma3nProcessor expects special token attrs on the tokenizer (added in transformers 5.x)
    # Patch them from config for compatibility with transformers 4.x
    _token_patches = {
        "audio_token_id": getattr(config, "audio_token_id", 262273),
        "audio_token": "<audio_soft_token>",
        "boa_token_id": getattr(config, "boa_token_id", 256000),
        "boa_token": "<start_of_audio>",
        "eoa_token_id": getattr(config, "eoa_token_id", 262272),
        "eoa_token": "<end_of_audio>",
        "boi_token_id": getattr(config, "boi_token_id", 255999),
        "boi_token": "<start_of_image>",
        "eoi_token_id": getattr(config, "eoi_token_id", 262144),
        "eoi_token": "<end_of_image>",
        "image_token_id": getattr(config, "image_token_id", 262145),
        "image_token": "<image_soft_token>",
    }
    for attr, val in _token_patches.items():
        if not hasattr(tokenizer, attr):
            setattr(tokenizer, attr, val)

    processor = AutoProcessor.from_pretrained(
        args.checkpoint, trust_remote_code=True, tokenizer=tokenizer
    )
    model = AutoModelForImageTextToText.from_pretrained(
        args.checkpoint,
        torch_dtype=torch.bfloat16,
        device_map=args.device,
        trust_remote_code=True,
        attn_implementation="sdpa",
    )
    model.eval()
    # Enable KV cache for faster autoregressive generation (training disables it)
    model.config.use_cache = True
    if hasattr(model.config, "text_config"):
        model.config.text_config.use_cache = True

    load_time = time.time() - t0
    print(f"Model loaded in {load_time:.1f}s")

    # Load dataset
    hf_token = args.hf_token or os.environ.get("HF_TOKEN")
    print(f"Loading dataset: {args.hf_dataset}")
    dataset = hf_load_dataset(
        args.hf_dataset, token=hf_token, cache_dir=args.cache_dir, split="train"
    )

    # Filter languages if specified
    if args.languages:
        dataset = dataset.filter(lambda x: x["lang_code"] in args.languages)
        print(f"Filtered to {len(dataset)} samples for languages: {args.languages}")

    print(f"Dataset: {len(dataset)} samples")

    # Run inference
    print(f"\nRunning inference (batch_size={args.batch_size})...")
    t0 = time.time()
    results = run_inference(model, processor, dataset, args.batch_size, args.max_new_tokens)
    inference_time = time.time() - t0

    total_audio_sec = sum(r["duration"] for r in results)
    rtf = inference_time / total_audio_sec if total_audio_sec > 0 else 0
    print(f"Inference done in {inference_time:.1f}s for {total_audio_sec:.0f}s audio (RTF={rtf:.4f})")

    # Save raw predictions
    preds_file = output_dir / "predictions.json"
    with open(preds_file, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"Predictions saved: {preds_file}")

    # Compute metrics
    metrics = compute_metrics(results)
    metrics["__meta__"] = {
        "checkpoint": args.checkpoint,
        "checkpoint_name": ckpt_name,
        "model_type": "gemma3n-E2B-asr",
        "dataset": args.hf_dataset,
        "batch_size": args.batch_size,
        "inference_time_sec": round(inference_time, 2),
        "total_audio_sec": round(total_audio_sec, 2),
        "rtf": round(rtf, 4),
    }

    metrics_file = output_dir / "metrics.json"
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)
    print(f"Metrics saved: {metrics_file}")

    # Print table
    print_metrics_table(metrics, ckpt_name)

    # Save sample analysis
    save_sample_analysis(results, output_dir)


if __name__ == "__main__":
    main()
