#!/usr/bin/env python3
"""
Direct-runtime streaming stress benchmark for Veena3 TTS.

Focuses on streaming KPIs:
- TTFB (time to first chunk)
- End-to-end latency
- Chunk cadence (inter-chunk interval)
- Emission speed (audio seconds emitted per second after first chunk)
- Throughput under concurrency
"""

from __future__ import annotations

import argparse
import asyncio
import json
import os
import statistics
import subprocess
import sys
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional


# Path setup (same pattern as local_server.py)
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
EXTERNAL_DIR = os.path.join(REPO_ROOT, "external")
for path in [os.path.join(EXTERNAL_DIR, "sparktts"), os.path.join(EXTERNAL_DIR, "AP-BWE"), REPO_ROOT]:
    if path not in sys.path:
        sys.path.insert(0, path)

os.environ["AUTH_BYPASS_MODE"] = "true"

TEST_TEXTS = {
    "short": "Hello, this is a quick streaming benchmark for voice agents.",
    "medium": (
        "This streaming benchmark sentence is designed to evaluate first audio latency, "
        "chunk cadence, and sustained emission speed under concurrent request load."
    ),
    "long": (
        "In a bustling city, a narrator collected voice notes from people walking through markets, "
        "subways, and parks. The stream needed to start quickly, stay stable, and deliver clear speech "
        "without stalls as many listeners connected at once."
    ),
}

SPEAKERS = ["lipakshi", "vardan", "reet", "Nandini", "krishna", "anika"]


@dataclass
class GPUSnapshot:
    timestamp: float
    memory_used_mb: float
    memory_total_mb: float
    gpu_utilization_pct: float


class GPUMonitor:
    def __init__(self, interval_seconds: float = 0.5):
        self.interval = interval_seconds
        self.snapshots: List[GPUSnapshot] = []
        self._running = False
        self._thread: Optional[threading.Thread] = None

    def start(self) -> None:
        self._running = True
        self._thread = threading.Thread(target=self._poll_loop, daemon=True)
        self._thread.start()

    def stop(self) -> None:
        self._running = False
        if self._thread:
            self._thread.join(timeout=3)

    def _poll_loop(self) -> None:
        while self._running:
            try:
                res = subprocess.run(
                    [
                        "nvidia-smi",
                        "--query-gpu=memory.used,memory.total,utilization.gpu",
                        "--format=csv,noheader,nounits",
                    ],
                    capture_output=True,
                    text=True,
                    timeout=5,
                )
                if res.returncode == 0 and res.stdout:
                    parts = [p.strip() for p in res.stdout.strip().split(",")]
                    if len(parts) >= 3:
                        self.snapshots.append(
                            GPUSnapshot(
                                timestamp=time.time(),
                                memory_used_mb=float(parts[0]),
                                memory_total_mb=float(parts[1]),
                                gpu_utilization_pct=float(parts[2]),
                            )
                        )
            except Exception:
                pass
            time.sleep(self.interval)

    def summary(self) -> Dict[str, Any]:
        if not self.snapshots:
            return {"samples": 0, "error": "no GPU samples captured"}
        mem = [s.memory_used_mb for s in self.snapshots]
        util = [s.gpu_utilization_pct for s in self.snapshots]
        return {
            "samples": len(self.snapshots),
            "memory_mb": {
                "min": min(mem),
                "avg": statistics.mean(mem),
                "max": max(mem),
                "total": self.snapshots[0].memory_total_mb,
            },
            "gpu_util_pct": {
                "min": min(util),
                "avg": statistics.mean(util),
                "max": max(util),
            },
        }


@dataclass
class RequestResult:
    success: bool
    latency_ms: float
    ttfb_ms: float
    chunks_sent: int
    total_bytes: int
    audio_seconds: float
    speaker: str
    error: Optional[str] = None
    inter_chunk_ms: List[float] = field(default_factory=list)
    tail_emit_speed_x: float = 0.0
    final_metrics: Dict[str, Any] = field(default_factory=dict)


