"""
Concurrent WebSocket load test for Hindi Nemotron ASR on Modal.

WHY this rewrite: Previous version measured total wall-clock time to stream a file
at real-time pace — that's dominated by audio duration, NOT ASR latency.
The HF blog reports 182ms median CHUNK-LEVEL delay (audio_timestamp vs transcript_timestamp).

This version measures:
1. Per-chunk ASR processing delay (the REAL latency metric)
2. Throughput: total audio seconds / wall seconds
3. Latency drift: does per-chunk delay increase with concurrency?
4. WER/CER on collected transcriptions

Usage:
  python load_test.py --concurrency 10 --max-samples 50
  python load_test.py --sweep --max-samples 100
"""

import os
import sys
import json
import time
import asyncio
import argparse
import csv
from dataclasses import dataclass, field
from typing import List, Dict, Optional

import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
from math import gcd
from datasets import load_dataset
from jiwer import wer, cer

try:
    import websockets
except ImportError:
    os.system(f"{sys.executable} -m pip install websockets -q")
    import websockets

# ─── Config ──────────────────────────────────────────────────────────────────
WS_URL = "wss://mayaresearch--hindi-nemotron-asr-nemotronasr-webapp.modal.run/ws"
TARGET_SR = 16000
CHUNK_DURATION_S = 0.56       # 560ms chunks (att_context_size [70,6])
CHUNK_SIZE_SAMPLES = int(CHUNK_DURATION_S * TARGET_SR)
CHUNK_SIZE_BYTES = CHUNK_SIZE_SAMPLES * 2  # 16-bit PCM = 2 bytes/sample
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark_results")


