#!/usr/bin/env python3
"""
Benchmark Qwen3-ASR on Kathbath (valid) + Svarah datasets.

Loads parquets directly, decodes audio bytes via soundfile, runs batched
inference, then writes predictions.json + metrics.json + sample_analysis.json
+ error_analysis.json under benchmark_outputs/<dataset>/<ckpt-name>/.
"""

import argparse
import io
import json
import os
import sys
import time
from collections import defaultdict
from pathlib import Path

os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")

import numpy as np
import pyarrow.parquet as pq
import soundfile as sf

sys.path.insert(0, "/home/ubuntu/training/qwen3-asr-1.7b-phase2-sft")
sys.path.insert(0, "/home/ubuntu/training")

from benchmark_maya_asr_tdt import (
    compute_metrics_for_samples, build_sample_analysis, build_error_analysis,
)

# lang folder name → (display_name, lang_code) for Qwen3-ASR
KATHBATH_LANG_MAP = {
    "bengali":   ("Bengali",   "bn"),
    "gujarati":  ("Gujarati",  "gu"),
    "hindi":     ("Hindi",     "hi"),
    "kannada":   ("Kannada",   "kn"),
    "malayalam": ("Malayalam", "ml"),
    "marathi":   ("Marathi",   "mr"),
    "odia":      ("Odia",      "or"),
    "punjabi":   ("Punjabi",   "pa"),
    "tamil":     ("Tamil",     "ta"),
    "telugu":    ("Telugu",    "te"),
}


def load_kathbath_valid(root: Path):
    """Yield samples (dict) from all kathbath valid parquets."""
    samples = []
    files = sorted(root.glob("*-valid-*.parquet"))
    for f in files:
        # filename format: "<lang>-valid-NNNNN-of-NNNNN.parquet"
        lang_folder = f.name.split("-valid-")[0]
        if lang_folder not in KATHBATH_LANG_MAP:
            continue
        lang_name, lang_code = KATHBATH_LANG_MAP[lang_folder]
        table = pq.read_table(f)
        df = table.to_pandas()
        for _, row in df.iterrows():
            audio_bytes = row["audio_filepath"]["bytes"]
            samples.append({
                "id": row["fname"],
                "language": lang_name,
                "lang_code": lang_code,
                "source": "kathbath",
                "duration": float(row["duration"]),
                "reference": row["text"],
                "_audio_bytes": audio_bytes,
            })
        print(f"  loaded {lang_folder}: {len(df)} samples")
    return samples


def load_indicvoices_valid(root: Path):
    """Yield samples from IndicVoices valid parquets (per-language)."""
    INDIC_LANG_MAP = {
        "assamese":  ("Assamese",  "as"),
        "bengali":   ("Bengali",   "bn"),
        "gujarati":  ("Gujarati",  "gu"),
        "hindi":     ("Hindi",     "hi"),
        "kannada":   ("Kannada",   "kn"),
        "malayalam": ("Malayalam", "ml"),
        "marathi":   ("Marathi",   "mr"),
        "odia":      ("Odia",      "or"),
        "punjabi":   ("Punjabi",   "pa"),
        "tamil":     ("Tamil",     "ta"),
        "telugu":    ("Telugu",    "te"),
    }
    samples = []
    files = sorted(root.glob("*-valid.parquet"))
    for f in files:
        lang_folder = f.stem.replace("-valid", "")
        if lang_folder not in INDIC_LANG_MAP:
            continue
        lang_name, lang_code = INDIC_LANG_MAP[lang_folder]
        df = pq.read_table(f).to_pandas()
        for i, row in df.iterrows():
            ref = row.get("normalized") or row.get("text") or ""
            samples.append({
                "id": f"{lang_folder}-{i}",
                "language": lang_name,
                "lang_code": lang_code,
                "source": "indicvoices",
                "duration": float(row["duration"]),
                "reference": ref,
                "_audio_bytes": row["audio_filepath"]["bytes"],
            })
        print(f"  loaded {lang_folder}: {len(df)} samples")
    return samples


def load_svarah(root: Path):
    """Load all svarah test parquets."""
    samples = []
    files = sorted(root.glob("test-*.parquet"))
    for f in files:
        df = pq.read_table(f).to_pandas()
        for i, row in df.iterrows():
            samples.append({
                "id": row["audio_filepath"]["path"] or f"svarah-{f.stem}-{i}",
                "language": "English",
                "lang_code": "en",
                "source": "svarah",
                "duration": float(row["duration"]),
                "reference": row["text"],
                "_audio_bytes": row["audio_filepath"]["bytes"],
            })
        print(f"  loaded {f.name}: {len(df)} samples")
    return samples


def decode_audio(audio_bytes: bytes):
    wav, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
    if wav.ndim > 1:
        wav = wav.mean(axis=1)
    if sr != 16000:
        # simple resample via np interpolation (quality fine for ASR)
        n_new = int(round(len(wav) * 16000 / sr))
        wav = np.interp(np.linspace(0, len(wav)-1, n_new), np.arange(len(wav)), wav).astype(np.float32)
        sr = 16000
    return wav, sr