@dataclass
class LevelResult:
    concurrency: int
    total_requests: int
    results: List[RequestResult] = field(default_factory=list)
    wall_time_s: float = 0.0
    gpu_summary: Dict[str, Any] = field(default_factory=dict)

    @property
    def successes(self) -> List[RequestResult]:
        return [r for r in self.results if r.success]

    @property
    def failures(self) -> List[RequestResult]:
        return [r for r in self.results if not r.success]

    @property
    def success_rate(self) -> float:
        return len(self.successes) / len(self.results) if self.results else 0.0

    @property
    def throughput_rps(self) -> float:
        return self.total_requests / self.wall_time_s if self.wall_time_s > 0 else 0.0

    def _pct(self, vals: List[float], p: int) -> float:
        if not vals:
            return 0.0
        s = sorted(vals)
        idx = int((len(s) - 1) * (p / 100))
        return s[idx]

    def _stats(self, vals: List[float]) -> Dict[str, float]:
        if not vals:
            return {}
        return {
            "avg": statistics.mean(vals),
            "p50": self._pct(vals, 50),
            "p95": self._pct(vals, 95),
            "max": max(vals),
        }

    def latency_stats(self) -> Dict[str, float]:
        return self._stats([r.latency_ms for r in self.successes])

    def ttfb_stats(self) -> Dict[str, float]:
        return self._stats([r.ttfb_ms for r in self.successes if r.ttfb_ms > 0])

    def chunk_count_stats(self) -> Dict[str, float]:
        return self._stats([float(r.chunks_sent) for r in self.successes])

    def emit_speed_stats(self) -> Dict[str, float]:
        return self._stats([r.tail_emit_speed_x for r in self.successes if r.tail_emit_speed_x > 0])

    def inter_chunk_stats(self) -> Dict[str, float]:
        vals: List[float] = []
        for r in self.successes:
            vals.extend(r.inter_chunk_ms)
        return self._stats(vals)

    def timing_stats(self, key: str) -> Dict[str, float]:
        vals: List[float] = []
        for r in self.successes:
            v = r.final_metrics.get(key)
            if isinstance(v, (int, float)):
                vals.append(float(v))
        return self._stats(vals)

    def latency_over_metric_stats(self, key: str) -> Dict[str, float]:
        vals: List[float] = []
        for r in self.successes:
            base = r.final_metrics.get(key)
            if isinstance(base, (int, float)):
                vals.append(max(0.0, float(r.latency_ms) - float(base)))
        return self._stats(vals)

    def failure_reason_counts(self) -> Dict[str, int]:
        counts: Dict[str, int] = {}
        for r in self.failures:
            reason = (r.error or "unknown").strip().lower()
            if not reason:
                reason = "unknown"
            counts[reason] = counts.get(reason, 0) + 1
        return counts


def init_runtime(
    model_path: str,
    gpu_mem: float,
    stream_output_kind: str,
    num_engines: int,
    max_num_batched_tokens: Optional[int],
    max_num_seqs: Optional[int],
    enable_chunked_prefill: Optional[bool],
    enable_prefix_caching: Optional[bool],
    disable_log_stats: Optional[bool],
    enforce_eager: Optional[bool],
) -> None:
    from veena3modal.local_server import resolve_model_paths
    from veena3modal.services.tts_runtime import initialize_runtime, is_initialized

    os.environ["VEENA3_STREAM_OUTPUT_KIND"] = stream_output_kind

    llm_path, bicodec_path = resolve_model_paths(model_path)
    if not is_initialized():
        initialize_runtime(
            model_path=llm_path,
            bicodec_path=bicodec_path,
            device="cuda",
            gpu_memory_utilization=gpu_mem,
            num_engines=num_engines,
            max_num_batched_tokens=max_num_batched_tokens,
            max_num_seqs=max_num_seqs,
            enable_chunked_prefill=enable_chunked_prefill,
            enable_prefix_caching=enable_prefix_caching,
            disable_log_stats=disable_log_stats,
            enforce_eager=enforce_eager,
            precompute_speaker_globals=False,
        )