def normalize_hindi(text: str) -> str:
    import re
    if not text:
        return ""
    text = re.sub(r'[।,\.\?\!;:\-\—\–\"\'\(\)\[\]\{\}<>]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip().lower()
    return text


def audio_to_pcm_bytes(audio_array: np.ndarray, orig_sr: int) -> bytes:
    if orig_sr != TARGET_SR:
        g = gcd(orig_sr, TARGET_SR)
        up, down = TARGET_SR // g, orig_sr // g
        audio_array = resample_poly(audio_array, up, down)
    return (audio_array * 32767).astype(np.int16).tobytes()


def load_fleurs_samples(max_samples: int = None) -> List[Dict]:
    print("Loading FLEURS Hindi test set...")
    ds = load_dataset("google/fleurs", "hi_in", split="test")
    total = min(len(ds), max_samples) if max_samples else len(ds)
    samples = []
    for i in range(total):
        item = ds[i]
        audio = item["audio"]
        arr = np.array(audio["array"], dtype=np.float32)
        pcm = audio_to_pcm_bytes(arr, audio["sampling_rate"])
        samples.append({
            "id": item.get("id", i),
            "pcm_bytes": pcm,
            "reference": item["transcription"],
            "duration": len(pcm) / (TARGET_SR * 2),
        })
    total_dur = sum(s['duration'] for s in samples)
    print(f"  Loaded {len(samples)} samples ({total_dur:.0f}s audio)")
    return samples


# ─── Per-stream result ────────────────────────────────────────────────────────
@dataclass
class StreamResult:
    sample_id: int = 0
    reference: str = ""
    hypothesis: str = ""
    duration: float = 0.0
    # WHY: These are the REAL metrics — per-chunk ASR processing delay
    chunk_delays: list = field(default_factory=list)  # per-chunk: server_ts - audio_ts
    first_transcript_delay: float = 0.0   # time from first chunk sent → first transcript back
    wall_time: float = 0.0               # total wall clock (includes streaming)
    error: Optional[str] = None


async def stream_one_sample(sample: Dict, session_id: int) -> StreamResult:
    """
    Stream audio over WebSocket, measure per-chunk ASR processing delay.
    
    WHY we DON'T sleep at real-time pace:
    - The server's batching loop (300ms cycle) handles timing
    - We send chunks rapidly, the server buffers and processes in batches
    - We measure delay from the server's own audio_timestamp vs wall clock
    - This stress-tests the server at max throughput, not at 1x pace
    """
    result = StreamResult(
        sample_id=sample["id"],
        reference=sample["reference"],
        duration=sample["duration"],
    )
    pcm_bytes = sample["pcm_bytes"]
    all_text_deltas = []
    final_segments = []
    chunk_delays = []
    
    t_wall_start = time.perf_counter()
    t_first_send = None
    t_first_transcript = None

    try:
        async with websockets.connect(
            WS_URL,
            max_size=10 * 1024 * 1024,
            open_timeout=300,
            close_timeout=30,
            ping_interval=30,
            ping_timeout=60,
        ) as ws:
            # Wait for Ready
            ready_msg = await asyncio.wait_for(ws.recv(), timeout=300)
            ready = json.loads(ready_msg)
            if ready.get("type") != "Ready":
                result.error = f"Expected Ready, got: {ready}"
                return result

            # === SEND + RECEIVE CONCURRENTLY ===
            # WHY: We run sender and receiver as parallel async tasks so we can
            # measure per-chunk latency as transcripts arrive WHILE streaming.
            send_done = asyncio.Event()
            
            async def sender():
                nonlocal t_first_send
                # Send chunks with minimal delay — just enough to not overwhelm
                # the WebSocket frame buffer. Server handles batching.
                n_chunks = 0
                for offset in range(0, len(pcm_bytes), CHUNK_SIZE_BYTES):
                    chunk = pcm_bytes[offset:offset + CHUNK_SIZE_BYTES]
                    await ws.send(chunk)
                    if t_first_send is None:
                        t_first_send = time.perf_counter()
                    n_chunks += 1
                    # WHY: Small yield to let receiver task run + avoid saturating
                    # the WebSocket. 10ms << 560ms chunk, so server sees continuous flow.
                    await asyncio.sleep(0.01)
                
                # Brief pause for server to process last buffered frames
                await asyncio.sleep(1.0)
                await ws.send("END")
                send_done.set()

            async def receiver():
                nonlocal t_first_transcript
                while True:
                    try:
                        msg = await asyncio.wait_for(ws.recv(), timeout=15)
                        t_recv = time.perf_counter()
                        
                        if isinstance(msg, str):
                            if msg == "END":
                                break
                            data = json.loads(msg)
                        elif isinstance(msg, bytes):
                            import msgpack
                            data = msgpack.unpackb(msg, raw=False)
                            if data.get("type") == "Marker":
                                break
                        else:
                            continue

                        text = data.get("text", "")
                        seg_text = data.get("segment_text", "")
                        audio_ts = data.get("audio_timestamp", 0.0)
                        
                        if text or seg_text:
                            if t_first_transcript is None:
                                t_first_transcript = t_recv
                            
                            # WHY: Per-chunk delay = wall_time_of_receipt - audio_timestamp
                            # This is the REAL ASR latency metric from the blog.
                            # audio_ts tracks cumulative audio duration processed.
                            if audio_ts > 0 and t_first_send:
                                elapsed_since_start = t_recv - t_first_send
                                chunk_delay = elapsed_since_start - audio_ts
                                chunk_delays.append(chunk_delay)
                        
                        if text:
                            all_text_deltas.append(text)
                        if seg_text and data.get("is_final", False):
                            final_segments.append(seg_text)

                    except asyncio.TimeoutError:
                        break

            # Run sender and receiver in parallel
            await asyncio.gather(sender(), receiver())

    except Exception as e:
        result.error = str(e)
        result.wall_time = time.perf_counter() - t_wall_start
        return result

    t_wall_end = time.perf_counter()
    
    if final_segments:
        result.hypothesis = " ".join(final_segments).strip()
    elif all_text_deltas:
        result.hypothesis = "".join(all_text_deltas).strip()
    
    result.wall_time = t_wall_end - t_wall_start
    result.chunk_delays = chunk_delays
    result.first_transcript_delay = (
        (t_first_transcript - t_first_send) if (t_first_transcript and t_first_send) else 0.0
    )
    return result


# ─── Concurrent runner ────────────────────────────────────────────────────────
async def run_concurrent_batch(samples: List[Dict], concurrency: int) -> List[StreamResult]:
    tasks = []
    for i in range(min(concurrency, len(samples))):
        tasks.append(stream_one_sample(samples[i], session_id=i))

    print(f"\n  Streaming {len(tasks)} concurrent connections...")
    t0 = time.perf_counter()
    results = await asyncio.gather(*tasks, return_exceptions=True)
    wall = time.perf_counter() - t0

    final = []
    for r in results:
        if isinstance(r, Exception):
            final.append(StreamResult(error=str(r)))
        else:
            final.append(r)

    total_audio = sum(r.duration for r in final if r.error is None)
    throughput = total_audio / wall if wall > 0 else 0
    print(f"  Done in {wall:.1f}s | {total_audio:.0f}s audio | throughput: {throughput:.1f}x real-time")
    return final


# ─── Metrics ──────────────────────────────────────────────────────────────────
def compute_metrics(results: List[StreamResult], wall_time: float = 0) -> Dict:
    valid = [r for r in results if r.error is None and r.reference.strip()]
    if not valid:
        return {"error": "No valid results"}

    refs = [normalize_hindi(r.reference) for r in valid]
    hyps = [normalize_hindi(r.hypothesis) for r in valid]
    pairs = [(r, h) for r, h in zip(refs, hyps) if r.strip()]
    if not pairs:
        return {"error": "No non-empty references"}
    refs_f, hyps_f = zip(*pairs)

    corpus_wer = wer(list(refs_f), list(hyps_f))
    corpus_cer = cer(list(refs_f), list(hyps_f))

    # Aggregate per-chunk delays across ALL streams
    all_delays = []
    for r in valid:
        all_delays.extend(r.chunk_delays)
    
    first_tok_delays = [r.first_transcript_delay for r in valid if r.first_transcript_delay > 0]
    total_audio = sum(r.duration for r in valid)

    m = {
        "valid_samples": len(valid),
        "errors": len(results) - len(valid),
        "wer": round(corpus_wer * 100, 2),
        "cer": round(corpus_cer * 100, 2),
        "total_audio_sec": round(total_audio, 1),
    }
    
    # WHY: These are the metrics that matter — per-chunk ASR processing delay
    if all_delays:
        m["chunk_delay_median_ms"] = round(np.median(all_delays) * 1000, 1)
        m["chunk_delay_p95_ms"] = round(np.percentile(all_delays, 95) * 1000, 1)
        m["chunk_delay_p99_ms"] = round(np.percentile(all_delays, 99) * 1000, 1)
        m["chunk_delay_max_ms"] = round(max(all_delays) * 1000, 1)
        m["total_chunks"] = len(all_delays)
    
    if first_tok_delays:
        m["first_token_median_ms"] = round(np.median(first_tok_delays) * 1000, 1)
    
    return m


def save_results(concurrency: int, results: List[StreamResult], metrics: Dict):
    os.makedirs(RESULTS_DIR, exist_ok=True)
    tag = f"concurrent_{concurrency}"

    csv_path = os.path.join(RESULTS_DIR, f"loadtest_{tag}_details.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "duration", "reference", "hypothesis", "wall_time",
                     "first_transcript_delay", "n_chunks", "chunk_delay_median_ms", "error"])
        for r in results:
            med = round(np.median(r.chunk_delays) * 1000, 1) if r.chunk_delays else ""
            w.writerow([r.sample_id, f"{r.duration:.2f}", r.reference, r.hypothesis,
                        f"{r.wall_time:.3f}", f"{r.first_transcript_delay:.3f}",
                        len(r.chunk_delays), med, r.error or ""])

    json_path = os.path.join(RESULTS_DIR, f"loadtest_{tag}_summary.json")
    with open(json_path, "w") as f:
        json.dump({"concurrency": concurrency, "ws_url": WS_URL, **metrics},
                  f, indent=2, ensure_ascii=False)
    print(f"  Saved: {csv_path}")


