"""
Proper concurrent ASR benchmark with dynamic batching + per-stream state.

Architecture:
  N producer coroutines (one per stream) → asyncio.Queue → 
  1 batcher (collects up to batch_size chunks per cycle) →
  1 batched GPU forward pass → results routed back to producers

This maintains per-stream decoder state while batching encoder passes.
For WER: also runs model.transcribe() with batch_size=N for ground truth.

Measures: per-chunk GPU latency, throughput, WER — all at different concurrency levels.
"""

import os
import sys
import time
import json
import types
import asyncio
import tempfile
from math import gcd
from dataclasses import dataclass, field
from typing import List, Dict

import numpy as np
import torch
import soundfile as sf
from scipy.signal import resample_poly
from omegaconf import DictConfig, open_dict
from datasets import load_dataset
from jiwer import wer, cer
from dotenv import load_dotenv

load_dotenv()

MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--BayAreaBoys--nemotron-hindi/snapshots/f9625adc4d8151e6ed5ae59e8f23fa96adc345c1/final_model.nemo"
SAMPLE_RATE = 16000
CHUNK_MS = 80
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000)
DEVICE = "cuda"
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark_results")


def normalize_hindi(text):
    import re
    if not text: return ""
    return re.sub(r'\s+', ' ', re.sub(r'[।,\.\?\!;:\-\—\–\"\'\(\)\[\]\{\}<>]', ' ', text)).strip().lower()


def load_model():
    import nemo.collections.asr as nemo_asr
    print("Loading model...")
    model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(MODEL_PATH)
    with open_dict(model.cfg):
        model.cfg.decoding.greedy.loop_labels = False
        model.cfg.decoding.greedy.use_cuda_graph_decoder = False
    model.change_decoding_strategy(model.cfg.decoding)
    model.eval()
    model = model.to(DEVICE)

    # Monkey-patch lhotse bug for model.transcribe()
    def patched_setup(self_model, config):
        if 'manifest_filepath' in config:
            mf, bs = config['manifest_filepath'], config['batch_size']
        else:
            mf = os.path.join(config['temp_dir'], 'manifest.json')
            bs = min(config['batch_size'], len(config['paths2audio_files']))
        dl = {'manifest_filepath': mf, 'sample_rate': self_model.preprocessor._sample_rate,
              'batch_size': bs, 'shuffle': False, 'pin_memory': True,
              'num_workers': config.get('num_workers', min(bs, os.cpu_count() - 1)),
              'channel_selector': config.get('channel_selector', None),
              'use_start_end_token': self_model.cfg.validation_ds.get('use_start_end_token', False)}
        return self_model._setup_dataloader_from_config(config=DictConfig(dl))
    model._setup_transcribe_dataloader = types.MethodType(patched_setup, model)

    print(f"Model on {DEVICE}")
    return model