async def run_one_request(
    text: str,
    speaker: str,
    max_tokens: int,
    timeout_s: float,
    enable_chunking: bool,
    seed: Optional[int],
) -> RequestResult:
    from veena3modal.services import tts_runtime

    t0 = time.perf_counter()
    first_chunk_t: Optional[float] = None
    prev_chunk_t: Optional[float] = None
    inter_chunk_ms: List[float] = []
    chunks = 0
    total_bytes = 0
    final_metrics: Dict[str, Any] = {}

    async def _consume_stream() -> None:
        nonlocal first_chunk_t, prev_chunk_t, chunks, total_bytes, final_metrics
        gen = tts_runtime.generate_speech_streaming(
            text=text,
            speaker=speaker,
            max_tokens=max_tokens,
            enable_chunking=enable_chunking,
            seed=seed,
        )
        async for audio_chunk, metrics in gen:
            now = time.perf_counter()
            if first_chunk_t is None:
                first_chunk_t = now
            if prev_chunk_t is not None:
                inter_chunk_ms.append((now - prev_chunk_t) * 1000.0)
            prev_chunk_t = now
            chunks += 1
            total_bytes += len(audio_chunk)
            if isinstance(metrics, dict):
                # Keep reference; runtime may mutate this dict after the final chunk
                # when internal stream generation completes.
                final_metrics = metrics

    try:
        await asyncio.wait_for(_consume_stream(), timeout=timeout_s)
        end = time.perf_counter()
        latency_ms = (end - t0) * 1000.0
        ttfb_ms = ((first_chunk_t - t0) * 1000.0) if first_chunk_t is not None else 0.0
        audio_seconds = float(final_metrics.get("audio_duration_seconds", 0.0) or 0.0)
        tail_s = max(1e-6, (latency_ms - ttfb_ms) / 1000.0)
        tail_emit_speed_x = audio_seconds / tail_s if audio_seconds > 0 else 0.0
        return RequestResult(
            success=True,
            latency_ms=latency_ms,
            ttfb_ms=ttfb_ms,
            chunks_sent=chunks,
            total_bytes=total_bytes,
            audio_seconds=audio_seconds,
            speaker=speaker,
            inter_chunk_ms=inter_chunk_ms,
            tail_emit_speed_x=tail_emit_speed_x,
            final_metrics=dict(final_metrics),
        )
    except asyncio.TimeoutError:
        return RequestResult(
            success=False,
            latency_ms=(time.perf_counter() - t0) * 1000.0,
            ttfb_ms=0.0,
            chunks_sent=chunks,
            total_bytes=total_bytes,
            audio_seconds=0.0,
            speaker=speaker,
            error="timeout",
            inter_chunk_ms=inter_chunk_ms,
            final_metrics=dict(final_metrics),
        )
    except Exception as exc:
        return RequestResult(
            success=False,
            latency_ms=(time.perf_counter() - t0) * 1000.0,
            ttfb_ms=0.0,
            chunks_sent=chunks,
            total_bytes=total_bytes,
            audio_seconds=0.0,
            speaker=speaker,
            error=str(exc)[:200],
            inter_chunk_ms=inter_chunk_ms,
            final_metrics=dict(final_metrics),
        )


async def run_level(
    concurrency: int,
    total_requests: int,
    text: str,
    max_tokens: int,
    timeout_s: float,
    enable_chunking: bool,
    gpu_monitor: GPUMonitor,
    seed_base: Optional[int],
) -> LevelResult:
    level = LevelResult(concurrency=concurrency, total_requests=total_requests)
    semaphore = asyncio.Semaphore(concurrency)

    async def wrapped(i: int) -> RequestResult:
        async with semaphore:
            speaker = SPEAKERS[i % len(SPEAKERS)]
            seed = (seed_base + i) if seed_base is not None else None
            return await run_one_request(
                text=text,
                speaker=speaker,
                max_tokens=max_tokens,
                timeout_s=timeout_s,
                enable_chunking=enable_chunking,
                seed=seed,
            )

    gpu_monitor.snapshots.clear()
    t0 = time.perf_counter()
    level.results = await asyncio.gather(*[wrapped(i) for i in range(total_requests)])
    level.wall_time_s = time.perf_counter() - t0
    level.gpu_summary = gpu_monitor.summary()
    return level