def run_inference(model, samples, batch_size: int):
    results = []
    total = len(samples)
    for start in range(0, total, batch_size):
        end = min(start + batch_size, total)
        batch = samples[start:end]
        audio_inputs = []
        forced_langs = []
        for s in batch:
            wav, sr = decode_audio(s["_audio_bytes"])
            audio_inputs.append((wav, sr))
            forced_langs.append(s["language"])
        try:
            transcriptions = model.transcribe(
                audio=audio_inputs, language=forced_langs, return_time_stamps=False,
            )
        except Exception as e:
            print(f"  Batch {start}-{end} failed: {e}, retrying without forced lang")
            try:
                transcriptions = model.transcribe(
                    audio=audio_inputs, language=None, return_time_stamps=False,
                )
            except Exception as e2:
                print(f"  Batch {start}-{end} FAIL: {e2}, one-by-one")
                transcriptions = []
                for ai in audio_inputs:
                    try:
                        transcriptions.extend(model.transcribe(audio=ai, language=None, return_time_stamps=False))
                    except Exception:
                        from qwen_asr.inference.qwen3_asr import ASRTranscription
                        transcriptions.append(ASRTranscription(language="", text=""))
        for s, t in zip(batch, transcriptions):
            results.append({
                "id": s["id"],
                "language": s["language"],
                "lang_code": s["lang_code"],
                "source": s["source"],
                "duration": s["duration"],
                "reference": s["reference"],
                "hypothesis": t.text,
                "detected_language": t.language,
            })
        if end % (batch_size * 4) == 0 or end == total:
            print(f"  [{end}/{total}] ({end/total*100:.0f}%)", flush=True)
    return results


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", required=True)
    p.add_argument("--checkpoint-name", required=True)
    p.add_argument("--dataset", choices=["kathbath", "svarah", "indicvoices"], required=True)
    p.add_argument("--data-root", required=True)
    p.add_argument("--output-base", default="/home/ubuntu/training/benchmark_outputs")
    p.add_argument("--batch-size", type=int, default=128)
    p.add_argument("--max-new-tokens", type=int, default=512)
    p.add_argument("--device", default="cuda:0")
    args = p.parse_args()

    print(f"Loading dataset: {args.dataset} from {args.data_root}")
    if args.dataset == "kathbath":
        samples = load_kathbath_valid(Path(args.data_root))
    elif args.dataset == "indicvoices":
        samples = load_indicvoices_valid(Path(args.data_root))
    else:
        samples = load_svarah(Path(args.data_root))
    print(f"Total: {len(samples)} samples")

    print(f"Loading model: {args.checkpoint}")
    import torch
    from qwen_asr import Qwen3ASRModel
    t0 = time.time()
    model = Qwen3ASRModel.from_pretrained(
        args.checkpoint,
        dtype=torch.bfloat16,
        device_map=args.device,
        max_inference_batch_size=args.batch_size,
        max_new_tokens=args.max_new_tokens,
    )
    print(f"Model loaded in {time.time()-t0:.1f}s")

    print(f"Running inference (batch_size={args.batch_size})...")
    t0 = time.time()
    results = run_inference(model, samples, args.batch_size)
    inference_time = time.time() - t0
    total_audio = sum(r["duration"] for r in results)
    rtf = inference_time / total_audio if total_audio > 0 else 0
    print(f"Inference: {inference_time:.0f}s for {total_audio:.0f}s audio (RTF={rtf:.4f})")

    out_dir = Path(args.output_base) / f"qwen3-asr-mixed-v2-{args.dataset}" / args.checkpoint_name
    out_dir.mkdir(parents=True, exist_ok=True)

    metrics = compute_metrics_for_samples(results)
    from datetime import datetime, timezone
    metrics["__meta__"] = {
        "checkpoint": args.checkpoint,
        "checkpoint_name": args.checkpoint_name,
        "model_id": f"qwen3-asr-mixed-v2-{args.dataset}",
        "model_type": "qwen3-asr-1.7B-mixed-v2",
        "dataset": f"ai4bharat/{args.dataset.capitalize()}",
        "batch_size": args.batch_size,
        "inference_time_sec": round(inference_time, 2),
        "total_audio_sec": round(total_audio, 2),
        "rtf": round(rtf, 4),
        "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
        "framework": "transformers",
        "normalization_version": "v1",
    }
    with open(out_dir / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)
    with open(out_dir / "predictions.json", "w") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    with open(out_dir / "sample_analysis.json", "w") as f:
        json.dump(build_sample_analysis(results), f, ensure_ascii=False, indent=2)
    with open(out_dir / "error_analysis.json", "w") as f:
        json.dump(build_error_analysis(results), f, ensure_ascii=False, indent=2)

    o = metrics["__overall__"]
    print(f"\nResults: wer_raw={o['wer_raw']:.2f}%  wer_norm={o['wer_norm']:.2f}%  "
          f"space_norm={o['space_norm_wer']:.2f}%  mer={o['mer']:.2f}%  cer_norm={o['cer_norm']:.2f}%")
    print(f"Per-language:")
    for lang in sorted(k for k in metrics if not k.startswith("_")):
        m = metrics[lang]
        print(f"  {lang:12} n={m['n_samples']:5} wer_raw={m['wer_raw']:6.2f}  wer_norm={m['wer_norm']:6.2f}  "
              f"space_norm={m['space_norm_wer']:6.2f}  mer={m['mer']:5.2f}  cer_norm={m['cer_norm']:5.2f}")
    print(f"\nWritten to {out_dir}/")


if __name__ == "__main__":
    main()
