"""
Concurrent WebSocket load test for Hindi Nemotron ASR on Modal.
Streams FLEURS Hindi audio to N parallel connections, collects transcriptions,
computes WER/CER, and reports latency metrics.

Usage:
  python load_test.py --concurrency 1     # single-stream baseline
  python load_test.py --concurrency 10    # 10 parallel streams
  python load_test.py --concurrency 50    # 50 parallel streams
"""

import os
import sys
import json
import time
import asyncio
import argparse
import csv
from dataclasses import dataclass, field
from typing import List, Dict, Optional

import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
from math import gcd
from datasets import load_dataset
from jiwer import wer, cer

try:
    import websockets
except ImportError:
    print("Installing websockets...")
    os.system(f"{sys.executable} -m pip install websockets -q")
    import websockets

# ─── Config ──────────────────────────────────────────────────────────────────
WS_URL = "wss://mayaresearch--hindi-nemotron-asr-nemotronasr-webapp.modal.run/ws"
TARGET_SR = 16000
CHUNK_DURATION_S = 0.56  # 560ms chunks (matches att_context_size [70,6])
CHUNK_SIZE_BYTES = int(CHUNK_DURATION_S * TARGET_SR * 2)  # 16-bit PCM = 2 bytes/sample
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark_results")


# ─── Hindi text normalization ────────────────────────────────────────────────
def normalize_hindi(text: str) -> str:
    import re
    if not text:
        return ""
    text = re.sub(r'[।,\.\?\!;:\-\—\–\"\'\(\)\[\]\{\}]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip().lower()
    return text


# ─── Audio helpers ───────────────────────────────────────────────────────────
def audio_to_pcm_bytes(audio_array: np.ndarray, orig_sr: int) -> bytes:
    """Convert audio array → 16kHz 16-bit PCM bytes for WebSocket streaming."""
    if orig_sr != TARGET_SR:
        g = gcd(orig_sr, TARGET_SR)
        up, down = TARGET_SR // g, orig_sr // g
        audio_array = resample_poly(audio_array, up, down)
    # Float32 [-1,1] → int16
    pcm = (audio_array * 32767).astype(np.int16)
    return pcm.tobytes()


# ─── Load FLEURS Hindi ───────────────────────────────────────────────────────
def load_fleurs_samples(max_samples: int = None) -> List[Dict]:
    """Load FLEURS Hindi test → return list of {pcm_bytes, reference, duration}."""
    print("Loading FLEURS Hindi test set...")
    ds = load_dataset("google/fleurs", "hi_in", split="test")
    total = min(len(ds), max_samples) if max_samples else len(ds)

    samples = []
    for i in range(total):
        item = ds[i]
        audio = item["audio"]
        arr = np.array(audio["array"], dtype=np.float32)
        sr = audio["sampling_rate"]
        pcm = audio_to_pcm_bytes(arr, sr)
        duration = len(pcm) / (TARGET_SR * 2)
        samples.append({
            "id": item.get("id", i),
            "pcm_bytes": pcm,
            "reference": item["transcription"],
            "duration": duration,
        })
    print(f"  Loaded {len(samples)} samples ({sum(s['duration'] for s in samples):.0f}s audio)")
    return samples


# ─── Single WebSocket session ─────────────────────────────────────────────────
@dataclass
class StreamResult:
    sample_id: int = 0
    reference: str = ""
    hypothesis: str = ""
    duration: float = 0.0
    latency_first_token: float = 0.0  # time to first transcription result
    latency_total: float = 0.0       # total wall time for this stream
    error: Optional[str] = None


async def stream_one_sample(sample: Dict, session_id: int) -> StreamResult:
    """Stream one audio sample over WebSocket and collect transcription."""
    result = StreamResult(
        sample_id=sample["id"],
        reference=sample["reference"],
        duration=sample["duration"],
    )

    pcm_bytes = sample["pcm_bytes"]
    t_start = time.perf_counter()
    first_token_time = None
    all_text_deltas = []    # accumulate ALL text deltas (incremental)
    final_segments = []     # completed segment texts

    try:
        async with websockets.connect(
            WS_URL,
            max_size=10 * 1024 * 1024,
            open_timeout=300,   # 5 min for Modal cold start (model download + GPU warmup)
            close_timeout=30,
            ping_interval=30,
            ping_timeout=60,
        ) as ws:
            # Wait for Ready message (long timeout for cold start)
            ready_msg = await asyncio.wait_for(ws.recv(), timeout=300)
            ready = json.loads(ready_msg)
            if ready.get("type") != "Ready":
                result.error = f"Expected Ready, got: {ready}"
                return result

            # Stream audio chunks at real-time pace
            for offset in range(0, len(pcm_bytes), CHUNK_SIZE_BYTES):
                chunk = pcm_bytes[offset:offset + CHUNK_SIZE_BYTES]
                await ws.send(chunk)
                # WHY: Must stream at ~real-time to let server's batching loop
                # (300ms delay) process chunks properly. Too fast = lost context.
                await asyncio.sleep(CHUNK_DURATION_S * 0.9)

            # Give server time to process remaining buffered audio before END
            await asyncio.sleep(2.0)
            await ws.send("END")

            # Collect all transcription results until END echo or timeout
            while True:
                try:
                    msg = await asyncio.wait_for(ws.recv(), timeout=10)
                    if isinstance(msg, str):
                        if msg == "END":
                            break
                        data = json.loads(msg)
                        text = data.get("text", "")
                        # Also check segment_text which has full segment
                        seg_text = data.get("segment_text", "")
                        if (text or seg_text) and first_token_time is None:
                            first_token_time = time.perf_counter()
                        if data.get("is_final", False):
                            transcription_parts.append(seg_text or text)
                        else:
                            # Track latest partial for fallback
                            if text:
                                last_partial = text
                    elif isinstance(msg, bytes):
                        import msgpack
                        data = msgpack.unpackb(msg, raw=False)
                        if data.get("type") == "Marker":
                            break
                        text = data.get("text", "")
                        seg_text = data.get("segment_text", "")
                        if (text or seg_text) and first_token_time is None:
                            first_token_time = time.perf_counter()
                        if data.get("is_final", False):
                            transcription_parts.append(seg_text or text)
                        else:
                            if text:
                                last_partial = text
                except asyncio.TimeoutError:
                    break

    except Exception as e:
        result.error = str(e)
        result.latency_total = time.perf_counter() - t_start
        return result

    t_end = time.perf_counter()
    # Use final segments if available, fall back to last partial transcript
    if transcription_parts:
        result.hypothesis = " ".join(transcription_parts).strip()
    else:
        result.hypothesis = last_partial.strip()
    result.latency_total = t_end - t_start
    result.latency_first_token = (first_token_time - t_start) if first_token_time else result.latency_total

    return result


# ─── Concurrent batch runner ─────────────────────────────────────────────────
async def run_concurrent_batch(
    samples: List[Dict],
    concurrency: int,
) -> List[StreamResult]:
    """Run N concurrent WebSocket streams, cycling through samples."""
    # Assign samples round-robin to concurrent slots
    tasks = []
    for i in range(min(concurrency, len(samples))):
        tasks.append(stream_one_sample(samples[i], session_id=i))

    print(f"\n  Streaming {len(tasks)} samples with {concurrency} concurrent connections...")
    t0 = time.perf_counter()
    results = await asyncio.gather(*tasks, return_exceptions=True)
    elapsed = time.perf_counter() - t0

    # Handle exceptions
    final_results = []
    for r in results:
        if isinstance(r, Exception):
            final_results.append(StreamResult(error=str(r)))
        else:
            final_results.append(r)

    total_audio = sum(r.duration for r in final_results if r.error is None)
    print(f"  Completed in {elapsed:.1f}s ({total_audio:.0f}s audio, {len(tasks)} streams)")

    return final_results


# ─── Metrics ──────────────────────────────────────────────────────────────────
def compute_metrics(results: List[StreamResult]) -> Dict:
    """Compute WER, CER, and latency stats from stream results."""
    valid = [r for r in results if r.error is None and r.reference.strip()]
    if not valid:
        return {"error": "No valid results"}

    refs = [normalize_hindi(r.reference) for r in valid]
    hyps = [normalize_hindi(r.hypothesis) for r in valid]

    # Filter out empty refs
    pairs = [(r, h) for r, h in zip(refs, hyps) if r.strip()]
    if not pairs:
        return {"error": "No non-empty references"}
    refs_f, hyps_f = zip(*pairs)

    corpus_wer = wer(list(refs_f), list(hyps_f))
    corpus_cer = cer(list(refs_f), list(hyps_f))

    latencies = [r.latency_total for r in valid]
    first_token_lats = [r.latency_first_token for r in valid]

    return {
        "valid_samples": len(valid),
        "errors": len(results) - len(valid),
        "wer": round(corpus_wer * 100, 2),
        "cer": round(corpus_cer * 100, 2),
        "latency_median": round(np.median(latencies), 3),
        "latency_p95": round(np.percentile(latencies, 95), 3),
        "latency_first_token_median": round(np.median(first_token_lats), 3),
        "total_audio_sec": round(sum(r.duration for r in valid), 1),
    }


# ─── Save results ────────────────────────────────────────────────────────────
def save_results(concurrency: int, results: List[StreamResult], metrics: Dict):
    os.makedirs(RESULTS_DIR, exist_ok=True)
    tag = f"concurrent_{concurrency}"

    # Per-sample CSV
    csv_path = os.path.join(RESULTS_DIR, f"loadtest_{tag}_details.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "duration", "reference", "hypothesis", "latency_total", "latency_first_token", "error"])
        for r in results:
            w.writerow([r.sample_id, f"{r.duration:.2f}", r.reference, r.hypothesis,
                        f"{r.latency_total:.3f}", f"{r.latency_first_token:.3f}", r.error or ""])

    # Summary JSON
    summary = {
        "concurrency": concurrency,
        "ws_url": WS_URL,
        **metrics,
    }
    json_path = os.path.join(RESULTS_DIR, f"loadtest_{tag}_summary.json")
    with open(json_path, "w") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    print(f"  Saved: {csv_path}")
    print(f"  Saved: {json_path}")


# ─── Main ─────────────────────────────────────────────────────────────────────
async def main_async(concurrency: int, max_samples: int):
    samples = load_fleurs_samples(max_samples=max_samples)

    print(f"\n{'='*70}")
    print(f"  Hindi ASR Load Test — {concurrency} concurrent connections")
    print(f"  Endpoint: {WS_URL}")
    print(f"  Dataset: FLEURS Hindi test ({len(samples)} samples)")
    print(f"{'='*70}")

    results = await run_concurrent_batch(samples, concurrency)
    metrics = compute_metrics(results)

    print(f"\n{'─'*70}")
    print(f"  RESULTS (concurrency={concurrency})")
    print(f"{'─'*70}")
    print(f"  Samples:   {metrics.get('valid_samples', 0)} valid, {metrics.get('errors', 0)} errors")
    print(f"  WER:       {metrics.get('wer', 'N/A')}%")
    print(f"  CER:       {metrics.get('cer', 'N/A')}%")
    print(f"  Latency:   median={metrics.get('latency_median', 'N/A')}s  p95={metrics.get('latency_p95', 'N/A')}s")
    print(f"  First tok: median={metrics.get('latency_first_token_median', 'N/A')}s")
    print(f"  Audio:     {metrics.get('total_audio_sec', 0)}s total")
    print(f"{'─'*70}")

    save_results(concurrency, results, metrics)
    return metrics


def main():
    parser = argparse.ArgumentParser(description="Hindi ASR concurrent load test")
    parser.add_argument("--concurrency", type=int, default=1, help="Number of parallel WebSocket streams")
    parser.add_argument("--max-samples", type=int, default=50, help="Max FLEURS samples to use")
    parser.add_argument("--sweep", action="store_true", help="Run sweep: 1, 5, 10, 25, 50 concurrent")
    args = parser.parse_args()

    if args.sweep:
        all_metrics = []
        for c in [1, 5, 10, 25, 50]:
            m = asyncio.run(main_async(c, max_samples=max(c, args.max_samples)))
            m["concurrency"] = c
            all_metrics.append(m)
            time.sleep(2)  # brief cooldown between sweeps

        print(f"\n{'='*70}")
        print(f"  SWEEP SUMMARY")
        print(f"{'='*70}")
        print(f"  {'Concur':>8} {'Samples':>8} {'WER%':>8} {'CER%':>8} {'Lat(med)':>10} {'Lat(p95)':>10}")
        print(f"  {'─'*8} {'─'*8} {'─'*8} {'─'*8} {'─'*10} {'─'*10}")
        for m in all_metrics:
            print(f"  {m['concurrency']:>8} {m.get('valid_samples',0):>8} {m.get('wer','N/A'):>7}% "
                  f"{m.get('cer','N/A'):>7}% {m.get('latency_median','N/A'):>9}s "
                  f"{m.get('latency_p95','N/A'):>9}s")
        print(f"{'='*70}")
    else:
        asyncio.run(main_async(args.concurrency, max_samples=args.max_samples))


if __name__ == "__main__":
    main()
