"""
Concurrent WebSocket load test for Hindi Nemotron ASR on Modal.
Streams FLEURS Hindi audio to N parallel connections, collects transcriptions,
computes WER/CER, and reports latency metrics.

Usage:
  python load_test.py --concurrency 1     # single-stream baseline
  python load_test.py --concurrency 10    # 10 parallel streams
  python load_test.py --concurrency 50    # 50 parallel streams
"""

import os
import sys
import json
import time
import asyncio
import argparse
import csv
from dataclasses import dataclass, field
from typing import List, Dict, Optional

import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
from math import gcd
from datasets import load_dataset
from jiwer import wer, cer

try:
    import websockets
except ImportError:
    print("Installing websockets...")
    os.system(f"{sys.executable} -m pip install websockets -q")
    import websockets

# ─── Config ──────────────────────────────────────────────────────────────────
WS_URL = "wss://mayaresearch--hindi-nemotron-asr-nemotronasr-webapp.modal.run/ws"
TARGET_SR = 16000
CHUNK_DURATION_S = 0.56  # 560ms chunks (matches att_context_size [70,6])
CHUNK_SIZE_BYTES = int(CHUNK_DURATION_S * TARGET_SR * 2)  # 16-bit PCM = 2 bytes/sample
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark_results")


# ─── Hindi text normalization ────────────────────────────────────────────────
def normalize_hindi(text: str) -> str:
    import re
    if not text:
        return ""
    text = re.sub(r'[।,\.\?\!;:\-\—\–\"\'\(\)\[\]\{\}]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip().lower()
    return text


# ─── Audio helpers ───────────────────────────────────────────────────────────
def audio_to_pcm_bytes(audio_array: np.ndarray, orig_sr: int) -> bytes:
    """Convert audio array → 16kHz 16-bit PCM bytes for WebSocket streaming."""
    if orig_sr != TARGET_SR:
        g = gcd(orig_sr, TARGET_SR)
        up, down = TARGET_SR // g, orig_sr // g
        audio_array = resample_poly(audio_array, up, down)
    # Float32 [-1,1] → int16
    pcm = (audio_array * 32767).astype(np.int16)
    return pcm.tobytes()


# ─── Load FLEURS Hindi ───────────────────────────────────────────────────────
def load_fleurs_samples(max_samples: int = None) -> List[Dict]:
    """Load FLEURS Hindi test → return list of {pcm_bytes, reference, duration}."""
    print("Loading FLEURS Hindi test set...")
    ds = load_dataset("google/fleurs", "hi_in", split="test")
    total = min(len(ds), max_samples) if max_samples else len(ds)

    samples = []
    for i in range(total):
        item = ds[i]
        audio = item["audio"]
        arr = np.array(audio["array"], dtype=np.float32)
        sr = audio["sampling_rate"]
        pcm = audio_to_pcm_bytes(arr, sr)
        duration = len(pcm) / (TARGET_SR * 2)
        samples.append({
            "id": item.get("id", i),
            "pcm_bytes": pcm,
            "reference": item["transcription"],
            "duration": duration,
        })
    print(f"  Loaded {len(samples)} samples ({sum(s['duration'] for s in samples):.0f}s audio)")
    return samples


# ─── Single WebSocket session ─────────────────────────────────────────────────
@dataclass
class StreamResult:
    sample_id: int = 0
    reference: str = ""
    hypothesis: str = ""
    duration: float = 0.0
    latency_first_token: float = 0.0  # time to first transcription result
    latency_total: float = 0.0       # total wall time for this stream
    error: Optional[str] = None


async def stream_one_sample(sample: Dict, session_id: int) -> StreamResult:
    """Stream one audio sample over WebSocket and collect transcription."""
    result = StreamResult(
        sample_id=sample["id"],
        reference=sample["reference"],
        duration=sample["duration"],
    )

    pcm_bytes = sample["pcm_bytes"]
    t_start = time.perf_counter()
    first_token_time = None
    transcription_parts = []

    try:
        async with websockets.connect(
            WS_URL,
            max_size=10 * 1024 * 1024,
            open_timeout=30,
            close_timeout=10,
            ping_interval=20,
            ping_timeout=20,
        ) as ws:
            # Wait for Ready message
            ready_msg = await asyncio.wait_for(ws.recv(), timeout=30)
            ready = json.loads(ready_msg)
            if ready.get("type") != "Ready":
                result.error = f"Expected Ready, got: {ready}"
                return result

            # Stream audio chunks at ~real-time pace
            for offset in range(0, len(pcm_bytes), CHUNK_SIZE_BYTES):
                chunk = pcm_bytes[offset:offset + CHUNK_SIZE_BYTES]
                await ws.send(chunk)
                # Simulate real-time: sleep proportional to chunk duration
                # But speed it up by 3x to keep test fast while still being realistic
                await asyncio.sleep(CHUNK_DURATION_S / 3)

            # Signal end of stream
            await ws.send("END")

            # Collect all transcription results until END echo
            while True:
                try:
                    msg = await asyncio.wait_for(ws.recv(), timeout=30)
                    if isinstance(msg, str):
                        if msg == "END":
                            break
                        data = json.loads(msg)
                        text = data.get("text", "")
                        if text and first_token_time is None:
                            first_token_time = time.perf_counter()
                        if data.get("is_final", False):
                            transcription_parts.append(text)
                        elif not data.get("is_final"):
                            # Partial — keep updating last partial
                            pass
                    elif isinstance(msg, bytes):
                        # msgpack response — decode
                        import msgpack
                        data = msgpack.unpackb(msg, raw=False)
                        if data.get("type") == "Marker":
                            break
                        text = data.get("text", "")
                        if text and first_token_time is None:
                            first_token_time = time.perf_counter()
                        if data.get("is_final", False):
                            transcription_parts.append(text)
                except asyncio.TimeoutError:
                    break

    except Exception as e:
        result.error = str(e)
        result.latency_total = time.perf_counter() - t_start
        return result

    t_end = time.perf_counter()
    result.hypothesis = " ".join(transcription_parts).strip()
    result.latency_total = t_end - t_start
    result.latency_first_token = (first_token_time - t_start) if first_token_time else result.latency_total

    return result


# ─── Concurrent batch runner ─────────────────────────────────────────────────
async def run_concurrent_batch(
    samples: List[Dict],
    concurrency: int,
) -> List[StreamResult]:
    """Run N concurrent WebSocket streams, cycling through samples."""
    # Assign samples round-robin to concurrent slots
    tasks = []
    for i in range(min(concurrency, len(samples))):
        tasks.append(stream_one_sample(samples[i], session_id=i))

    print(f"\n  Streaming {len(tasks)} samples with {concurrency} concurrent connections...")
    t0 = time.perf_counter()
    results = await asyncio.gather(*tasks, return_exceptions=True)
    elapsed = time.perf_counter() - t0

    # Handle exceptions
    final_results = []
    for r in results:
        if isinstance(r, Exception):
            final_results.append(StreamResult(error=str(r)))
        else:
            final_results.append(r)

    total_audio = sum(r.duration for r in final_results if r.error is None)
    print(f"  Completed in {elapsed:.1f}s ({total_audio:.0f}s audio, {len(tasks)} streams)")

    return final_results


# ─── Metrics ──────────────────────────────────────────────────────────────────
def compute_metrics(results: List[StreamResult]) -> Dict:
    """Compute WER, CER, and latency stats from stream results."""
    valid = [r for r in results if r.error is None and r.reference.strip()]
    if not valid:
        return {"error": "No valid results"}

    refs = [normalize_hindi(r.reference) for r in valid]
    hyps = [normalize_hindi(r.hypothesis) for r in valid]

    # Filter out empty refs
    pairs = [(r, h) for r, h in zip(refs, hyps) if r.strip()]
    if not pairs:
        return {"error": "No non-empty references"}
    refs_f, hyps_f = zip(*pairs)

    corpus_wer = wer(list(refs_f), list(hyps_f))
    corpus_cer = cer(list(refs_f), list(hyps_f))

    latencies = [r.latency_total for r in valid]
    first_token_lats = [r.latency_first_token for r in valid]

    return {
        "valid_samples": len(valid),
        "errors": len(results) - len(valid),
        "wer": round(corpus_wer * 100, 2),
        "cer": round(corpus_cer * 100, 2),
        "latency_median": round(np.median(latencies), 3),
        "latency_p95": round(np.percentile(latencies, 95), 3),
        "latency_first_token_median": round(np.median(first_token_lats), 3),
        "total_audio_sec": round(sum(r.duration for r in valid), 1),
    }


# ─── Save results ────────────────────────────────────────────────────────────
def save_results(concurrency: int, results: List[StreamResult], metrics: Dict):
    os.makedirs(RESULTS_DIR, exist_ok=True)
    tag = f"concurrent_{concurrency}"

    # Per-sample CSV
    csv_path = os.path.join(RESULTS_DIR, f"loadtest_{tag}_details.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "duration", "reference", "hypothesis", "latency_total", "latency_first_token", "error"])
        for r in results:
            w.writerow([r.sample_id, f"{r.duration:.2f}", r.reference, r.hypothesis,
                        f"{r.latency_total:.3f}", f"{r.latency_first_token:.3f}", r.error or ""])

    # Summary JSON
    summary = {
        "concurrency": concurrency,
        "ws_url": WS_URL,
        **metrics,
    }
    json_path = os.path.join(RESULTS_DIR, f"loadtest_{tag}_summary.json")
    with open(json_path, "w") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    print(f"  Saved: {csv_path}")
    print(f"  Saved: {json_path}")


# ─── Main ─────────────────────────────────────────────────────────────────────
async def main_async(concurrency: int, max_samples: int):
    samples = load_fleurs_samples(max_samples=max_samples)

    print(f"\n{'='*70}")
    print(f"  Hindi ASR Load Test — {concurrency} concurrent connections")
    print(f"  Endpoint: {WS_URL}")
    print(f"  Dataset: FLEURS Hindi test ({len(samples)} samples)")
    print(f"{'='*70}")

    results = await run_concurrent_batch(samples, concurrency)
    metrics = compute_metrics(results)

    print(f"\n{'─'*70}")
    print(f"  RESULTS (concurrency={concurrency})")
    print(f"{'─'*70}")
    print(f"  Samples:   {metrics.get('valid_samples', 0)} valid, {metrics.get('errors', 0)} errors")
    print(f"  WER:       {metrics.get('wer', 'N/A')}%")
    print(f"  CER:       {metrics.get('cer', 'N/A')}%")
    print(f"  Latency:   median={metrics.get('latency_median', 'N/A')}s  p95={metrics.get('latency_p95', 'N/A')}s")
    print(f"  First tok: median={metrics.get('latency_first_token_median', 'N/A')}s")
    print(f"  Audio:     {metrics.get('total_audio_sec', 0)}s total")
    print(f"{'─'*70}")

    save_results(concurrency, results, metrics)
    return metrics


def main():
    parser = argparse.ArgumentParser(description="Hindi ASR concurrent load test")
    parser.add_argument("--concurrency", type=int, default=1, help="Number of parallel WebSocket streams")
    parser.add_argument("--max-samples", type=int, default=50, help="Max FLEURS samples to use")
    parser.add_argument("--sweep", action="store_true", help="Run sweep: 1, 5, 10, 25, 50 concurrent")
    args = parser.parse_args()

    if args.sweep:
        all_metrics = []
        for c in [1, 5, 10, 25, 50]:
            m = asyncio.run(main_async(c, max_samples=max(c, args.max_samples)))
            m["concurrency"] = c
            all_metrics.append(m)
            time.sleep(2)  # brief cooldown between sweeps

        print(f"\n{'='*70}")
        print(f"  SWEEP SUMMARY")
        print(f"{'='*70}")
        print(f"  {'Concur':>8} {'Samples':>8} {'WER%':>8} {'CER%':>8} {'Lat(med)':>10} {'Lat(p95)':>10}")
        print(f"  {'─'*8} {'─'*8} {'─'*8} {'─'*8} {'─'*10} {'─'*10}")
        for m in all_metrics:
            print(f"  {m['concurrency']:>8} {m.get('valid_samples',0):>8} {m.get('wer','N/A'):>7}% "
                  f"{m.get('cer','N/A'):>7}% {m.get('latency_median','N/A'):>9}s "
                  f"{m.get('latency_p95','N/A'):>9}s")
        print(f"{'='*70}")
    else:
        asyncio.run(main_async(args.concurrency, max_samples=args.max_samples))


if __name__ == "__main__":
    main()
