#!/usr/bin/env python3
"""
Modal Nemotron ASR Benchmark — WebSocket / HTTP with WER quality scoring.
Stress test with increasing concurrency. Finds the quality degradation point.

Usage:
    source venv/bin/activate
    python nemotron_asr/bench_modal.py --url <MODAL_WEB_URL> --mode ws --concurrency 1,5,10,25,50,100
    python nemotron_asr/bench_modal.py --url <MODAL_WEB_URL> --mode http --concurrency 1,5,10,25,50
    python nemotron_asr/bench_modal.py --url <MODAL_WEB_URL> --mode all  # ws+http, default sweep
"""

import asyncio
import aiohttp
import argparse
import base64
import csv
import json
import os
import sys
import time
import wave
import io
import statistics
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, field

from jiwer import wer as compute_wer, cer as compute_cer

# === CONFIGURATION ===

DEFAULT_CONCURRENCY_LEVELS = [1, 5, 10, 25, 50, 75, 100, 150, 200]
AUDIO_DIR = Path(__file__).parent.parent / "benchmark_data" / "fleurs_hindi"
RESULTS_DIR = Path(__file__).parent.parent / "benchmark_results"

# WebSocket chunk params
CHUNK_DURATION_MS = 160  # 160ms chunks
SAMPLE_RATE = 16000
BYTES_PER_SAMPLE = 2
CHUNK_SIZE_BYTES = int(CHUNK_DURATION_MS / 1000 * SAMPLE_RATE * BYTES_PER_SAMPLE)

# Timeouts
WS_CONNECT_TIMEOUT = 30
HTTP_TIMEOUT = 60
WS_STREAM_TIMEOUT = 120


@dataclass
class StreamResult:
    """Result from a single stream/request"""
    stream_id: int
    mode: str
    sample_id: int = -1          # Which audio sample was used (for ground truth matching)
    ok: bool = False
    error: str = ""
    text: str = ""               # Predicted transcription
    reference_text: str = ""     # Ground truth from manifest
    wer: float = -1.0            # Word Error Rate (-1 = not computed)
    cer: float = -1.0            # Character Error Rate (-1 = not computed)
    audio_duration: float = 0.0
    connect_time_ms: float = 0.0
    first_token_ms: float = 0.0
    last_token_ms: float = 0.0
    total_time_ms: float = 0.0
    processing_time_ms: float = 0.0
    messages_received: int = 0
    vram_gb: float = 0.0
    vram_total_gb: float = 0.0


@dataclass
class BenchmarkLevel:
    """Results for a single concurrency level"""
    concurrency: int
    mode: str
    results: list = field(default_factory=list)
    wall_time_ms: float = 0.0
    health_before: dict = field(default_factory=dict)
    health_after: dict = field(default_factory=dict)


# === WER/CER COMPUTATION ===

def compute_quality(result: StreamResult) -> None:
    """Compute WER and CER for a single result. Mutates result in place."""
    if not result.ok or not result.reference_text:
        return

    hyp = result.text.strip()
    ref = result.reference_text.strip()

    if not ref:
        return
    if not hyp:
        # Empty hypothesis against non-empty reference = 100% error
        result.wer = 1.0
        result.cer = 1.0
        return

    try:
        result.wer = compute_wer(ref, hyp)
        result.cer = compute_cer(ref, hyp)
    except Exception:
        # jiwer can fail on edge cases; leave at -1
        pass


# === AUDIO LOADING ===