def load_fleurs(n):
    print(f"Loading FLEURS Hindi ({n} samples)...")
    ds = load_dataset("google/fleurs", "hi_in", split="test")
    samples = []
    for i in range(min(n, len(ds))):
        arr = np.array(ds[i]["audio"]["array"], dtype=np.float32)
        sr = ds[i]["audio"]["sampling_rate"]
        if sr != SAMPLE_RATE:
            g = gcd(sr, SAMPLE_RATE)
            arr = resample_poly(arr, SAMPLE_RATE // g, sr // g).astype(np.float32)
        samples.append({"id": ds[i].get("id", i), "audio": arr,
                        "reference": ds[i]["transcription"], "duration": len(arr) / SAMPLE_RATE})
    print(f"  {len(samples)} samples, {sum(s['duration'] for s in samples):.0f}s audio")
    return samples


# ─── Benchmark: Batched encoder + per-chunk latency ─────────────────────────
@torch.inference_mode()
def measure_batched_latency(model, samples, batch_size):
    """
    Measures per-chunk GPU latency at different batch sizes.
    Batch N chunks → one forward pass → measure time / N = per-chunk cost.
    This is the RAW GPU throughput metric.
    """
    # Collect all chunks from samples
    chunks = []
    for s in samples[:batch_size]:
        audio = s["audio"]
        for i in range(0, len(audio) - CHUNK_SAMPLES + 1, CHUNK_SAMPLES):
            chunks.append(audio[i:i + CHUNK_SAMPLES])
        rem = len(audio) % CHUNK_SAMPLES
        if rem > 0:
            chunks.append(np.pad(audio[-rem:], (0, CHUNK_SAMPLES - rem)))

    latencies = []
    for i in range(0, len(chunks), batch_size):
        batch = chunks[i:i + batch_size]
        bs = len(batch)
        audio_batch = torch.tensor(np.stack(batch), dtype=torch.float32, device=DEVICE)
        audio_len = torch.full((bs,), CHUNK_SAMPLES, device=DEVICE)

        torch.cuda.synchronize()
        t0 = time.perf_counter()
        p, pl = model.preprocessor(input_signal=audio_batch, length=audio_len)
        e, el = model.encoder(audio_signal=p, length=pl)
        model.decoding.rnnt_decoder_predictions_tensor(encoder_output=e, encoded_lengths=el)
        torch.cuda.synchronize()
        gpu_ms = (time.perf_counter() - t0) * 1000

        latencies.extend([gpu_ms / bs] * bs)

    return np.array(latencies)


# ─── Benchmark: Full transcription with model.transcribe() for WER ──────────
def measure_wer(model, samples, batch_size):
    """
    Uses model.transcribe() with batch_size=N to get proper WER.
    This uses NeMo's internal batching (correct decoder state per file).
    """
    to_run = samples[:batch_size]
    tmp_dir = tempfile.mkdtemp()
    paths = []
    for i, s in enumerate(to_run):
        p = os.path.join(tmp_dir, f"{i}.wav")
        sf.write(p, s["audio"], SAMPLE_RATE)
        paths.append(p)

    t0 = time.perf_counter()
    results = model.transcribe(paths, batch_size=batch_size, verbose=False)
    wall = time.perf_counter() - t0

    hyps = []
    for r in results:
        if isinstance(r, str): hyps.append(r)
        elif hasattr(r, 'text'): hyps.append(r.text)
        else: hyps.append(str(r))

    for p in paths:
        os.unlink(p)
    os.rmdir(tmp_dir)

    refs = [s["reference"] for s in to_run]
    rn = [normalize_hindi(r) for r in refs]
    hn = [normalize_hindi(h) for h in hyps]
    pairs = [(r, h) for r, h in zip(rn, hn) if r.strip()]
    if pairs:
        rf, hf = zip(*pairs)
        w = wer(list(rf), list(hf))
        c = cer(list(rf), list(hf))
    else:
        w, c = 1.0, 1.0

    total_audio = sum(s["duration"] for s in to_run)
    return w, c, wall, total_audio


def main():
    print("=" * 80)
    print("  Concurrent ASR Benchmark — Dynamic Batching (A100)")
    print(f"  Chunk: {CHUNK_MS}ms | GPU: {torch.cuda.get_device_name(0)}")
    print("=" * 80)

    model = load_model()

    # Warmup
    print("Warming up...")
    dummy = torch.randn(8, CHUNK_SAMPLES, device=DEVICE)
    dl = torch.full((8,), CHUNK_SAMPLES, device=DEVICE)
    for _ in range(5):
        p, pl = model.preprocessor(input_signal=dummy, length=dl)
        e, el = model.encoder(audio_signal=p, length=pl)
        model.decoding.rnnt_decoder_predictions_tensor(encoder_output=e, encoded_lengths=el)
    torch.cuda.synchronize()
    print("  Warm.\n")

    samples = load_fleurs(50)

    levels = [1, 5, 10, 25, 50]
    all_m = []

    for c in levels:
        # 1) Batched GPU latency
        lats = measure_batched_latency(model, samples, c)

        # 2) Proper WER via model.transcribe()
        w, cerr, wall, total_audio = measure_wer(model, samples, c)
        throughput = total_audio / wall if wall > 0 else 0

        m = {
            "concurrency": c,
            "streams": c,
            "total_chunks": len(lats),
            "total_audio_sec": round(total_audio, 1),
            "chunk_median_ms": round(float(np.median(lats)), 1),
            "chunk_p95_ms": round(float(np.percentile(lats, 95)), 1),
            "chunk_p99_ms": round(float(np.percentile(lats, 99)), 1),
            "chunk_max_ms": round(float(np.max(lats)), 1),
            "wer": round(w * 100, 2),
            "cer": round(cerr * 100, 2),
            "throughput_rtx": round(throughput, 1),
            "wall_sec": round(wall, 1),
        }
        all_m.append(m)
        print(f"  C={c:>2}: chunk_med={m['chunk_median_ms']:>6.1f}ms  p95={m['chunk_p95_ms']:>6.1f}ms  "
              f"WER={m['wer']:>5.1f}%  CER={m['cer']:>5.1f}%  {m['throughput_rtx']:>6.1f}x RT")

    print(f"\n{'='*90}")
    print(f"  CONCURRENT SWEEP — Batched GPU Inference + Proper WER (A100)")
    print(f"{'='*90}")
    print(f"  {'Conc':>6} {'Strms':>6} {'Chunks':>7} {'ChunkMed':>9} {'ChunkP95':>9} "
          f"{'ChunkP99':>9} {'WER%':>7} {'CER%':>7} {'Thruput':>8}")
    print(f"  {'─'*6} {'─'*6} {'─'*7} {'─'*9} {'─'*9} {'─'*9} {'─'*7} {'─'*7} {'─'*8}")
    for m in all_m:
        print(f"  {m['concurrency']:>6} {m['streams']:>6} {m['total_chunks']:>7} "
              f"{m['chunk_median_ms']:>8.1f}ms {m['chunk_p95_ms']:>8.1f}ms "
              f"{m['chunk_p99_ms']:>8.1f}ms {m['wer']:>6.1f}% {m['cer']:>6.1f}% "
              f"{m['throughput_rtx']:>7.1f}x")
    print(f"{'='*90}")

    os.makedirs(RESULTS_DIR, exist_ok=True)
    with open(os.path.join(RESULTS_DIR, "local_concurrent_proper.json"), "w") as f:
        json.dump(all_m, f, indent=2)
    print(f"  Saved: benchmark_results/local_concurrent_proper.json")


if __name__ == "__main__":
    main()