def print_level(level: LevelResult) -> None:
    print(f"\n=== streaming concurrency={level.concurrency} requests={level.total_requests} ===")
    print(f"success={len(level.successes)}/{level.total_requests} ({level.success_rate:.0%})")
    print(f"wall={level.wall_time_s:.2f}s throughput={level.throughput_rps:.2f} req/s")
    if level.failures:
        reasons = level.failure_reason_counts()
        summary = ", ".join(f"{k}:{v}" for k, v in sorted(reasons.items(), key=lambda kv: kv[1], reverse=True))
        print(f"failures: {summary}")
    lat = level.latency_stats()
    if lat:
        print(f"latency_ms avg={lat['avg']:.0f} p50={lat['p50']:.0f} p95={lat['p95']:.0f} max={lat['max']:.0f}")
    ttfb = level.ttfb_stats()
    if ttfb:
        print(f"ttfb_ms avg={ttfb['avg']:.0f} p50={ttfb['p50']:.0f} p95={ttfb['p95']:.0f} max={ttfb['max']:.0f}")
    cc = level.chunk_count_stats()
    if cc:
        print(f"chunks_per_req avg={cc['avg']:.1f} p50={cc['p50']:.1f} p95={cc['p95']:.1f} max={cc['max']:.1f}")
    ic = level.inter_chunk_stats()
    if ic:
        print(
            "inter_chunk_ms avg={:.1f} p50={:.1f} p95={:.1f} max={:.1f}".format(
                ic["avg"], ic["p50"], ic["p95"], ic["max"]
            )
        )
    es = level.emit_speed_stats()
    if es:
        print(
            "tail_emit_speed_x avg={:.2f} p50={:.2f} p95={:.2f} max={:.2f}".format(
                es["avg"], es["p50"], es["p95"], es["max"]
            )
        )

    timing_keys = [
        "timeline_to_first_batch_ms",
        "timeline_to_first_chunk_emitted_ms",
        "timeline_total_ms",
        "llm_time_in_queue_ms",
        "llm_scheduler_ms",
        "llm_first_token_ms",
        "llm_request_lifecycle_ms",
        "llm_generation_wall_ms",
        "llm_time_per_token_ms",
        "llm_batch_wall_ms_total",
        "llm_decode_wall_ms_total",
        "tokens_per_batch_p50",
        "bicodec_decode_calls",
        "bicodec_decode_wall_ms_total",
        "bicodec_decode_interval_ms_p50",
        "decode_interval_applied_tokens_p50",
        "decode_pending_requests_p50",
        "admission_wait_ms",
        "admission_inflight_on_grant",
        "admission_waiters_on_grant",
        "admission_queue_depth_on_entry",
        "batch_avg_size",
        "batch_max_seen",
        "batch_workers_live",
        "batch_workers_target",
        "batch_queue_wait_ms_avg",
        "batch_compute_ms_avg",
        "batch_queue_depth_avg",
    ]
    for key in timing_keys:
        st = level.timing_stats(key)
        if st:
            print(f"{key}: avg={st['avg']:.2f} p50={st['p50']:.2f} p95={st['p95']:.2f} max={st['max']:.2f}")

    latency_over_timeline = level.latency_over_metric_stats("timeline_total_ms")
    if latency_over_timeline:
        print(
            "latency_over_timeline_ms: avg={:.2f} p50={:.2f} p95={:.2f} max={:.2f}".format(
                latency_over_timeline["avg"],
                latency_over_timeline["p50"],
                latency_over_timeline["p95"],
                latency_over_timeline["max"],
            )
        )

    tl_total = level.timing_stats("timeline_total_ms")
    llm_total = level.timing_stats("llm_generation_wall_ms")
    bicodec_total = level.timing_stats("bicodec_decode_wall_ms_total")
    if tl_total and llm_total:
        total_avg = tl_total.get("avg", 0.0)
        llm_avg = llm_total.get("avg", 0.0)
        bicodec_avg = bicodec_total.get("avg", 0.0) if bicodec_total else 0.0
        other_avg = max(0.0, total_avg - llm_avg - bicodec_avg)
        if total_avg > 0:
            print(
                "timeline_share: llm={:.1f}% bicodec={:.1f}% other={:.1f}%".format(
                    (llm_avg / total_avg) * 100,
                    (bicodec_avg / total_avg) * 100,
                    (other_avg / total_avg) * 100,
                )
            )

    if "memory_mb" in level.gpu_summary:
        g = level.gpu_summary
        print(
            "gpu: util_avg={:.0f}% util_max={:.0f}% mem_avg={:.0f}MB mem_max={:.0f}MB".format(
                g["gpu_util_pct"]["avg"],
                g["gpu_util_pct"]["max"],
                g["memory_mb"]["avg"],
                g["memory_mb"]["max"],
            )
        )