def load_audio_samples(max_samples=50):
    """Load FLEURS Hindi audio samples as raw PCM bytes + ground truth text"""
    manifest_path = AUDIO_DIR / "manifest.json"
    if not manifest_path.exists():
        print(f"ERROR: No manifest at {manifest_path}")
        sys.exit(1)

    with open(manifest_path) as f:
        manifest = json.load(f)

    samples = []
    for entry in manifest[:max_samples]:
        audio_path = Path(__file__).parent.parent / entry["audio_path"]
        if not audio_path.exists():
            continue

        with open(audio_path, "rb") as f:
            wav_bytes = f.read()

        with wave.open(io.BytesIO(wav_bytes), "rb") as wf:
            n_frames = wf.getnframes()
            sr = wf.getframerate()
            pcm_data = wf.readframes(n_frames)
            duration = n_frames / sr

        samples.append({
            "id": entry["id"],
            "wav_bytes": wav_bytes,
            "pcm_bytes": pcm_data,
            "text": entry.get("text", ""),
            "duration": duration,
            "b64": base64.b64encode(wav_bytes).decode("ascii"),
        })

    total_dur = sum(s["duration"] for s in samples)
    gt_count = sum(1 for s in samples if s["text"])
    print(f"Loaded {len(samples)} audio samples ({total_dur:.1f}s total, {gt_count} with ground truth)")
    return samples


# === HEALTH CHECK ===

async def check_health(base_url: str, session: aiohttp.ClientSession) -> dict:
    """Query /health endpoint for VRAM, GPU, and status"""
    try:
        url = base_url.rstrip("/") + "/health"
        async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp:
            if resp.status in (200, 503):
                return await resp.json()
            return {"error": f"status {resp.status}"}
    except Exception as e:
        return {"error": str(e)}


# === WEBSOCKET BENCHMARK ===

async def ws_stream(
    ws_url: str,
    sample: dict,
    stream_id: int,
    realtime_pace: bool = True,
) -> StreamResult:
    """Single WebSocket stream: connect, send audio chunks, collect transcription."""
    result = StreamResult(
        stream_id=stream_id, mode="ws",
        sample_id=sample["id"], reference_text=sample["text"],
        audio_duration=sample["duration"],
    )

    t_start = time.perf_counter()
    try:
        session = aiohttp.ClientSession()
        try:
            t_conn_start = time.perf_counter()
            ws = await session.ws_connect(ws_url, timeout=WS_CONNECT_TIMEOUT)
            t_connected = time.perf_counter()
            result.connect_time_ms = (t_connected - t_conn_start) * 1000

            # Wait for Ready message
            ready_msg = await asyncio.wait_for(ws.receive(), timeout=10)
            if ready_msg.type == aiohttp.WSMsgType.TEXT:
                ready_data = json.loads(ready_msg.data)
                if ready_data.get("type") != "Ready":
                    result.error = f"Expected Ready, got {ready_data}"
                    await ws.close()
                    return result

            pcm = sample["pcm_bytes"]
            chunk_interval = CHUNK_DURATION_MS / 1000
            messages = []

            async def sender():
                for i in range(0, len(pcm), CHUNK_SIZE_BYTES):
                    chunk = pcm[i:i + CHUNK_SIZE_BYTES]
                    await ws.send_bytes(chunk)
                    if realtime_pace:
                        await asyncio.sleep(chunk_interval)
                await ws.send_str("END")

            async def receiver():
                try:
                    while True:
                        msg = await asyncio.wait_for(ws.receive(), timeout=WS_STREAM_TIMEOUT)
                        t_recv = time.perf_counter()

                        if msg.type == aiohttp.WSMsgType.TEXT:
                            data = msg.data
                            if data == "END":
                                break
                            try:
                                parsed = json.loads(data)
                                messages.append({"time": t_recv, "data": parsed})
                            except json.JSONDecodeError:
                                pass
                        elif msg.type == aiohttp.WSMsgType.BINARY:
                            import msgpack
                            try:
                                parsed = msgpack.unpackb(msg.data, raw=False)
                                if parsed.get("type") == "Marker":
                                    break
                                messages.append({"time": t_recv, "data": parsed})
                            except Exception:
                                pass
                        elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
                            break
                except asyncio.TimeoutError:
                    pass

            t_stream_start = time.perf_counter()
            await asyncio.gather(
                asyncio.create_task(sender()),
                asyncio.create_task(receiver()),
            )
            t_end = time.perf_counter()

            result.ok = True
            result.messages_received = len(messages)
            result.total_time_ms = (t_end - t_start) * 1000

            if messages:
                result.first_token_ms = (messages[0]["time"] - t_stream_start) * 1000
                result.last_token_ms = (messages[-1]["time"] - t_stream_start) * 1000

                # WHY: Server sends incremental text DELTAS per step (current_step_transcript).
                # Each message's "text" field is a fragment with embedded spaces — concatenate as-is.
                # DON'T strip individual parts — leading spaces indicate word boundaries.
                all_text_parts = []
                for m in messages:
                    d = m["data"]
                    t = d.get("text", "")
                    if t and t.strip() != "<s>":  # Skip BOS token artifact
                        all_text_parts.append(t)
                result.text = "".join(all_text_parts).strip()

            await ws.close()
        finally:
            await session.close()

    except asyncio.TimeoutError:
        result.error = "timeout"
        result.total_time_ms = (time.perf_counter() - t_start) * 1000
    except Exception as e:
        result.error = str(e)[:200]
        result.total_time_ms = (time.perf_counter() - t_start) * 1000

    compute_quality(result)
    return result


