#!/usr/bin/env python3
"""
Benchmark Qwen3-ASR checkpoints on indic-asr-benchmark-6k dataset.

Usage:
    python benchmark_qwen3_asr.py \
        --checkpoint /home/ubuntu/training/checkpoints/qwen3-asr-ckpt-24000 \
        --checkpoint-name ckpt-24000 \
        --output-dir /home/ubuntu/training/benchmark_results \
        --batch-size 8

Metrics computed per language and overall:
    - WER  (Word Error Rate)
    - CER  (Character Error Rate)
    - WER-norm (WER after unicode/punctuation/space normalization)
    - CER-norm (CER after normalization)
"""

import argparse
import json
import os
import re
import sys
import time
import unicodedata
from collections import defaultdict
from pathlib import Path

# vLLM v1 (0.16+) spawns a subprocess for the engine — CUDA must not be
# initialized in the parent before that fork.  Setting the env var up-front
# tells vLLM to use 'spawn' instead of 'fork', which avoids the
# "Cannot re-initialize CUDA in forked subprocess" crash.
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")

import numpy as np
from datasets import load_dataset as hf_load_dataset
from jiwer import cer, wer

# Add the repo to path for local imports
sys.path.insert(0, "/home/ubuntu/training/qwen3-asr-1.7b-phase2-sft")

# Language code → Qwen3-ASR language name mapping
LANG_CODE_TO_NAME = {
    "as": "Assamese", "bn": "Bengali", "en": "English", "gu": "Gujarati",
    "hi": "Hindi", "kn": "Kannada", "ml": "Malayalam", "mr": "Marathi",
    "or": "Odia", "pa": "Punjabi", "ta": "Tamil", "te": "Telugu",
}

# Languages the model supports (base + Indic from Phase 2 fine-tuning)
_MODEL_SUPPORTED_LANGS = {
    "Chinese", "English", "Cantonese", "Arabic", "German", "French", "Spanish",
    "Portuguese", "Indonesian", "Italian", "Korean", "Russian", "Thai", "Vietnamese",
    "Japanese", "Turkish", "Hindi", "Malay", "Dutch", "Swedish", "Danish", "Finnish",
    "Polish", "Czech", "Filipino", "Persian", "Greek", "Romanian", "Hungarian", "Macedonian",
    # Indic languages (Phase 2 fine-tuning)
    "Assamese", "Bengali", "Gujarati", "Kannada", "Malayalam",
    "Marathi", "Odia", "Punjabi", "Tamil", "Telugu",
}


# ── Text normalization (standard ASR eval practice) ──────────────────────

def normalize_text(text: str) -> str:
    """Normalize text for fair WER/CER comparison.
    Unicode NFC → lowercase → remove punctuation → collapse whitespace → strip.
    """
    text = unicodedata.normalize("NFC", text)
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text, flags=re.UNICODE)
    text = re.sub(r"\s+", " ", text).strip()
    return text


def run_inference(model, dataset, batch_size: int, force_language: bool = True):
    """Run batched inference on HF dataset and return list of result dicts."""
    results = []
    total = len(dataset)

    for start in range(0, total, batch_size):
        end = min(start + batch_size, total)
        batch = dataset.select(range(start, end))

        audio_inputs = []
        for sample in batch:
            audio = sample["audio"]
            wav = np.array(audio["array"], dtype=np.float32)
            sr = audio["sampling_rate"]
            audio_inputs.append((wav, sr))

        # Force language if supported
        if force_language:
            forced_langs = []
            for sample in batch:
                lang_name = LANG_CODE_TO_NAME.get(sample["lang_code"])
                forced_langs.append(lang_name)
        else:
            forced_langs = None

        try:
            transcriptions = model.transcribe(
                audio=audio_inputs,
                language=forced_langs,
                return_time_stamps=False,
            )
        except Exception as e:
            print(f"  Batch {start}-{end} failed ({e}), retrying without forced language...")
            try:
                transcriptions = model.transcribe(
                    audio=audio_inputs, language=None, return_time_stamps=False,
                )
            except Exception as e2:
                print(f"  Batch {start}-{end} FAILED: {e2}, trying one-by-one...")
                transcriptions = []
                for ai in audio_inputs:
                    try:
                        r = model.transcribe(audio=ai, language=None, return_time_stamps=False)
                        transcriptions.extend(r)
                    except Exception as e3:
                        from qwen_asr.inference.qwen3_asr import ASRTranscription
                        transcriptions.append(ASRTranscription(language="", text=""))

        for i, t in enumerate(transcriptions):
            sample = batch[i]
            results.append({
                "id": sample["id"],
                "language": sample["language"],
                "lang_code": sample["lang_code"],
                "source": sample["source"],
                "duration": sample["duration"],
                "reference": sample["reference"],
                "hypothesis": t.text,
                "detected_language": t.language,
            })

        elapsed_pct = end / total * 100
        print(f"  [{end}/{total}] ({elapsed_pct:.0f}%) processed", flush=True)

    return results


