"""
Benchmark fine-tuned Hindi Nemotron ASR on open-source datasets.
Datasets: FLEURS Hindi (google/fleurs hi_in), IndicVoices Hindi (ai4bharat/indicvoices_r)
Metrics:  WER (Word Error Rate), CER (Character Error Rate), RTF (Real-Time Factor)

Model: BayAreaBoys/nemotron-hindi (fine-tuned from nvidia/nemotron-speech-streaming-en-0.6b)
Architecture: FastConformer Cache-Aware RNNT — expects 16kHz mono audio
"""

import os
import sys
import csv
import json
import time
import types
import tempfile
from math import gcd
from typing import List, Dict, Tuple

import torch
import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from omegaconf import DictConfig, open_dict
from jiwer import wer, cer
from datasets import load_dataset

import nemo.collections.asr as nemo_asr

load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")
HF_MODEL = os.getenv("HF_MODEL")
TARGET_SR = 16000
BATCH_SIZE = 16  # batch inference for throughput
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
RESULTS_DIR = os.path.join(BASE_DIR, "benchmark_results")


# ─────────────────────────────────────────────────────────────────────────────
# Hindi text normalization for fair WER/CER comparison
# ─────────────────────────────────────────────────────────────────────────────
def normalize_hindi(text: str) -> str:
    """Normalize Hindi text for WER/CER: strip punctuation, collapse whitespace, lowercase."""
    import re
    if not text:
        return ""
    # Remove common punctuation (keep Devanagari + digits + spaces)
    text = re.sub(r'[।,\.\?\!;:\-\—\–\"\'\(\)\[\]\{\}]', ' ', text)
    # Collapse multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    # Lowercase (for any English mixed in)
    text = text.lower()
    return text


# ─────────────────────────────────────────────────────────────────────────────
# Model loading — reuses fixes from transcribe.py
# ─────────────────────────────────────────────────────────────────────────────
def load_model():
    """Download and load the fine-tuned Nemotron Hindi ASR model."""
    print(f"\n[1/3] Downloading model: {HF_MODEL}")
    t0 = time.time()
    model_path = hf_hub_download(
        repo_id=HF_MODEL, filename="final_model.nemo", token=HF_TOKEN,
    )
    print(f"  Cached at: {model_path} ({time.time() - t0:.1f}s)")

    print(f"[2/3] Loading NeMo ASR model...")
    t0 = time.time()
    asr_model = nemo_asr.models.ASRModel.restore_from(model_path)

    # FIX: NeMo 2.6 hardcodes use_lhotse=True → lhotse DynamicCutSampler compat bug
    def patched_setup(self_model, config):
        if 'manifest_filepath' in config:
            manifest_filepath = config['manifest_filepath']
            batch_size = config['batch_size']
        else:
            manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json')
            batch_size = min(config['batch_size'], len(config['paths2audio_files']))
        dl_config = {
            'manifest_filepath': manifest_filepath,
            'sample_rate': self_model.preprocessor._sample_rate,
            'batch_size': batch_size,
            'shuffle': False,
            'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
            'pin_memory': True,
            'channel_selector': config.get('channel_selector', None),
            'use_start_end_token': self_model.cfg.validation_ds.get('use_start_end_token', False),
        }
        if config.get("augmentor"):
            dl_config['augmentor'] = config.get("augmentor")
        return self_model._setup_dataloader_from_config(config=DictConfig(dl_config))

    asr_model._setup_transcribe_dataloader = types.MethodType(patched_setup, asr_model)

    # FIX: CUDA graph cu_call unpacking bug with newer CUDA drivers
    with open_dict(asr_model.cfg):
        if hasattr(asr_model.cfg, 'decoding'):
            asr_model.cfg.decoding.greedy.loop_labels = False
            asr_model.cfg.decoding.greedy.use_cuda_graph_decoder = False
    asr_model.change_decoding_strategy(asr_model.cfg.decoding)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    asr_model = asr_model.to(device)
    asr_model.eval()
    print(f"  Loaded on {device} ({time.time() - t0:.1f}s)")
    return asr_model