# === HTTP BENCHMARK ===

async def http_request(
    base_url: str,
    sample: dict,
    stream_id: int,
    session: aiohttp.ClientSession,
) -> StreamResult:
    """Single HTTP batch transcription request"""
    result = StreamResult(
        stream_id=stream_id, mode="http",
        sample_id=sample["id"], reference_text=sample["text"],
        audio_duration=sample["duration"],
    )

    t_start = time.perf_counter()
    try:
        url = base_url.rstrip("/") + "/transcribe"
        payload = {"audio_base64": sample["b64"]}

        async with session.post(
            url, json=payload,
            timeout=aiohttp.ClientTimeout(total=HTTP_TIMEOUT),
        ) as resp:
            t_resp = time.perf_counter()

            if resp.status == 200:
                data = await resp.json()
                result.ok = True
                result.text = data.get("text", "")
                result.processing_time_ms = data.get("processing_time", 0) * 1000
                result.total_time_ms = (t_resp - t_start) * 1000
                result.first_token_ms = result.total_time_ms
                result.vram_gb = data.get("vram_gb", 0)
                result.vram_total_gb = data.get("vram_total_gb", 0)
            else:
                result.error = f"HTTP {resp.status}: {(await resp.text())[:200]}"
                result.total_time_ms = (t_resp - t_start) * 1000

    except asyncio.TimeoutError:
        result.error = "timeout"
        result.total_time_ms = (time.perf_counter() - t_start) * 1000
    except Exception as e:
        result.error = str(e)[:200]
        result.total_time_ms = (time.perf_counter() - t_start) * 1000

    compute_quality(result)
    return result


# === BENCHMARK RUNNER ===