# ─── Main ─────────────────────────────────────────────────────────────────────
async def main_async(concurrency: int, max_samples: int):
    samples = load_fleurs_samples(max_samples=max_samples)

    print(f"\n{'='*70}")
    print(f"  Hindi ASR Latency Benchmark — {concurrency} concurrent streams")
    print(f"  Endpoint: {WS_URL}")
    print(f"  Chunk: {CHUNK_DURATION_S*1000:.0f}ms | Samples: {len(samples)}")
    print(f"{'='*70}")

    t0 = time.perf_counter()
    results = await run_concurrent_batch(samples, concurrency)
    wall = time.perf_counter() - t0
    metrics = compute_metrics(results, wall)
    
    total_audio = metrics.get('total_audio_sec', 0)
    throughput = total_audio / wall if wall > 0 else 0
    metrics["wall_time"] = round(wall, 2)
    metrics["throughput_rtx"] = round(throughput, 1)

    print(f"\n{'─'*70}")
    print(f"  RESULTS (concurrency={concurrency})")
    print(f"{'─'*70}")
    print(f"  Streams:     {metrics.get('valid_samples', 0)} ok, {metrics.get('errors', 0)} errors")
    print(f"  WER / CER:   {metrics.get('wer', 'N/A')}% / {metrics.get('cer', 'N/A')}%")
    print(f"  Throughput:   {throughput:.1f}x real-time ({total_audio:.0f}s audio in {wall:.1f}s)")
    cd_med = metrics.get('chunk_delay_median_ms', 'N/A')
    cd_p95 = metrics.get('chunk_delay_p95_ms', 'N/A')
    cd_p99 = metrics.get('chunk_delay_p99_ms', 'N/A')
    n_chunks = metrics.get('total_chunks', 0)
    print(f"  Chunk delay:  median={cd_med}ms  p95={cd_p95}ms  p99={cd_p99}ms  ({n_chunks} chunks)")
    print(f"  First token:  {metrics.get('first_token_median_ms', 'N/A')}ms median")
    print(f"{'─'*70}")

    save_results(concurrency, results, metrics)
    return metrics