def compute_metrics(results: list) -> dict:
    """Compute WER, CER, normalized-WER/CER per language and overall."""
    by_lang = defaultdict(lambda: {"refs": [], "hyps": [], "refs_norm": [], "hyps_norm": []})

    for r in results:
        lang = r["language"]
        ref = r["reference"]
        hyp = r.get("hypothesis", "")

        ref_norm = normalize_text(ref)
        hyp_norm = normalize_text(hyp)

        if not ref_norm:
            continue

        by_lang[lang]["refs"].append(ref)
        by_lang[lang]["hyps"].append(hyp)
        by_lang[lang]["refs_norm"].append(ref_norm)
        by_lang[lang]["hyps_norm"].append(hyp_norm)

    metrics = {}
    all_refs, all_hyps = [], []
    all_refs_norm, all_hyps_norm = [], []

    for lang in sorted(by_lang.keys()):
        d = by_lang[lang]
        n = len(d["refs"])
        lang_wer = wer(d["refs"], d["hyps"])
        lang_cer = cer(d["refs"], d["hyps"])
        lang_wer_norm = wer(d["refs_norm"], d["hyps_norm"])
        lang_cer_norm = cer(d["refs_norm"], d["hyps_norm"])

        empty_count = sum(1 for h in d["hyps"] if not h.strip())

        metrics[lang] = {
            "n_samples": n,
            "wer": round(lang_wer * 100, 2),
            "cer": round(lang_cer * 100, 2),
            "wer_normalized": round(lang_wer_norm * 100, 2),
            "cer_normalized": round(lang_cer_norm * 100, 2),
            "empty_hypotheses": empty_count,
        }

        all_refs.extend(d["refs"])
        all_hyps.extend(d["hyps"])
        all_refs_norm.extend(d["refs_norm"])
        all_hyps_norm.extend(d["hyps_norm"])

    # Overall (corpus-level)
    metrics["__overall__"] = {
        "n_samples": len(all_refs),
        "wer": round(wer(all_refs, all_hyps) * 100, 2),
        "cer": round(cer(all_refs, all_hyps) * 100, 2),
        "wer_normalized": round(wer(all_refs_norm, all_hyps_norm) * 100, 2),
        "cer_normalized": round(cer(all_refs_norm, all_hyps_norm) * 100, 2),
    }

    # Macro average across languages
    lang_keys = [k for k in metrics if not k.startswith("__")]
    metrics["__macro_avg__"] = {
        "n_languages": len(lang_keys),
        "wer": round(np.mean([metrics[k]["wer"] for k in lang_keys]), 2),
        "cer": round(np.mean([metrics[k]["cer"] for k in lang_keys]), 2),
        "wer_normalized": round(np.mean([metrics[k]["wer_normalized"] for k in lang_keys]), 2),
        "cer_normalized": round(np.mean([metrics[k]["cer_normalized"] for k in lang_keys]), 2),
    }

    return metrics


def print_metrics_table(metrics: dict, checkpoint_name: str):
    """Pretty print metrics as a table."""
    print(f"\n{'='*85}")
    print(f"  BENCHMARK RESULTS: {checkpoint_name}")
    print(f"{'='*85}")
    print(f"{'Language':<14} {'N':>5} {'WER%':>8} {'CER%':>8} {'WER-n%':>8} {'CER-n%':>8}")
    print(f"{'-'*14} {'-'*5} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")

    for lang in sorted(k for k in metrics if not k.startswith("__")):
        m = metrics[lang]
        print(f"{lang:<14} {m['n_samples']:>5} {m['wer']:>8.2f} {m['cer']:>8.2f} "
              f"{m['wer_normalized']:>8.2f} {m['cer_normalized']:>8.2f}")

    print(f"{'-'*14} {'-'*5} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")

    ov = metrics["__overall__"]
    print(f"{'OVERALL':<14} {ov['n_samples']:>5} {ov['wer']:>8.2f} {ov['cer']:>8.2f} "
          f"{ov['wer_normalized']:>8.2f} {ov['cer_normalized']:>8.2f}")

    ma = metrics["__macro_avg__"]
    print(f"{'MACRO-AVG':<14} {ma['n_languages']:>5} {ma['wer']:>8.2f} {ma['cer']:>8.2f} "
          f"{ma['wer_normalized']:>8.2f} {ma['cer_normalized']:>8.2f}")
    print(f"{'='*85}\n")


def save_sample_analysis(results: list, output_dir: Path, n_per_lang: int = 5):
    """Save a few sample predictions per language for manual inspection."""
    by_lang = defaultdict(list)
    for r in results:
        by_lang[r["language"]].append(r)

    samples = []
    for lang in sorted(by_lang.keys()):
        lang_results = by_lang[lang]
        # Pick first n_per_lang samples
        for r in lang_results[:n_per_lang]:
            samples.append({
                "id": r["id"],
                "language": r["language"],
                "reference": r["reference"],
                "hypothesis": r["hypothesis"],
                "detected_language": r.get("detected_language", ""),
                "ref_normalized": normalize_text(r["reference"]),
                "hyp_normalized": normalize_text(r["hypothesis"]),
            })

    path = output_dir / "sample_analysis.json"
    with open(path, "w", encoding="utf-8") as f:
        json.dump(samples, f, ensure_ascii=False, indent=2)
    print(f"Sample analysis saved: {path}")


