"""
Batched per-chunk latency test — simulates N concurrent streams via GPU batching.

WHY batching, not threading:
- Single GPU serializes CUDA ops. ThreadPoolExecutor = GIL hell, 17% GPU util.
- Real concurrency = batch N chunks from N streams into ONE forward pass.
- This is exactly what the Modal deployment's central batching loop does.
- Measures: does per-chunk latency degrade as batch size (= concurrent streams) grows?
"""

import os
import time
import json
import numpy as np
import torch
from math import gcd
from scipy.signal import resample_poly
from datasets import load_dataset
from jiwer import wer, cer
from dotenv import load_dotenv

load_dotenv()

MODEL_PATH = "/home/ubuntu/.cache/huggingface/hub/models--BayAreaBoys--nemotron-hindi/snapshots/f9625adc4d8151e6ed5ae59e8f23fa96adc345c1/final_model.nemo"
SAMPLE_RATE = 16000
CHUNK_MS = 80
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000)
DEVICE = "cuda"
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark_results")


def normalize_hindi(text):
    import re
    if not text: return ""
    return re.sub(r'\s+', ' ', re.sub(r'[।,\.\?\!;:\-\—\–\"\'\(\)\[\]\{\}<>]', ' ', text)).strip().lower()


def load_model():
    import nemo.collections.asr as nemo_asr
    from omegaconf import open_dict
    print("Loading model...")
    model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(MODEL_PATH)
    with open_dict(model.cfg):
        model.cfg.decoding.greedy.loop_labels = False
        model.cfg.decoding.greedy.use_cuda_graph_decoder = False
    model.change_decoding_strategy(model.cfg.decoding)
    model.eval()
    model = model.to(DEVICE)
    print(f"Model on {DEVICE}")
    return model