async def run_benchmark_level(
    base_url: str,
    mode: str,
    concurrency: int,
    samples: list,
    realtime_pace: bool = True,
) -> BenchmarkLevel:
    """Run one concurrency level of a specific mode"""
    level = BenchmarkLevel(concurrency=concurrency, mode=mode)

    async with aiohttp.ClientSession() as session:
        level.health_before = await check_health(base_url, session)

    ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://").rstrip("/") + "/ws"

    print(f"  [{mode.upper()}] c={concurrency}: Starting {concurrency} {'streams' if mode == 'ws' else 'requests'}...")

    t_wall_start = time.perf_counter()

    if mode == "ws":
        tasks = []
        for i in range(concurrency):
            sample = samples[i % len(samples)]
            tasks.append(ws_stream(ws_url, sample, i, realtime_pace=realtime_pace))
        level.results = await asyncio.gather(*tasks)

    elif mode == "http":
        connector = aiohttp.TCPConnector(limit=concurrency + 10)
        async with aiohttp.ClientSession(connector=connector) as session:
            tasks = []
            for i in range(concurrency):
                sample = samples[i % len(samples)]
                tasks.append(http_request(base_url, sample, i, session))
            level.results = await asyncio.gather(*tasks)

    t_wall_end = time.perf_counter()
    level.wall_time_ms = (t_wall_end - t_wall_start) * 1000

    await asyncio.sleep(1)
    async with aiohttp.ClientSession() as session:
        level.health_after = await check_health(base_url, session)

    # --- Print summary with WER ---
    ok_results = [r for r in level.results if r.ok]
    err = concurrency - len(ok_results)

    if ok_results:
        ft_p50 = statistics.median([r.first_token_ms for r in ok_results])
        ft_p99 = sorted([r.first_token_ms for r in ok_results])[min(int(len(ok_results) * 0.99), len(ok_results) - 1)]
        total_audio = sum(r.audio_duration for r in ok_results)
        throughput = total_audio / (level.wall_time_ms / 1000) if level.wall_time_ms > 0 else 0
        msgs_avg = statistics.mean([r.messages_received for r in ok_results]) if mode == "ws" else 0

        wer_vals = [r.wer for r in ok_results if r.wer >= 0]
        cer_vals = [r.cer for r in ok_results if r.cer >= 0]
        wer_str = f" WER={statistics.mean(wer_vals)*100:.1f}%" if wer_vals else ""
        cer_str = f" CER={statistics.mean(cer_vals)*100:.1f}%" if cer_vals else ""

        vram_str = ""
        h = level.health_after
        if h.get("vram_gb"):
            vram_str = f" VRAM={h['vram_gb']:.1f}/{h.get('vram_total_gb', 0):.0f}GB"
        elif ok_results[0].vram_gb > 0:
            vram_str = f" VRAM={ok_results[0].vram_gb:.1f}/{ok_results[0].vram_total_gb:.0f}GB"

        print(f"  [{mode.upper()}] c={concurrency}: OK={len(ok_results)} Err={err} "
              f"FT_p50={ft_p50:.0f}ms FT_p99={ft_p99:.0f}ms "
              f"Thru={throughput:.1f}x"
              f"{wer_str}{cer_str}"
              f"{'  Msgs=' + f'{msgs_avg:.1f}' if mode == 'ws' else ''}"
              f"{vram_str}")
    else:
        print(f"  [{mode.upper()}] c={concurrency}: OK=0 Err={err} (ALL FAILED)")

    return level


# === STATS ===

def compute_stats(levels: list) -> list:
    """Compute summary stats for all levels including WER/CER"""
    rows = []
    for level in levels:
        ok_results = [r for r in level.results if r.ok]
        err_count = len(level.results) - len(ok_results)

        if not ok_results:
            rows.append({
                "mode": level.mode, "concurrency": level.concurrency,
                "ok": 0, "errors": err_count,
                "wer_pct": -1, "cer_pct": -1,
                "ft_p50_ms": 0, "ft_p99_ms": 0,
                "total_p50_ms": 0, "total_p99_ms": 0,
                "wall_ms": level.wall_time_ms,
                "throughput_x": 0, "msgs_per_stream": 0,
                "vram_gb": 0, "vram_total_gb": 0,
                "active_streams": 0,
            })
            continue

        fts = sorted([r.first_token_ms for r in ok_results])
        totals = sorted([r.total_time_ms for r in ok_results])
        total_audio = sum(r.audio_duration for r in ok_results)
        throughput = total_audio / (level.wall_time_ms / 1000) if level.wall_time_ms > 0 else 0

        wer_vals = [r.wer for r in ok_results if r.wer >= 0]
        cer_vals = [r.cer for r in ok_results if r.cer >= 0]

        h = level.health_after
        vram = h.get("vram_gb", 0)
        vram_total = h.get("vram_total_gb", 0)
        if vram == 0 and ok_results[0].vram_gb > 0:
            vram = ok_results[0].vram_gb
            vram_total = ok_results[0].vram_total_gb

        rows.append({
            "mode": level.mode,
            "concurrency": level.concurrency,
            "ok": len(ok_results),
            "errors": err_count,
            "wer_pct": round(statistics.mean(wer_vals) * 100, 1) if wer_vals else -1,
            "cer_pct": round(statistics.mean(cer_vals) * 100, 1) if cer_vals else -1,
            "ft_p50_ms": round(statistics.median(fts), 1),
            "ft_p99_ms": round(fts[min(int(len(fts) * 0.99), len(fts) - 1)], 1),
            "total_p50_ms": round(statistics.median(totals), 1),
            "total_p99_ms": round(totals[min(int(len(totals) * 0.99), len(totals) - 1)], 1),
            "wall_ms": round(level.wall_time_ms, 0),
            "throughput_x": round(throughput, 1),
            "msgs_per_stream": round(statistics.mean([r.messages_received for r in ok_results]), 1) if level.mode == "ws" else 0,
            "vram_gb": round(vram, 1),
            "vram_total_gb": round(vram_total, 0),
            "active_streams": h.get("active_streams", 0),
        })

    return rows