# ─────────────────────────────────────────────────────────────────────────────
# Audio helpers
# ─────────────────────────────────────────────────────────────────────────────
def resample_array(audio_array: np.ndarray, orig_sr: int, target_sr: int = TARGET_SR) -> np.ndarray:
    """Resample numpy audio array to target sample rate."""
    if orig_sr == target_sr:
        return audio_array.astype(np.float32)
    g = gcd(orig_sr, target_sr)
    up, down = target_sr // g, orig_sr // g
    return resample_poly(audio_array, up, down).astype(np.float32)


def save_temp_wav(audio_array: np.ndarray, sr: int, tmp_dir: str, idx: int) -> str:
    """Save audio array as temporary WAV file for NeMo transcription."""
    path = os.path.join(tmp_dir, f"sample_{idx:05d}.wav")
    sf.write(path, audio_array, sr)
    return path


# ─────────────────────────────────────────────────────────────────────────────
# Dataset loaders — return list of (audio_path, reference_text, duration_sec)
# ─────────────────────────────────────────────────────────────────────────────
def load_fleurs_hindi(tmp_dir: str, max_samples: int = None) -> List[Dict]:
    """Load FLEURS Hindi test set → save as 16kHz WAVs → return metadata."""
    print("\n  Loading FLEURS Hindi (google/fleurs, hi_in, test)...")
    ds = load_dataset("google/fleurs", "hi_in", split="test")
    total = len(ds)
    if max_samples:
        total = min(total, max_samples)

    samples = []
    for i in range(total):
        item = ds[i]
        audio = item["audio"]
        arr = np.array(audio["array"], dtype=np.float32)
        sr = audio["sampling_rate"]

        # Resample if needed (FLEURS is 16kHz, but just in case)
        arr = resample_array(arr, sr, TARGET_SR)
        duration = len(arr) / TARGET_SR

        path = save_temp_wav(arr, TARGET_SR, tmp_dir, i)
        samples.append({
            "id": item.get("id", i),
            "audio_path": path,
            "reference": item["transcription"],
            "duration": duration,
        })

        if (i + 1) % 100 == 0:
            print(f"    Prepared {i + 1}/{total} samples...")

    print(f"    Total: {len(samples)} samples prepared")
    return samples


def load_indicvoices_hindi(tmp_dir: str, max_samples: int = 200) -> List[Dict]:
    """Load IndicVoices Hindi test set → resample 48kHz→16kHz → return metadata."""
    print("\n  Loading IndicVoices Hindi (ai4bharat/indicvoices_r, Hindi, test)...")
    try:
        ds = load_dataset("ai4bharat/indicvoices_r", "Hindi", split="test", streaming=True)
    except Exception as e:
        print(f"    ⚠ Could not load IndicVoices: {e}")
        return []

    samples = []
    for i, item in enumerate(ds):
        if max_samples and i >= max_samples:
            break
        audio = item["audio"]
        arr = np.array(audio["array"], dtype=np.float32)
        sr = audio["sampling_rate"]

        arr = resample_array(arr, sr, TARGET_SR)
        duration = len(arr) / TARGET_SR

        path = save_temp_wav(arr, TARGET_SR, tmp_dir, i)
        ref_text = item.get("text", item.get("normalized", ""))
        samples.append({
            "id": i,
            "audio_path": path,
            "reference": ref_text,
            "duration": duration,
        })

        if (i + 1) % 50 == 0:
            print(f"    Prepared {i + 1} samples...")

    print(f"    Total: {len(samples)} samples prepared")
    return samples