def load_fleurs(n):
    print(f"Loading FLEURS Hindi ({n} samples)...")
    ds = load_dataset("google/fleurs", "hi_in", split="test")
    samples = []
    for i in range(min(n, len(ds))):
        arr = np.array(ds[i]["audio"]["array"], dtype=np.float32)
        sr = ds[i]["audio"]["sampling_rate"]
        if sr != SAMPLE_RATE:
            g = gcd(sr, SAMPLE_RATE)
            arr = resample_poly(arr, SAMPLE_RATE // g, sr // g).astype(np.float32)
        samples.append({"id": ds[i].get("id", i), "audio": arr, "reference": ds[i]["transcription"],
                        "duration": len(arr) / SAMPLE_RATE})
    print(f"  {len(samples)} samples, {sum(s['duration'] for s in samples):.0f}s audio")
    return samples


@torch.inference_mode()
def bench_batch_size(model, samples, batch_size):
    """
    Simulate `batch_size` concurrent streams:
    - Take batch_size samples
    - Process ALL their chunks in batched forward passes (batch_size chunks per pass)
    - Measure per-chunk latency for the batched inference
    """
    to_run = samples[:batch_size]

    # Split each sample into chunks
    all_chunks = []  # list of (sample_idx, chunk_array)
    for idx, s in enumerate(to_run):
        audio = s["audio"]
        for i in range(0, len(audio) - CHUNK_SAMPLES + 1, CHUNK_SAMPLES):
            all_chunks.append((idx, audio[i:i + CHUNK_SAMPLES]))
        # Pad last chunk
        rem = len(audio) % CHUNK_SAMPLES
        if rem > 0:
            all_chunks.append((idx, np.pad(audio[-rem:], (0, CHUNK_SAMPLES - rem))))

    # Process chunks in batches of `batch_size` (simulating N concurrent streams)
    latencies = []
    per_sample_text = {i: [] for i in range(batch_size)}
    total_chunks = len(all_chunks)

    for batch_start in range(0, total_chunks, batch_size):
        batch = all_chunks[batch_start:batch_start + batch_size]
        actual_bs = len(batch)

        # Stack into a single batched tensor
        audio_batch = torch.tensor(
            np.stack([c[1] for c in batch]), dtype=torch.float32, device=DEVICE
        )  # (bs, chunk_samples)
        audio_lengths = torch.full((actual_bs,), CHUNK_SAMPLES, device=DEVICE)

        t0 = time.perf_counter()

        # Batched forward pass — this is the REAL GPU work
        processed, processed_len = model.preprocessor(input_signal=audio_batch, length=audio_lengths)
        encoded, encoded_len = model.encoder(audio_signal=processed, length=processed_len)
        hypotheses = model.decoding.rnnt_decoder_predictions_tensor(
            encoder_output=encoded, encoded_lengths=encoded_len)

        gpu_time_ms = (time.perf_counter() - t0) * 1000
        per_chunk_ms = gpu_time_ms / actual_bs
        latencies.extend([per_chunk_ms] * actual_bs)

        # Collect text per sample
        for i, (sidx, _) in enumerate(batch):
            if i < len(hypotheses) and hasattr(hypotheses[i], 'text') and hypotheses[i].text:
                per_sample_text[sidx].append(hypotheses[i].text)

    # Build hypotheses
    hyps = []
    for idx in range(batch_size):
        hyps.append(" ".join(per_sample_text.get(idx, [])))

    refs = [s["reference"] for s in to_run]
    total_audio = sum(s["duration"] for s in to_run)

    # WER/CER
    rn = [normalize_hindi(r) for r in refs]
    hn = [normalize_hindi(h) for h in hyps]
    pairs = [(r, h) for r, h in zip(rn, hn) if r.strip()]
    if pairs:
        rf, hf = zip(*pairs)
        w = wer(list(rf), list(hf))
        c = cer(list(rf), list(hf))
    else:
        w, c = 1.0, 1.0

    lats = np.array(latencies)
    wall = sum(latencies) / 1000  # total GPU time in seconds
    throughput = total_audio / wall if wall > 0 else 0

    return {
        "concurrency": batch_size,
        "streams": batch_size,
        "total_chunks": len(latencies),
        "total_audio_sec": round(total_audio, 1),
        "chunk_median_ms": round(float(np.median(lats)), 1),
        "chunk_mean_ms": round(float(np.mean(lats)), 1),
        "chunk_p95_ms": round(float(np.percentile(lats, 95)), 1),
        "chunk_p99_ms": round(float(np.percentile(lats, 99)), 1),
        "chunk_max_ms": round(float(np.max(lats)), 1),
        "wer": round(w * 100, 2),
        "cer": round(c * 100, 2),
        "throughput_rtx": round(throughput, 1),
    }


def main():
    print("=" * 80)
    print("  BATCHED Concurrent Latency Test — A100 GPU")
    print(f"  Chunk: {CHUNK_MS}ms | Simulates N concurrent streams via GPU batching")
    print("=" * 80)

    model = load_model()

    # Warmup
    print("Warming up...")
    dummy = torch.randn(4, CHUNK_SAMPLES, device=DEVICE)
    dummy_len = torch.full((4,), CHUNK_SAMPLES, device=DEVICE)
    for _ in range(5):
        p, pl = model.preprocessor(input_signal=dummy, length=dummy_len)
        e, el = model.encoder(audio_signal=p, length=pl)
        model.decoding.rnnt_decoder_predictions_tensor(encoder_output=e, encoded_lengths=el)
    torch.cuda.synchronize()
    print("  Warm.\n")

    samples = load_fleurs(50)

    levels = [1, 5, 10, 25, 50]
    all_m = []

    for c in levels:
        m = bench_batch_size(model, samples, c)
        all_m.append(m)
        print(f"  C={c:>2}: med={m['chunk_median_ms']:>6.1f}ms  p95={m['chunk_p95_ms']:>6.1f}ms  "
              f"p99={m['chunk_p99_ms']:>6.1f}ms  WER={m['wer']:>5.1f}%  {m['throughput_rtx']:>6.1f}x RT")

    print(f"\n{'='*80}")
    print(f"  CONCURRENT SWEEP — Batched GPU Inference (A100)")
    print(f"{'='*80}")
    print(f"  {'Conc':>6} {'Strms':>6} {'Chunks':>7} {'Med':>8} {'P95':>8} "
          f"{'P99':>8} {'Max':>8} {'WER%':>7} {'Thru':>8}")
    print(f"  {'─'*6} {'─'*6} {'─'*7} {'─'*8} {'─'*8} {'─'*8} {'─'*8} {'─'*7} {'─'*8}")
    for m in all_m:
        print(f"  {m['concurrency']:>6} {m['streams']:>6} {m['total_chunks']:>7} "
              f"{m['chunk_median_ms']:>7.1f}ms {m['chunk_p95_ms']:>7.1f}ms "
              f"{m['chunk_p99_ms']:>7.1f}ms {m['chunk_max_ms']:>7.1f}ms "
              f"{m['wer']:>6.1f}% {m['throughput_rtx']:>7.1f}x")
    print(f"{'='*80}")

    os.makedirs(RESULTS_DIR, exist_ok=True)
    with open(os.path.join(RESULTS_DIR, "local_batched_sweep.json"), "w") as f:
        json.dump(all_m, f, indent=2)
    print(f"  Saved: benchmark_results/local_batched_sweep.json")


if __name__ == "__main__":
    main()