def save_results(stats_rows: list, prefix: str):
    """Save results to CSV and JSON"""
    RESULTS_DIR.mkdir(exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    csv_path = RESULTS_DIR / f"{prefix}_{timestamp}.csv"
    json_path = RESULTS_DIR / f"{prefix}_{timestamp}.json"

    if stats_rows:
        fieldnames = list(stats_rows[0].keys())
        with open(csv_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(stats_rows)

    with open(json_path, "w") as f:
        json.dump(stats_rows, f, indent=2)

    print(f"\nResults saved to:\n  {csv_path}\n  {json_path}")
    return csv_path, json_path


def print_summary_table(stats_rows: list):
    """Print summary table with WER/CER columns"""
    if not stats_rows:
        return

    print(f"\n{'='*140}")
    print(f"{'Mode':>6} | {'Conc':>5} | {'OK':>4} | {'Err':>4} | {'WER%':>6} | {'CER%':>6} | "
          f"{'FT p50':>8} | {'FT p99':>8} | {'Total p50':>9} | "
          f"{'Wall':>8} | {'Thru':>7} | {'Msgs':>5} | {'VRAM':>12}")
    print(f"{'-'*140}")

    for r in stats_rows:
        vram_str = f"{r['vram_gb']:.1f}/{r['vram_total_gb']:.0f}GB" if r['vram_total_gb'] > 0 else "N/A"
        msgs_str = f"{r['msgs_per_stream']:.1f}" if r['msgs_per_stream'] > 0 else "-"
        wer_str = f"{r['wer_pct']:.1f}" if r['wer_pct'] >= 0 else "N/A"
        cer_str = f"{r['cer_pct']:.1f}" if r['cer_pct'] >= 0 else "N/A"

        print(f"{r['mode']:>6} | {r['concurrency']:>5} | {r['ok']:>4} | {r['errors']:>4} | "
              f"{wer_str:>6} | {cer_str:>6} | "
              f"{r['ft_p50_ms']:>7.0f}ms | {r['ft_p99_ms']:>7.0f}ms | "
              f"{r['total_p50_ms']:>8.0f}ms | "
              f"{r['wall_ms']:>7.0f}ms | {r['throughput_x']:>6.1f}x | "
              f"{msgs_str:>5} | {vram_str:>12}")

    print(f"{'='*140}")


# === MAIN ===

async def main():
    parser = argparse.ArgumentParser(description="Modal Nemotron ASR Benchmark (WS + HTTP with WER)")
    parser.add_argument("--url", required=True, help="Modal web URL")
    parser.add_argument("--mode", default="all", choices=["ws", "http", "all"],
                        help="Benchmark mode (default: all = ws + http)")
    parser.add_argument("--concurrency", default=None,
                        help="Comma-separated concurrency levels (default: 1,5,10,25,50,75,100,150,200)")
    parser.add_argument("--max-samples", type=int, default=50, help="Max audio samples to load")
    parser.add_argument("--no-realtime", action="store_true", help="Send WS audio as fast as possible (not 1x pace)")
    parser.add_argument("--cooldown", type=int, default=5, help="Seconds between concurrency levels")
    parser.add_argument("--prefix", default="modal_bench", help="Output file prefix")
    args = parser.parse_args()

    if args.concurrency:
        concurrency_levels = [int(x.strip()) for x in args.concurrency.split(",")]
    else:
        concurrency_levels = DEFAULT_CONCURRENCY_LEVELS

    modes = ["ws", "http"] if args.mode == "all" else [args.mode]

    samples = load_audio_samples(args.max_samples)
    if not samples:
        print("ERROR: No audio samples found")
        sys.exit(1)

    base_url = args.url.rstrip("/")

    print(f"\n{'='*80}")
    print(f"Modal Nemotron ASR Benchmark (with WER scoring)")
    print(f"{'='*80}")
    print(f"URL: {base_url}")
    print(f"Modes: {', '.join(modes)}")
    print(f"Concurrency levels: {concurrency_levels}")
    print(f"Audio samples: {len(samples)}")
    print(f"Realtime pace (WS): {'No (burst)' if args.no_realtime else 'Yes (1x)'}")

    async with aiohttp.ClientSession() as session:
        health = await check_health(base_url, session)

    if "error" in health:
        print(f"\nWARNING: Health check failed: {health['error']}")
        print("The app may need to cold-start. Continuing anyway...")
    else:
        print(f"\nHealth: {health.get('status', 'unknown')}")
        print(f"  GPU: {health.get('gpu', '?')}")
        print(f"  VRAM: {health.get('vram_gb', '?')}/{health.get('vram_total_gb', '?')} GB")
        print(f"  Active streams: {health.get('active_streams', '?')}")

    print(f"\n{'='*80}")
    print(f"Starting benchmark...")
    print(f"{'='*80}\n")

    all_levels = []

    for mode in modes:
        print(f"\n--- {mode.upper()} MODE ---")

        for conc in concurrency_levels:
            level = await run_benchmark_level(
                base_url=base_url,
                mode=mode,
                concurrency=conc,
                samples=samples,
                realtime_pace=not args.no_realtime,
            )
            all_levels.append(level)

            if conc != concurrency_levels[-1]:
                print(f"  Cooling down {args.cooldown}s...")
                await asyncio.sleep(args.cooldown)

    stats = compute_stats(all_levels)
    print_summary_table(stats)

    gpu_name = health.get("gpu", "unknown") if isinstance(health, dict) else "unknown"
    gpu_tag = gpu_name.replace(" ", "_").replace("/", "-") if gpu_name != "unknown" else ""
    prefix = f"{args.prefix}_{gpu_tag}" if gpu_tag else args.prefix
    save_results(stats, prefix)

    # --- WER degradation summary ---
    ws_rows = [r for r in stats if r["mode"] == "ws" and r["wer_pct"] >= 0]
    http_rows = [r for r in stats if r["mode"] == "http" and r["wer_pct"] >= 0]

    if ws_rows or http_rows:
        print(f"\n--- QUALITY DEGRADATION CURVE ---")
        for label, rows in [("WS", ws_rows), ("HTTP", http_rows)]:
            if not rows:
                continue
            baseline_wer = rows[0]["wer_pct"]
            print(f"  {label} baseline WER (c={rows[0]['concurrency']}): {baseline_wer:.1f}%")
            for r in rows[1:]:
                delta = r["wer_pct"] - baseline_wer
                marker = ""
                if delta > 10:
                    marker = " << DEGRADED"
                elif delta > 5:
                    marker = " << WARNING"
                print(f"    c={r['concurrency']:>4}: WER={r['wer_pct']:.1f}% (delta={delta:+.1f}%){marker}")

    # VRAM summary
    vram_entries = [r for r in stats if r["vram_total_gb"] > 0]
    if vram_entries:
        max_vram = max(r["vram_gb"] for r in vram_entries)
        total_vram = vram_entries[0]["vram_total_gb"]
        pct = (max_vram / total_vram) * 100 if total_vram > 0 else 0
        print(f"\n--- VRAM ---")
        print(f"Peak: {max_vram:.1f} / {total_vram:.0f} GB ({pct:.1f}%)")


if __name__ == "__main__":
    asyncio.run(main())