def main():
    parser = argparse.ArgumentParser(description="Benchmark Qwen3-ASR on Indic ASR dataset")
    parser.add_argument("--checkpoint", required=True, help="Path to model checkpoint directory")
    parser.add_argument("--checkpoint-name", default=None,
                        help="Name for this checkpoint in output files (default: dirname)")
    parser.add_argument("--hf-dataset", default="BayAreaBoys/indic-asr-benchmark-6k",
                        help="HuggingFace dataset name")
    parser.add_argument("--hf-token", default=None, help="HuggingFace token (or set HF_TOKEN env)")
    parser.add_argument("--cache-dir", default="/home/ubuntu/training/datasets/indic-asr-benchmark-6k")
    parser.add_argument("--output-dir", default="/home/ubuntu/training/benchmark_results")
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--device", default="cuda:0")
    parser.add_argument("--backend", choices=["transformers", "vllm"], default="vllm",
                        help="Inference backend (default: vllm)")
    parser.add_argument("--gpu-memory-utilization", type=float, default=0.85,
                        help="vLLM GPU memory utilization (default: 0.85)")
    parser.add_argument("--no-force-language", action="store_true",
                        help="Don't force language, let model auto-detect")
    parser.add_argument("--languages", nargs="+", default=None,
                        help="Filter to specific language codes (e.g. hi en ta)")
    args = parser.parse_args()

    ckpt_name = args.checkpoint_name or Path(args.checkpoint).name
    output_dir = Path(args.output_dir) / ckpt_name
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load model
    from qwen_asr import Qwen3ASRModel

    print(f"Loading checkpoint: {args.checkpoint} (backend={args.backend})")
    t0 = time.time()
    if args.backend == "vllm":
        model = Qwen3ASRModel.LLM(
            model=args.checkpoint,
            gpu_memory_utilization=args.gpu_memory_utilization,
            max_inference_batch_size=args.batch_size if args.batch_size > 0 else -1,
            max_new_tokens=args.max_new_tokens,
        )
    else:
        import torch
        model = Qwen3ASRModel.from_pretrained(
            args.checkpoint,
            dtype=torch.bfloat16,
            device_map=args.device,
            max_inference_batch_size=args.batch_size,
            max_new_tokens=args.max_new_tokens,
        )
    load_time = time.time() - t0
    print(f"Model loaded in {load_time:.1f}s")

    # Load dataset
    hf_token = args.hf_token or os.environ.get("HF_TOKEN")
    print(f"Loading dataset: {args.hf_dataset}")
    from datasets import Audio
    dataset = hf_load_dataset(
        args.hf_dataset, token=hf_token, cache_dir=args.cache_dir, split="train"
    )
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000, decode=True))

    # Filter languages if specified
    if args.languages:
        dataset = dataset.filter(lambda x: x["lang_code"] in args.languages)
        print(f"Filtered to {len(dataset)} samples for languages: {args.languages}")

    print(f"Dataset: {len(dataset)} samples")

    # Run inference
    print(f"\nRunning inference (batch_size={args.batch_size}, "
          f"force_language={not args.no_force_language})...")
    t0 = time.time()
    results = run_inference(model, dataset, args.batch_size,
                            force_language=not args.no_force_language)
    inference_time = time.time() - t0

    total_audio_sec = sum(r["duration"] for r in results)
    rtf = inference_time / total_audio_sec if total_audio_sec > 0 else 0
    print(f"Inference done in {inference_time:.1f}s for {total_audio_sec:.0f}s audio (RTF={rtf:.4f})")

    # Save raw predictions
    preds_file = output_dir / "predictions.json"
    with open(preds_file, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    print(f"Predictions saved: {preds_file}")

    # Compute metrics
    metrics = compute_metrics(results)
    metrics["__meta__"] = {
        "checkpoint": args.checkpoint,
        "checkpoint_name": ckpt_name,
        "backend": args.backend,
        "dataset": args.hf_dataset,
        "batch_size": args.batch_size,
        "force_language": not args.no_force_language,
        "inference_time_sec": round(inference_time, 2),
        "total_audio_sec": round(total_audio_sec, 2),
        "rtf": round(rtf, 4),
    }

    metrics_file = output_dir / "metrics.json"
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)
    print(f"Metrics saved: {metrics_file}")

    # Print table
    print_metrics_table(metrics, ckpt_name)

    # Save sample analysis
    save_sample_analysis(results, output_dir)


if __name__ == "__main__":
    main()