# ─────────────────────────────────────────────────────────────────────────────
# Batch transcription
# ─────────────────────────────────────────────────────────────────────────────
def transcribe_batch(asr_model, audio_paths: List[str], batch_size: int = BATCH_SIZE) -> Tuple[List[str], float]:
    """Transcribe a list of audio files in batches. Returns (texts, total_time)."""
    t0 = time.time()
    results = asr_model.transcribe(audio_paths, batch_size=batch_size, verbose=False)
    elapsed = time.time() - t0

    # Unpack results — handle both str and Hypothesis returns
    texts = []
    if isinstance(results, list):
        for r in results:
            if isinstance(r, str):
                texts.append(r)
            elif hasattr(r, 'text'):
                texts.append(r.text)
            else:
                texts.append(str(r))
    else:
        texts = [str(results)]

    return texts, elapsed


# ─────────────────────────────────────────────────────────────────────────────
# Compute metrics
# ─────────────────────────────────────────────────────────────────────────────
def compute_metrics(references: List[str], hypotheses: List[str]) -> Dict:
    """Compute WER, CER on normalized Hindi text. Also per-sample metrics."""
    # Normalize both sides
    refs_norm = [normalize_hindi(r) for r in references]
    hyps_norm = [normalize_hindi(h) for h in hypotheses]

    # Filter out empty references (can't compute WER on empty ref)
    valid = [(r, h) for r, h in zip(refs_norm, hyps_norm) if r.strip()]
    if not valid:
        return {"wer": 1.0, "cer": 1.0, "valid_samples": 0}

    refs_valid, hyps_valid = zip(*valid)
    refs_valid, hyps_valid = list(refs_valid), list(hyps_valid)

    # Corpus-level WER/CER (jiwer treats list as concatenated corpus)
    corpus_wer = wer(refs_valid, hyps_valid)
    corpus_cer = cer(refs_valid, hyps_valid)

    # Per-sample WER for distribution analysis
    per_sample_wer = []
    for r, h in zip(refs_valid, hyps_valid):
        try:
            per_sample_wer.append(wer(r, h))
        except Exception:
            per_sample_wer.append(1.0)

    return {
        "wer": corpus_wer,
        "cer": corpus_cer,
        "valid_samples": len(refs_valid),
        "median_wer": float(np.median(per_sample_wer)),
        "p90_wer": float(np.percentile(per_sample_wer, 90)),
        "per_sample_wer": per_sample_wer,
    }


# ─────────────────────────────────────────────────────────────────────────────
# Save detailed results
# ─────────────────────────────────────────────────────────────────────────────
def save_results(dataset_name: str, samples: List[Dict], hypotheses: List[str], metrics: Dict):
    """Save per-sample CSV and summary JSON."""
    os.makedirs(RESULTS_DIR, exist_ok=True)

    # Per-sample CSV
    csv_path = os.path.join(RESULTS_DIR, f"{dataset_name}_details.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["id", "duration_s", "reference", "hypothesis", "ref_normalized", "hyp_normalized", "sample_wer"])
        refs_norm = [normalize_hindi(s["reference"]) for s in samples]
        hyps_norm = [normalize_hindi(h) for h in hypotheses]
        per_wer = metrics.get("per_sample_wer", [])
        for i, (s, h) in enumerate(zip(samples, hypotheses)):
            rn = refs_norm[i] if i < len(refs_norm) else ""
            hn = hyps_norm[i] if i < len(hyps_norm) else ""
            sw = per_wer[i] if i < len(per_wer) else ""
            writer.writerow([s["id"], f'{s["duration"]:.2f}', s["reference"], h, rn, hn, sw])
    print(f"    Saved: {csv_path}")

    # Summary JSON
    summary = {
        "dataset": dataset_name,
        "model": HF_MODEL,
        "total_samples": len(samples),
        "valid_samples": metrics["valid_samples"],
        "wer": round(metrics["wer"] * 100, 2),
        "cer": round(metrics["cer"] * 100, 2),
        "median_wer": round(metrics.get("median_wer", 0) * 100, 2),
        "p90_wer": round(metrics.get("p90_wer", 0) * 100, 2),
    }
    json_path = os.path.join(RESULTS_DIR, f"{dataset_name}_summary.json")
    with open(json_path, "w") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)
    print(f"    Saved: {json_path}")

    return summary