def save_results(path: str, levels: List[LevelResult], args: argparse.Namespace) -> None:
    payload = {
        "config": {
            "levels": args.levels,
            "text": args.text,
            "max_tokens": args.max_tokens,
            "timeout_s": args.timeout,
            "seed_base": args.seed_base,
            "requests_multiplier": args.requests_multiplier,
            "min_requests": args.min_requests,
            "gpu_memory": args.gpu_memory,
            "stream_output_kind": args.stream_output_kind,
            "chunking": args.chunking,
            "num_engines": args.num_engines,
            "max_num_batched_tokens": args.max_num_batched_tokens,
            "max_num_seqs": args.max_num_seqs,
            "enable_chunked_prefill": args.enable_chunked_prefill,
            "disable_chunked_prefill": args.disable_chunked_prefill,
            "disable_prefix_caching": args.disable_prefix_caching,
            "disable_engine_stats_logs": args.disable_engine_stats_logs,
            "enforce_eager": args.enforce_eager,
            "stream_decode_interval": args.stream_decode_interval,
            "stream_window_size": args.stream_window_size,
            "stream_min_semantic_first": args.stream_min_semantic_first,
            "stream_crossfade_ms": args.stream_crossfade_ms,
            "adaptive_decode": args.adaptive_decode,
            "adaptive_decode_first_interval": args.adaptive_decode_first_interval,
            "adaptive_decode_busy_interval": args.adaptive_decode_busy_interval,
            "adaptive_decode_busy_pending": args.adaptive_decode_busy_pending,
            "disable_windowed_decode": args.disable_windowed_decode,
            "disable_bicodec_batching": args.disable_bicodec_batching,
            "bicodec_batch_max": args.bicodec_batch_max,
            "bicodec_batch_timeout_ms": args.bicodec_batch_timeout_ms,
            "bicodec_batch_workers": args.bicodec_batch_workers,
            "bicodec_batch_scale_pending": args.bicodec_batch_scale_pending,
            "bicodec_batch_scale_mode": args.bicodec_batch_scale_mode,
            "stream_admission_max_inflight": args.stream_admission_max_inflight,
            "stream_admission_max_queue": args.stream_admission_max_queue,
            "stream_admission_max_wait_ms": args.stream_admission_max_wait_ms,
            "stream_admission_poll_ms": args.stream_admission_poll_ms,
            "timestamp": int(time.time()),
        },
        "levels": [],
    }
    for lvl in levels:
        payload["levels"].append(
            {
                "concurrency": lvl.concurrency,
                "total_requests": lvl.total_requests,
                "success_rate": lvl.success_rate,
                "throughput_rps": lvl.throughput_rps,
                "wall_time_s": lvl.wall_time_s,
                "latency_ms": lvl.latency_stats(),
                "ttfb_ms": lvl.ttfb_stats(),
                "chunks_per_req": lvl.chunk_count_stats(),
                "inter_chunk_ms": lvl.inter_chunk_stats(),
                "tail_emit_speed_x": lvl.emit_speed_stats(),
                "gpu": lvl.gpu_summary,
                "timing": {
                    "timeline_to_first_batch_ms": lvl.timing_stats("timeline_to_first_batch_ms"),
                    "timeline_to_first_chunk_emitted_ms": lvl.timing_stats("timeline_to_first_chunk_emitted_ms"),
                    "timeline_total_ms": lvl.timing_stats("timeline_total_ms"),
                    "llm_time_in_queue_ms": lvl.timing_stats("llm_time_in_queue_ms"),
                    "llm_scheduler_ms": lvl.timing_stats("llm_scheduler_ms"),
                    "llm_first_token_ms": lvl.timing_stats("llm_first_token_ms"),
                    "llm_request_lifecycle_ms": lvl.timing_stats("llm_request_lifecycle_ms"),
                    "llm_generation_wall_ms": lvl.timing_stats("llm_generation_wall_ms"),
                    "llm_time_per_token_ms": lvl.timing_stats("llm_time_per_token_ms"),
                    "llm_batch_wall_ms_total": lvl.timing_stats("llm_batch_wall_ms_total"),
                    "llm_decode_wall_ms_total": lvl.timing_stats("llm_decode_wall_ms_total"),
                    "tokens_per_batch_p50": lvl.timing_stats("tokens_per_batch_p50"),
                    "bicodec_decode_calls": lvl.timing_stats("bicodec_decode_calls"),
                    "bicodec_decode_wall_ms_total": lvl.timing_stats("bicodec_decode_wall_ms_total"),
                    "bicodec_decode_interval_ms_p50": lvl.timing_stats("bicodec_decode_interval_ms_p50"),
                    "decode_interval_applied_tokens_p50": lvl.timing_stats("decode_interval_applied_tokens_p50"),
                    "decode_pending_requests_p50": lvl.timing_stats("decode_pending_requests_p50"),
                    "admission_wait_ms": lvl.timing_stats("admission_wait_ms"),
                    "admission_inflight_on_grant": lvl.timing_stats("admission_inflight_on_grant"),
                    "admission_waiters_on_grant": lvl.timing_stats("admission_waiters_on_grant"),
                    "admission_queue_depth_on_entry": lvl.timing_stats("admission_queue_depth_on_entry"),
                    "batch_avg_size": lvl.timing_stats("batch_avg_size"),
                    "batch_max_seen": lvl.timing_stats("batch_max_seen"),
                    "batch_workers_live": lvl.timing_stats("batch_workers_live"),
                    "batch_workers_target": lvl.timing_stats("batch_workers_target"),
                    "batch_queue_wait_ms_avg": lvl.timing_stats("batch_queue_wait_ms_avg"),
                    "batch_compute_ms_avg": lvl.timing_stats("batch_compute_ms_avg"),
                    "batch_queue_depth_avg": lvl.timing_stats("batch_queue_depth_avg"),
                    "latency_over_timeline_ms": lvl.latency_over_metric_stats("timeline_total_ms"),
                },
                "failure_reasons": lvl.failure_reason_counts(),
                "requests": [
                    {
                        "success": r.success,
                        "latency_ms": r.latency_ms,
                        "ttfb_ms": r.ttfb_ms,
                        "chunks_sent": r.chunks_sent,
                        "total_bytes": r.total_bytes,
                        "audio_seconds": r.audio_seconds,
                        "speaker": r.speaker,
                        "error": r.error,
                        "inter_chunk_ms": r.inter_chunk_ms,
                        "tail_emit_speed_x": r.tail_emit_speed_x,
                        "final_metrics": r.final_metrics,
                    }
                    for r in lvl.results
                ],
            }
        )
    with open(path, "w", encoding="utf-8") as fd:
        json.dump(payload, fd, indent=2)