def main():
    parser = argparse.ArgumentParser(description="Hindi ASR latency benchmark")
    parser.add_argument("--concurrency", type=int, default=1)
    parser.add_argument("--max-samples", type=int, default=50)
    parser.add_argument("--sweep", action="store_true", help="Run 1,5,10,25,50 concurrent")
    args = parser.parse_args()

    if args.sweep:
        all_metrics = []
        for c in [1, 5, 10, 25, 50]:
            m = asyncio.run(main_async(c, max_samples=max(c, args.max_samples)))
            m["concurrency"] = c
            all_metrics.append(m)
            time.sleep(2)

        print(f"\n{'='*80}")
        print(f"  CONCURRENCY SWEEP SUMMARY")
        print(f"{'='*80}")
        print(f"  {'Conc':>6} {'Strms':>6} {'WER%':>7} {'CER%':>7} "
              f"{'ChunkMed':>10} {'ChunkP95':>10} {'ChunkP99':>10} {'1stTok':>8} {'Thruput':>8}")
        print(f"  {'─'*6} {'─'*6} {'─'*7} {'─'*7} {'─'*10} {'─'*10} {'─'*10} {'─'*8} {'─'*8}")
        for m in all_metrics:
            print(f"  {m.get('concurrency',0):>6} {m.get('valid_samples',0):>6} "
                  f"{m.get('wer','?'):>6}% {m.get('cer','?'):>6}% "
                  f"{m.get('chunk_delay_median_ms','?'):>9}ms "
                  f"{m.get('chunk_delay_p95_ms','?'):>9}ms "
                  f"{m.get('chunk_delay_p99_ms','?'):>9}ms "
                  f"{m.get('first_token_median_ms','?'):>7}ms "
                  f"{m.get('throughput_rtx','?'):>7}x")
        print(f"{'='*80}")
    else:
        asyncio.run(main_async(args.concurrency, max_samples=args.max_samples))


if __name__ == "__main__":
    main()