# ─────────────────────────────────────────────────────────────────────────────
# Main benchmark
# ─────────────────────────────────────────────────────────────────────────────
def run_benchmark(dataset_name: str, samples: List[Dict], asr_model) -> Dict:
    """Run benchmark on a dataset: transcribe → compute metrics → save."""
    if not samples:
        print(f"  ⚠ No samples for {dataset_name}, skipping.")
        return {}

    audio_paths = [s["audio_path"] for s in samples]
    references = [s["reference"] for s in samples]
    total_audio_dur = sum(s["duration"] for s in samples)

    print(f"\n  Transcribing {len(samples)} samples ({total_audio_dur:.1f}s audio)...")
    hypotheses, elapsed = transcribe_batch(asr_model, audio_paths)

    rtf = elapsed / total_audio_dur if total_audio_dur > 0 else 0
    print(f"    Inference: {elapsed:.1f}s | RTF: {rtf:.4f} ({1/rtf:.0f}x real-time)")

    print(f"  Computing WER/CER...")
    metrics = compute_metrics(references, hypotheses)

    summary = save_results(dataset_name, samples, hypotheses, metrics)
    summary["total_audio_sec"] = round(total_audio_dur, 1)
    summary["inference_sec"] = round(elapsed, 1)
    summary["rtf"] = round(rtf, 4)

    return summary


def main():
    print("=" * 70)
    print("  Hindi Nemotron ASR — Benchmark Suite")
    print(f"  Model: {HF_MODEL}")
    print(f"  GPU:   {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
    print("=" * 70)

    asr_model = load_model()
    summaries = []

    with tempfile.TemporaryDirectory(prefix="hindi_bench_") as tmp_dir:
        # ── Benchmark 1: FLEURS Hindi ──────────────────────────────────────
        print("\n" + "─" * 70)
        print("  BENCHMARK: FLEURS Hindi (google/fleurs hi_in test)")
        print("─" * 70)
        fleurs_samples = load_fleurs_hindi(tmp_dir)
        fleurs_summary = run_benchmark("fleurs_hindi", fleurs_samples, asr_model)
        if fleurs_summary:
            summaries.append(fleurs_summary)

        # ── Benchmark 2: IndicVoices Hindi (first 200 samples) ────────────
        print("\n" + "─" * 70)
        print("  BENCHMARK: IndicVoices Hindi (ai4bharat/indicvoices_r test, 200 samples)")
        print("─" * 70)
        iv_samples = load_indicvoices_hindi(tmp_dir, max_samples=200)
        if iv_samples:
            iv_summary = run_benchmark("indicvoices_hindi", iv_samples, asr_model)
            if iv_summary:
                summaries.append(iv_summary)

    # ── Final Report ──────────────────────────────────────────────────────
    print("\n" + "=" * 70)
    print("  BENCHMARK RESULTS SUMMARY")
    print("=" * 70)
    print(f"  {'Dataset':<25} {'Samples':>8} {'Audio(s)':>10} {'WER%':>8} {'CER%':>8} {'RTF':>8}")
    print(f"  {'─'*25} {'─'*8} {'─'*10} {'─'*8} {'─'*8} {'─'*8}")
    for s in summaries:
        print(f"  {s['dataset']:<25} {s['valid_samples']:>8} {s['total_audio_sec']:>10.1f} "
              f"{s['wer']:>7.2f}% {s['cer']:>7.2f}% {s['rtf']:>8.4f}")
    print("=" * 70)
    print(f"  Results saved to: {RESULTS_DIR}/")
    print("=" * 70)


if __name__ == "__main__":
    main()