async def main_async(args: argparse.Namespace) -> None:
    text = TEST_TEXTS[args.text]
    levels = [int(x.strip()) for x in args.levels.split(",") if x.strip()]
    gpu_monitor = GPUMonitor(interval_seconds=0.5)
    gpu_monitor.start()

    # warmup
    await run_one_request(
        text=text,
        speaker=SPEAKERS[0],
        max_tokens=args.max_tokens,
        timeout_s=args.timeout,
        enable_chunking=args.chunking,
        seed=args.seed_base,
    )

    results: List[LevelResult] = []
    for idx, c in enumerate(levels):
        n = max(args.min_requests, c * args.requests_multiplier)
        level = await run_level(
            concurrency=c,
            total_requests=n,
            text=text,
            max_tokens=args.max_tokens,
            timeout_s=args.timeout,
            enable_chunking=args.chunking,
            gpu_monitor=gpu_monitor,
            seed_base=args.seed_base,
        )
        print_level(level)
        results.append(level)
        if idx < len(levels) - 1:
            await asyncio.sleep(2)

    gpu_monitor.stop()
    save_results(args.output, results, args)
    print(f"\nresults saved: {args.output}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Direct runtime streaming stress benchmark")
    parser.add_argument("--levels", default="16,32,64", help="Comma-separated concurrency levels")
    parser.add_argument("--text", choices=["short", "medium", "long"], default="short")
    parser.add_argument("--max-tokens", type=int, default=512, help="max_tokens sent to streaming generation")
    parser.add_argument("--seed-base", type=int, default=1337, help="Base random seed; request i uses seed_base+i")
    parser.add_argument("--timeout", type=float, default=180.0, help="Per-request timeout seconds")
    parser.add_argument("--requests-multiplier", type=int, default=2, help="Requests per level ~= concurrency * multiplier")
    parser.add_argument("--min-requests", type=int, default=8)
    parser.add_argument("--chunking", action="store_true", default=False, help="Enable long-text chunking path")
    parser.add_argument("--gpu-memory", type=float, default=0.25)
    parser.add_argument("--stream-output-kind", choices=["delta", "cumulative"], default="delta")
    parser.add_argument("--num-engines", type=int, default=1, help="Number of vLLM engines")
    parser.add_argument("--max-num-batched-tokens", type=int, default=None)
    parser.add_argument("--max-num-seqs", type=int, default=None)
    parser.add_argument("--enable-chunked-prefill", action="store_true", default=None)
    parser.add_argument("--disable-chunked-prefill", action="store_true", default=False)
    parser.add_argument("--disable-prefix-caching", action="store_true", default=False)
    parser.add_argument("--disable-engine-stats-logs", action="store_true", default=False)
    parser.add_argument("--enforce-eager", action="store_true", default=False)
    parser.add_argument("--stream-decode-interval", type=int, default=48)
    parser.add_argument("--stream-window-size", type=int, default=128)
    parser.add_argument("--stream-min-semantic-first", type=int, default=10)
    parser.add_argument("--stream-crossfade-ms", type=int, default=50)
    parser.add_argument("--adaptive-decode", action="store_true", default=False)
    parser.add_argument("--adaptive-decode-first-interval", type=int, default=48)
    parser.add_argument("--adaptive-decode-busy-interval", type=int, default=64)
    parser.add_argument("--adaptive-decode-busy-pending", type=int, default=32)
    parser.add_argument("--disable-windowed-decode", action="store_true", default=False)
    parser.add_argument("--disable-bicodec-batching", action="store_true", default=False)
    parser.add_argument("--bicodec-batch-max", type=int, default=16)
    parser.add_argument("--bicodec-batch-timeout-ms", type=float, default=6.0)
    parser.add_argument("--bicodec-batch-workers", type=int, default=1)
    parser.add_argument("--bicodec-batch-scale-pending", type=int, default=0)
    parser.add_argument("--bicodec-batch-scale-mode", choices=["sticky", "dynamic"], default="sticky")
    parser.add_argument("--stream-admission-max-inflight", type=int, default=0)
    parser.add_argument("--stream-admission-max-queue", type=int, default=0)
    parser.add_argument("--stream-admission-max-wait-ms", type=float, default=0.0)
    parser.add_argument("--stream-admission-poll-ms", type=float, default=2.0)
    parser.add_argument("--model-path", type=str, default="")
    parser.add_argument("--output", default="stress_streaming_runtime.json")
    args = parser.parse_args()

    from veena3modal.local_server import DEFAULT_LOCAL_MODEL_DIR

    model_path = args.model_path or DEFAULT_LOCAL_MODEL_DIR

    os.environ["VEENA3_STREAM_OUTPUT_KIND"] = args.stream_output_kind
    os.environ["VEENA3_STREAM_DECODE_INTERVAL"] = str(args.stream_decode_interval)
    os.environ["VEENA3_STREAM_WINDOW_SIZE"] = str(args.stream_window_size)
    os.environ["VEENA3_STREAM_MIN_SEMANTIC_FIRST"] = str(args.stream_min_semantic_first)
    os.environ["VEENA3_STREAM_CROSSFADE_MS"] = str(args.stream_crossfade_ms)
    os.environ["VEENA3_STREAM_ADAPTIVE_DECODE"] = "1" if args.adaptive_decode else "0"
    os.environ["VEENA3_STREAM_DECODE_INTERVAL_FIRST"] = str(args.adaptive_decode_first_interval)
    os.environ["VEENA3_STREAM_DECODE_INTERVAL_BUSY"] = str(args.adaptive_decode_busy_interval)
    os.environ["VEENA3_STREAM_DECODE_BUSY_PENDING"] = str(args.adaptive_decode_busy_pending)
    os.environ["VEENA3_STREAM_WINDOWED_DECODE"] = "0" if args.disable_windowed_decode else "1"
    os.environ["VEENA3_BICODEC_BATCHING"] = "0" if args.disable_bicodec_batching else "1"
    os.environ["VEENA3_BICODEC_BATCH_MAX"] = str(args.bicodec_batch_max)
    os.environ["VEENA3_BICODEC_BATCH_TIMEOUT_MS"] = str(args.bicodec_batch_timeout_ms)
    os.environ["VEENA3_BICODEC_BATCH_WORKERS"] = str(args.bicodec_batch_workers)
    os.environ["VEENA3_BICODEC_BATCH_SCALE_PENDING"] = str(args.bicodec_batch_scale_pending)
    os.environ["VEENA3_BICODEC_BATCH_SCALE_MODE"] = str(args.bicodec_batch_scale_mode)
    os.environ["VEENA3_STREAM_ADMISSION_MAX_INFLIGHT"] = str(args.stream_admission_max_inflight)
    os.environ["VEENA3_STREAM_ADMISSION_MAX_QUEUE"] = str(args.stream_admission_max_queue)
    os.environ["VEENA3_STREAM_ADMISSION_MAX_WAIT_MS"] = str(args.stream_admission_max_wait_ms)
    os.environ["VEENA3_STREAM_ADMISSION_POLL_MS"] = str(args.stream_admission_poll_ms)

    enable_chunked_prefill = args.enable_chunked_prefill
    if args.disable_chunked_prefill:
        enable_chunked_prefill = False

    init_runtime(
        model_path=model_path,
        gpu_mem=args.gpu_memory,
        stream_output_kind=args.stream_output_kind,
        num_engines=args.num_engines,
        max_num_batched_tokens=args.max_num_batched_tokens,
        max_num_seqs=args.max_num_seqs,
        enable_chunked_prefill=enable_chunked_prefill,
        enable_prefix_caching=(False if args.disable_prefix_caching else None),
        disable_log_stats=(True if args.disable_engine_stats_logs else None),
        enforce_eager=(True if args.enforce_eager else None),
    )
    asyncio.run(main_async(args))


if __name__ == "__main__":
    main()
