#!/usr/bin/env python3
"""Detailed local stress benchmark for optimized TTS path.

This script is focused on collecting per-request GPU/LLM timing for local tuning:
- Request wall-clock latency and server reported TTFB
- GPU timing breakdown (LLM batch wall / GPU wall, parse, decode)
- Time per token and per batch
- Concurrency visibility from in-flight counters
- End-to-end throughput and RTF with GPU utilization snapshots

Usage:
    python scripts/stress_test_optimized_local.py
    python scripts/stress_test_optimized_local.py --levels 1,5,10,20 --concurrency-scaling both
    python scripts/stress_test_optimized_local.py --text-category long --output benchmarks.json
"""

import argparse
import asyncio
import json
import statistics
import subprocess
import sys
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

import httpx


DEFAULT_URL = "http://localhost:8000"
GENERATE_PATH = "/v1/tts/generate"


TEST_TEXTS = {
    "short": "Hello, this is a quick test.",
    "medium": (
        "The quick brown fox jumps over the lazy dog. "
        "This is a baseline quality sentence used for benchmarking local TTS latency, "
        "TTFB, and tokenization behavior under load."
    ),
    "long": (
        "In the heart of a bustling city, where towering skyscrapers cast shadows over crowded streets, "
        "there lived a quiet linguist who believed every sentence carried a rhythm. "
        "She recorded conversations at dawn and dusk, trying to capture how meaning shifted with breath. "
        "By evening she would review them, matching sound to emotion, and every night she learned a new way to speak."
    ),
}

SPEAKERS = ["lipakshi", "vardan", "reet", "Nandini", "krishna", "anika"]


@dataclass
class GPUSnapshot:
    timestamp: float
    memory_used_mb: float
    memory_total_mb: float
    gpu_utilization_pct: float
    temperature_c: float


class GPUMonitor:
    """Background polling for quick GPU health snapshots."""

    def __init__(self, interval_seconds: float = 1.0):
        self.interval = interval_seconds
        self.snapshots: List[GPUSnapshot] = []
        self._running = False
        self._thread: Optional[threading.Thread] = None

    def start(self):
        self._running = True
        self._thread = threading.Thread(target=self._poll_loop, daemon=True)
        self._thread.start()

    def stop(self):
        self._running = False
        if self._thread:
            self._thread.join(timeout=3)

    def _poll_loop(self):
        while self._running:
            try:
                result = subprocess.run(
                    [
                        "nvidia-smi",
                        "--query-gpu=memory.used,memory.total,utilization.gpu,temperature.gpu",
                        "--format=csv,noheader,nounits",
                    ],
                    capture_output=True,
                    text=True,
                    timeout=5,
                )
                if result.returncode == 0 and result.stdout:
                    parts = [p.strip() for p in result.stdout.strip().split(",")]
                    if len(parts) >= 4:
                        self.snapshots.append(
                            GPUSnapshot(
                                timestamp=time.time(),
                                memory_used_mb=float(parts[0]),
                                memory_total_mb=float(parts[1]),
                                gpu_utilization_pct=float(parts[2]),
                                temperature_c=float(parts[3]),
                            )
                        )
            except Exception:
                pass
            time.sleep(self.interval)

    def summary(self) -> Dict[str, Any]:
        if not self.snapshots:
            return {
                "samples": 0,
                "error": "no GPU samples captured",
            }

        mem_used = [s.memory_used_mb for s in self.snapshots]
        gpu_util = [s.gpu_utilization_pct for s in self.snapshots]
        temp = [s.temperature_c for s in self.snapshots]

        return {
            "samples": len(self.snapshots),
            "memory_mb": {
                "min": min(mem_used),
                "max": max(mem_used),
                "avg": statistics.mean(mem_used),
                "total": self.snapshots[0].memory_total_mb,
            },
            "gpu_util_pct": {
                "min": min(gpu_util),
                "max": max(gpu_util),
                "avg": statistics.mean(gpu_util),
            },
            "temperature_c": {
                "min": min(temp),
                "max": max(temp),
                "avg": statistics.mean(temp),
            },
        }


@dataclass
class RequestResult:
    success: bool
    status_code: int
    latency_ms: float
    ttfb_ms: float = 0.0
    server_ttfb_ms: float = 0.0
    audio_bytes: int = 0
    audio_seconds: float = 0.0
    request_id: str = ""
    error: Optional[str] = None
    speaker: str = ""
    text_length: int = 0
    stream: bool = False
    timing: Dict[str, Any] = field(default_factory=dict)
    request_inflight: Optional[int] = None


@dataclass
class LevelResult:
    level_name: str
    concurrency: int
    total_requests: int
    mode: str
    text_category: str
    chunking: bool
    results: List[RequestResult] = field(default_factory=list)
    wall_time_seconds: float = 0.0
    gpu_summary: Dict[str, Any] = field(default_factory=dict)

    @property
    def successes(self) -> List[RequestResult]:
        return [r for r in self.results if r.success]

    @property
    def failures(self) -> List[RequestResult]:
        return [r for r in self.results if not r.success]

    @property
    def success_rate(self) -> float:
        return len(self.successes) / len(self.results) if self.results else 0.0

    @property
    def throughput_rps(self) -> float:
        return self.total_requests / self.wall_time_seconds if self.wall_time_seconds > 0 else 0.0

    @property
    def audio_duration_total(self) -> float:
        return sum(r.audio_seconds for r in self.successes)

    @property
    def effective_rtf(self) -> float:
        total_audio = self.audio_duration_total
        return self.wall_time_seconds / total_audio if total_audio > 0 else 0.0

    def _percentile(self, values: List[float], pct: float) -> float:
        if not values:
            return 0.0
        sorted_vals = sorted(values)
        idx = int(len(sorted_vals) * pct / 100)
        return sorted_vals[min(idx, len(sorted_vals) - 1)]

    @property
    def latency_stats(self) -> Dict[str, float]:
        lat = [r.latency_ms for r in self.successes]
        if not lat:
            return {}
        return {
            "min": min(lat),
            "avg": statistics.mean(lat),
            "p50": self._percentile(lat, 50),
            "p95": self._percentile(lat, 95),
            "p99": self._percentile(lat, 99),
            "max": max(lat),
        }

    @property
    def ttfb_stats(self) -> Dict[str, float]:
        vals = [r.server_ttfb_ms for r in self.successes if r.server_ttfb_ms >= 0]
        if not vals:
            return {}
        return {
            "min": min(vals),
            "avg": statistics.mean(vals),
            "p50": self._percentile(vals, 50),
            "p95": self._percentile(vals, 95),
            "max": max(vals),
        }

    def timing_stats(self, key: str) -> Dict[str, float]:
        values: List[float] = []
        for r in self.successes:
            v = r.timing.get(key)
            if isinstance(v, (int, float)):
                values.append(float(v))
        if not values:
            return {}
        return {
            "min": min(values),
            "avg": statistics.mean(values),
            "p50": self._percentile(values, 50),
            "p95": self._percentile(values, 95),
            "p99": self._percentile(values, 99),
            "max": max(values),
        }

    def timing_rate(self, numerator_key: str, denominator_key: str) -> Tuple[float, int]:
        num = 0.0
        den = 0
        for r in self.successes:
            n = r.timing.get(numerator_key)
            d = r.timing.get(denominator_key)
            if isinstance(n, (int, float)) and isinstance(d, (int, float)) and d > 0:
                num += float(n)
                den += int(d)
        if den == 0:
            return 0.0, 0
        return num / den, den

    def observed_inflight_max(self) -> Optional[int]:
        values = [r.request_inflight for r in self.successes if isinstance(r.request_inflight, int)]
        if not values:
            return None
        return max(values)

    def error_summary(self) -> Dict[str, int]:
        errors: Dict[str, int] = {}
        for r in self.failures:
            key = r.error or f"HTTP {r.status_code}"
            errors[key[:100]] = errors.get(key[:100], 0) + 1
        return errors


def _coerce_number(raw: str) -> Optional[float]:
    if raw is None:
        return None
    v = str(raw).strip()
    if v == "":
        return None
    try:
        if "." in v:
            return float(v)
        return float(int(v))
    except ValueError:
        try:
            return float(v)
        except ValueError:
            return None


def _normalize_time_ms(raw: Optional[str]) -> float:
    value = _coerce_number(raw)
    if value is None:
        return 0.0
    return float(value)


def _parse_perf_headers(headers: Any) -> Dict[str, Any]:
    """Parse server timing headers into numeric metric fields."""
    metrics: Dict[str, Any] = {}

    # Explicit compact payload.
    detail_payload = headers.get("x-perf-details") if hasattr(headers, "get") else None
    if detail_payload:
        try:
            parsed = json.loads(detail_payload)
            if isinstance(parsed, dict):
                metrics.update(parsed)
        except Exception:
            pass

    header_map = {
        "x-api-preprocess-ms": "api_preprocess_ms",
        "x-api-generate-await-ms": "api_generate_await_ms",
        "x-api-postprocess-ms": "api_postprocess_ms",
        "x-api-total-ms": "api_total_ms",
        "x-api-overhead-vs-pipeline-ms": "api_overhead_vs_pipeline_ms",
        "x-api-overhead-vs-generation-ms": "api_overhead_vs_generation_ms",
        "x-timeline-first-batch-ms": "timeline_llm_first_batch_ms",
        "x-timeline-llm-done-ms": "timeline_llm_done_ms",
        "x-timeline-bicodec-done-ms": "timeline_bicodec_done_ms",
        "x-timeline-total-ms": "timeline_total_ms",
        "x-llm-token-total": "llm_token_total",
        "x-llm-batch-count": "llm_batch_count",
        "x-llm-batch-wall-ms": "llm_batch_wall_ms_total",
        "x-llm-batch-gpu-ms": "llm_batch_gpu_ms_total",
        "x-llm-batch-wall-ms-min": "llm_batch_wall_ms_min",
        "x-llm-batch-wall-ms-max": "llm_batch_wall_ms_max",
        "x-llm-batch-wall-ms-p50": "llm_batch_wall_ms_p50",
        "x-llm-batch-gpu-ms-min": "llm_batch_gpu_ms_min",
        "x-llm-batch-gpu-ms-max": "llm_batch_gpu_ms_max",
        "x-llm-batch-gpu-ms-p50": "llm_batch_gpu_ms_p50",
        "x-llm-decode-wall-ms": "llm_decode_wall_ms_total",
        "x-llm-decode-gpu-ms": "llm_decode_gpu_ms_total",
        "x-llm-decode-wall-ms-min": "llm_decode_wall_ms_min",
        "x-llm-decode-wall-ms-max": "llm_decode_wall_ms_max",
        "x-llm-decode-wall-ms-p50": "llm_decode_wall_ms_p50",
        "x-llm-decode-gpu-ms-min": "llm_decode_gpu_ms_min",
        "x-llm-decode-gpu-ms-max": "llm_decode_gpu_ms_max",
        "x-llm-decode-gpu-ms-p50": "llm_decode_gpu_ms_p50",
        "x-llm-decode-calls": "decode_calls",
        "x-llm-decode-cpu-ms": "llm_decode_cpu_ms",
        "x-llm-parse-ms": "llm_parse_ms",
        "x-llm-parse-avg-ms": "llm_parse_ms_avg",
        "x-llm-tokens-per-batch-min": "tokens_per_batch_min",
        "x-llm-tokens-per-batch-max": "tokens_per_batch_max",
        "x-llm-tokens-per-batch-p50": "tokens_per_batch_p50",
        "x-generation-ms": "generation_ms",
        "x-request-inflight": "request_inflight",
        "x-llm-text-chunked": "text_chunked",
        "x-llm-chunks-processed": "chunks_processed",
        "x-llm-time-per-token-ms": "llm_time_per_token_ms",
        "x-llm-time-per-batch-wall-ms": "llm_time_per_batch_wall_ms",
        "x-llm-time-per-batch-gpu-ms": "llm_time_per_batch_gpu_ms",
        "x-llm-parse-avg-ms": "llm_parse_ms_avg",
        "x-bicodec-decode-wall-ms": "bicodec_decode_wall_ms",
        "x-bicodec-decode-gpu-ms": "bicodec_decode_gpu_ms",
        "x-bicodec-decode-cpu-ms": "bicodec_decode_cpu_ms",
    }

    if hasattr(headers, "items"):
        for raw_key, raw_val in headers.items():
            lk = str(raw_key).lower()
            target_key = header_map.get(lk)
            if target_key is None:
                continue
            v = _coerce_number(raw_val)
            if v is not None:
                if target_key in {
                    "text_chunked",
                }:
                    metrics[target_key] = bool(v)
                else:
                    metrics[target_key] = int(v) if float(v).is_integer() else float(v)

    if "llm_token_total" in metrics and "generation_ms" in metrics and metrics["llm_token_total"]:
        try:
            metrics["llm_time_per_token_ms"] = float(metrics["generation_ms"]) / float(metrics["llm_token_total"])
        except Exception:
            pass

    if "llm_batch_count" in metrics and "generation_ms" in metrics and metrics["llm_batch_count"]:
        try:
            metrics["llm_time_per_batch_wall_ms"] = float(metrics["generation_ms"]) / float(metrics["llm_batch_count"])
        except Exception:
            pass

    if "llm_batch_count" in metrics and "llm_batch_gpu_ms_total" in metrics and metrics["llm_batch_count"]:
        try:
            metrics["llm_time_per_batch_gpu_ms"] = float(metrics["llm_batch_gpu_ms_total"]) / float(metrics["llm_batch_count"])
        except Exception:
            pass

    return metrics


async def make_request(
    client: httpx.AsyncClient,
    base_url: str,
    text: str,
    speaker: str,
    chunking: bool,
    max_tokens: int,
    stream: bool = False,
    timeout: float = 120.0,
) -> RequestResult:
    url = f"{base_url}{GENERATE_PATH}"
    payload = {
        "text": text,
        "speaker": speaker,
        "stream": stream,
        "chunking": chunking,
        "format": "wav",
        "max_tokens": max_tokens,
    }

    start = time.time()
    ttfb_time = 0.0

    try:
        if stream:
            async with client.stream(
                "POST",
                url,
                json=payload,
                headers={"Content-Type": "application/json"},
                timeout=timeout,
            ) as response:
                first = True
                total_bytes = 0
                async for chunk in response.aiter_bytes():
                    if first:
                        ttfb_time = (time.time() - start) * 1000
                        first = False
                    total_bytes += len(chunk)

                latency = (time.time() - start) * 1000
                timing = _parse_perf_headers(response.headers)

                if response.status_code == 200:
                    pcm_bytes = max(0, total_bytes - 44)
                    audio_seconds = pcm_bytes / (16000 * 2)
                    return RequestResult(
                        success=True,
                        status_code=200,
                        latency_ms=latency,
                        ttfb_ms=ttfb_time,
                        audio_bytes=total_bytes,
                        audio_seconds=audio_seconds,
                        request_id=response.headers.get("x-request-id", ""),
                        speaker=speaker,
                        text_length=len(text),
                        stream=True,
                        timing=timing,
                        request_inflight=timing.get("request_inflight"),
                        server_ttfb_ms=_normalize_time_ms(response.headers.get("x-ttfb-ms")),
                    )

                return RequestResult(
                    success=False,
                    status_code=response.status_code,
                    latency_ms=latency,
                    ttfb_ms=ttfb_time,
                    error=f"HTTP {response.status_code}",
                    speaker=speaker,
                    text_length=len(text),
                    stream=True,
                    timing=timing,
                    request_id=response.headers.get("x-request-id", ""),
                    request_inflight=timing.get("request_inflight"),
                )
        else:
            response = await client.post(
                url,
                json=payload,
                headers={"Content-Type": "application/json"},
                timeout=timeout,
            )
            latency = (time.time() - start) * 1000
            timing = _parse_perf_headers(response.headers)

            server_ttfb = _normalize_time_ms(response.headers.get("x-ttfb-ms"))
            audio_seconds = 0.0
            audio_sec_header = response.headers.get("x-audio-seconds")
            if audio_sec_header:
                parsed = _coerce_number(audio_sec_header)
                if parsed is not None:
                    audio_seconds = parsed
            elif response.status_code == 200:
                pcm_bytes = max(0, len(response.content) - 44)
                audio_seconds = pcm_bytes / (16000 * 2)

            if response.status_code == 200:
                return RequestResult(
                    success=True,
                    status_code=200,
                    latency_ms=latency,
                    ttfb_ms=latency,
                    server_ttfb_ms=server_ttfb,
                    audio_bytes=len(response.content),
                    audio_seconds=audio_seconds,
                    request_id=response.headers.get("x-request-id", ""),
                    speaker=speaker,
                    text_length=len(text),
                    stream=False,
                    timing=timing,
                    request_inflight=timing.get("request_inflight"),
                )

            return RequestResult(
                success=False,
                status_code=response.status_code,
                latency_ms=latency,
                ttfb_ms=latency,
                server_ttfb_ms=server_ttfb,
                audio_bytes=len(response.content),
                error=(response.text[:200] if response.text else f"HTTP {response.status_code}"),
                speaker=speaker,
                text_length=len(text),
                stream=False,
                timing=timing,
                request_id=response.headers.get("x-request-id", ""),
                request_inflight=timing.get("request_inflight"),
            )

    except httpx.TimeoutException:
        return RequestResult(
            success=False,
            status_code=0,
            latency_ms=(time.time() - start) * 1000,
            error="TIMEOUT",
            speaker=speaker,
            text_length=len(text),
            stream=stream,
        )
    except Exception as e:
        return RequestResult(
            success=False,
            status_code=0,
            latency_ms=(time.time() - start) * 1000,
            error=str(e)[:200],
            speaker=speaker,
            text_length=len(text),
            stream=stream,
        )


async def run_level(
    base_url: str,
    num_requests: int,
    concurrency: int,
    text_category: str,
    chunking: bool,
    max_tokens: int,
    request_timeout: float,
    stream: bool,
    gpu_monitor: Optional[GPUMonitor] = None,
) -> LevelResult:
    mode = "streaming" if stream else "non-streaming"
    level = LevelResult(
        level_name=f"{concurrency}c-{text_category}-{mode}",
        concurrency=concurrency,
        total_requests=num_requests,
        mode=mode,
        text_category=text_category,
        chunking=chunking,
    )

    semaphore = asyncio.Semaphore(concurrency)

    if text_category == "mixed":
        texts = list(TEST_TEXTS.values())
    else:
        texts = [TEST_TEXTS.get(text_category, TEST_TEXTS["short"])]

    async def limited_request(i: int, client: httpx.AsyncClient) -> RequestResult:
        async with semaphore:
            text = texts[i % len(texts)]
            speaker = SPEAKERS[i % len(SPEAKERS)]
            return await make_request(
                client=client,
                base_url=base_url,
                text=text,
                speaker=speaker,
                chunking=chunking,
                max_tokens=max_tokens,
                stream=stream,
                timeout=request_timeout,
            )

    if gpu_monitor:
        gpu_monitor.snapshots.clear()

    start_time = time.time()
    limits = httpx.Limits(
        max_connections=concurrency + 10,
        max_keepalive_connections=concurrency,
    )
    async with httpx.AsyncClient(limits=limits) as client:
        tasks = [limited_request(i, client) for i in range(num_requests)]
        responses = await asyncio.gather(*tasks, return_exceptions=True)

    level.wall_time_seconds = time.time() - start_time

    for item in responses:
        if isinstance(item, RequestResult):
            level.results.append(item)
        else:
            level.results.append(
                RequestResult(
                    success=False,
                    status_code=0,
                    latency_ms=0,
                    error=str(item)[:200],
                    stream=stream,
                )
            )

    if gpu_monitor:
        level.gpu_summary = gpu_monitor.summary()

    return level


def print_level_report(level: LevelResult):
    print(f"\n{'=' * 72}")
    print(f"  {level.level_name.upper()}")
    print(f"  {level.total_requests} requests, {level.concurrency} concurrent, {level.mode}, chunking={level.chunking}")
    print(f"{'=' * 72}")

    print(f"  Success:     {len(level.successes)}/{level.total_requests} ({level.success_rate:.0%})")
    print(f"  Wall time:   {level.wall_time_seconds:.2f}s")
    print(f"  Throughput:  {level.throughput_rps:.2f} req/s")
    print(f"  Audio total: {level.audio_duration_total:.2f}s")
    print(f"  Eff. RTF:    {level.effective_rtf:.3f}")

    if level.observed_inflight_max() is not None:
        print(f"  Peak inflight: {level.observed_inflight_max()}")

    latency = level.latency_stats
    if latency:
        print("\n  Latency (ms):")
        print(
            f"    min={latency['min']:.0f} avg={latency['avg']:.0f} "
            f"p50={latency['p50']:.0f} p95={latency['p95']:.0f} p99={latency['p99']:.0f} max={latency['max']:.0f}"
        )

    ttfb = level.ttfb_stats
    if ttfb and not level.mode.startswith("stream"):
        print("\n  Server TTFB (ms):")
        print(
            f"    min={ttfb['min']:.0f} avg={ttfb['avg']:.0f} "
            f"p50={ttfb['p50']:.0f} p95={ttfb['p95']:.0f} max={ttfb['max']:.0f}"
        )

    if level.mode == "non-streaming":
        llm = level.timing_stats("generation_ms")
        if llm:
            print("\n  Generation (ms):")
            print(f"    min={llm['min']:.0f} avg={llm['avg']:.0f} p50={llm['p50']:.0f} p95={llm['p95']:.0f} max={llm['max']:.0f}")
        wall = level.timing_stats("llm_batch_wall_ms_total")
        if wall:
            print("\n  LLM batch wall (ms):")
            print(f"    min={wall['min']:.0f} avg={wall['avg']:.0f} p50={wall['p50']:.0f} p95={wall['p95']:.0f} max={wall['max']:.0f}")
        gpu = level.timing_stats("llm_batch_gpu_ms_total")
        if gpu:
            print("\n  LLM batch GPU (ms):")
            print(f"    min={gpu['min']:.0f} avg={gpu['avg']:.0f} p50={gpu['p50']:.0f} p95={gpu['p95']:.0f} max={gpu['max']:.0f}")
        parse = level.timing_stats("llm_parse_ms")
        if parse:
            print("\n  Parse (ms):")
            print(f"    avg={parse['avg']:.0f} p95={parse['p95']:.0f}")
        decode = level.timing_stats("llm_decode_wall_ms_total")
        if decode:
            print("\n  Decode (ms):")
            print(f"    avg={decode['avg']:.0f} p95={decode['p95']:.0f}")
        bdecode = level.timing_stats("bicodec_decode_wall_ms")
        if bdecode:
            print("\n  BiCodec decode (ms):")
            print(f"    avg={bdecode['avg']:.0f} p95={bdecode['p95']:.0f}")
        bdecode_gpu = level.timing_stats("bicodec_decode_gpu_ms")
        if bdecode_gpu:
            print("\n  BiCodec decode GPU (ms):")
            print(f"    avg={bdecode_gpu['avg']:.0f} p95={bdecode_gpu['p95']:.0f}")
        api_pre = level.timing_stats("api_preprocess_ms")
        api_wait = level.timing_stats("api_generate_await_ms")
        api_post = level.timing_stats("api_postprocess_ms")
        api_total = level.timing_stats("api_total_ms")
        api_over_pipeline = level.timing_stats("api_overhead_vs_pipeline_ms")
        api_over_generation = level.timing_stats("api_overhead_vs_generation_ms")
        if api_total:
            print("\n  API overhead (ms):")
            if api_pre:
                print(f"    preprocess avg={api_pre['avg']:.0f} p95={api_pre['p95']:.0f}")
            if api_wait:
                print(f"    generate_await avg={api_wait['avg']:.0f} p95={api_wait['p95']:.0f}")
            if api_post:
                print(f"    postprocess avg={api_post['avg']:.0f} p95={api_post['p95']:.0f}")
            print(f"    api_total avg={api_total['avg']:.0f} p95={api_total['p95']:.0f}")
            if api_over_pipeline:
                print(
                    f"    overhead_vs_pipeline avg={api_over_pipeline['avg']:.0f} p95={api_over_pipeline['p95']:.0f}"
                )
            if api_over_generation:
                print(
                    f"    overhead_vs_generation avg={api_over_generation['avg']:.0f} p95={api_over_generation['p95']:.0f}"
                )
            if latency:
                client_overhead_avg = max(0.0, latency["avg"] - api_total["avg"])
                client_overhead_p95 = max(0.0, latency["p95"] - api_total["p95"])
                print(
                    f"    client+queue overhead avg={client_overhead_avg:.0f} p95={client_overhead_p95:.0f}"
                )

        tl_first = level.timing_stats("timeline_llm_first_batch_ms")
        tl_llm_done = level.timing_stats("timeline_llm_done_ms")
        tl_parse_done = level.timing_stats("timeline_parse_done_ms")
        tl_bicodec_done = level.timing_stats("timeline_bicodec_done_ms")
        tl_total = level.timing_stats("timeline_total_ms")
        if tl_total:
            print("\n  Timeline markers (ms):")
            if tl_first:
                print(f"    first_batch avg={tl_first['avg']:.0f} p95={tl_first['p95']:.0f}")
            if tl_llm_done:
                print(f"    llm_done avg={tl_llm_done['avg']:.0f} p95={tl_llm_done['p95']:.0f}")
            if tl_parse_done:
                print(f"    parse_done avg={tl_parse_done['avg']:.0f} p95={tl_parse_done['p95']:.0f}")
            if tl_bicodec_done:
                print(f"    bicodec_done avg={tl_bicodec_done['avg']:.0f} p95={tl_bicodec_done['p95']:.0f}")
            print(f"    request_done avg={tl_total['avg']:.0f} p95={tl_total['p95']:.0f}")

            llm_avg = wall.get("avg", 0.0) if wall else 0.0
            bicodec_avg = bdecode.get("avg", 0.0) if bdecode else 0.0
            total_avg = tl_total.get("avg", 0.0)
            if total_avg > 0:
                other_avg = max(0.0, total_avg - llm_avg - bicodec_avg)
                print(
                    "    Stage share: llm={:.1f}% bicodec={:.1f}% other={:.1f}%".format(
                        (llm_avg / total_avg) * 100,
                        (bicodec_avg / total_avg) * 100,
                        (other_avg / total_avg) * 100,
                    )
                )

            timeline_samples = [
                r for r in level.successes
                if isinstance(r.timing.get("timeline_total_ms"), (int, float))
            ]
            if timeline_samples:
                sample = sorted(timeline_samples, key=lambda r: r.latency_ms)[len(timeline_samples) // 2]
                print(
                    "    Sample (median-latency): first_batch={:.1f} llm_done={:.1f} parse_done={:.1f} "
                    "bicodec_done={:.1f} total={:.1f}".format(
                        float(sample.timing.get("timeline_llm_first_batch_ms", 0.0)),
                        float(sample.timing.get("timeline_llm_done_ms", 0.0)),
                        float(sample.timing.get("timeline_parse_done_ms", 0.0)),
                        float(sample.timing.get("timeline_bicodec_done_ms", 0.0)),
                        float(sample.timing.get("timeline_total_ms", 0.0)),
                    )
                )

        token_ms, token_count = level.timing_rate("generation_ms", "llm_token_total")
        if token_count:
            print(f"  Token timing: avg {token_ms * 1000:.3f} us/token")

    if level.gpu_summary and "memory_mb" in level.gpu_summary:
        gpu_summary = level.gpu_summary
        mem = gpu_summary["memory_mb"]
        util = gpu_summary["gpu_util_pct"]
        temp = gpu_summary["temperature_c"]
        print("\n  GPU:")
        print(f"    Memory: {mem['avg']:.0f}MB avg, {mem['max']:.0f}MB peak / {mem['total']:.0f}MB")
        print(f"    Util:   {util['avg']:.0f}% avg, {util['max']:.0f}% peak")
        print(f"    Temp:   {temp['avg']:.0f}C avg, {temp['max']:.0f}C peak")

    errors = level.error_summary()
    if errors:
        print("\n  Errors:")
        for key, count in sorted(errors.items(), key=lambda kv: -kv[1]):
            print(f"    [{count}x] {key}")


def print_summary_table(levels: List[LevelResult]):
    print(f"\n{'=' * 110}")
    print("  STRESS BENCHMARK SUMMARY")
    print(f"{'=' * 110}")
    header = f"{'Level':<26} {'OK%':>5} {'RPS':>7} {'p50':>7} {'p95':>7} {'GPU%':>5} {'GPUmem':>8} {'RTF':>6} {'Inflight':>8}"
    print(header)
    print("-" * 110)

    for level in levels:
        lat = level.latency_stats
        p50 = f"{lat.get('p50', 0):.0f}" if lat else "---"
        p95 = f"{lat.get('p95', 0):.0f}" if lat else "---"
        gpu_util = level.gpu_summary.get("gpu_util_pct", {}).get("avg", 0)
        gpu_mem = level.gpu_summary.get("memory_mb", {}).get("max", 0)
        inflight = level.observed_inflight_max() or 0
        print(
            f"{level.level_name:<26} "
            f"{level.success_rate:>4.0%} "
            f"{level.throughput_rps:>7.2f} "
            f"{p50:>7} "
            f"{p95:>7} "
            f"{gpu_util:>4.0f}% "
            f"{gpu_mem:>7.0f}MB "
            f"{level.effective_rtf:>6.3f} "
            f"{inflight:>8d}"
        )

    print("=" * 110)


def save_results(levels: List[LevelResult], path: str):
    payload = []
    for level in levels:
        payload.append({
            "level": level.level_name,
            "concurrency": level.concurrency,
            "mode": level.mode,
            "chunking": level.chunking,
            "text_category": level.text_category,
            "success_rate": level.success_rate,
            "throughput_rps": level.throughput_rps,
            "wall_time_s": level.wall_time_seconds,
            "latency_ms": level.latency_stats,
            "ttfb_ms": level.ttfb_stats,
            "effective_rtf": level.effective_rtf,
            "audio_total_seconds": level.audio_duration_total,
            "observed_inflight_max": level.observed_inflight_max(),
            "errors": level.error_summary(),
            "gpu": level.gpu_summary,
            "timing_averages": {
                "generation_ms": level.timing_stats("generation_ms"),
                "llm_batch_wall_ms_total": level.timing_stats("llm_batch_wall_ms_total"),
                "llm_batch_gpu_ms_total": level.timing_stats("llm_batch_gpu_ms_total"),
                "llm_parse_ms": level.timing_stats("llm_parse_ms"),
                "llm_decode_wall_ms_total": level.timing_stats("llm_decode_wall_ms_total"),
                "bicodec_decode_wall_ms": level.timing_stats("bicodec_decode_wall_ms"),
                "bicodec_decode_gpu_ms": level.timing_stats("bicodec_decode_gpu_ms"),
                "api_preprocess_ms": level.timing_stats("api_preprocess_ms"),
                "api_generate_await_ms": level.timing_stats("api_generate_await_ms"),
                "api_postprocess_ms": level.timing_stats("api_postprocess_ms"),
                "api_total_ms": level.timing_stats("api_total_ms"),
                "api_overhead_vs_pipeline_ms": level.timing_stats("api_overhead_vs_pipeline_ms"),
                "api_overhead_vs_generation_ms": level.timing_stats("api_overhead_vs_generation_ms"),
                "timeline_llm_first_batch_ms": level.timing_stats("timeline_llm_first_batch_ms"),
                "timeline_llm_done_ms": level.timing_stats("timeline_llm_done_ms"),
                "timeline_parse_done_ms": level.timing_stats("timeline_parse_done_ms"),
                "timeline_bicodec_done_ms": level.timing_stats("timeline_bicodec_done_ms"),
                "timeline_total_ms": level.timing_stats("timeline_total_ms"),
            },
            "requests": [
                {
                    "request_id": r.request_id,
                    "success": r.success,
                    "status_code": r.status_code,
                    "latency_ms": r.latency_ms,
                    "ttfb_ms": r.ttfb_ms,
                    "server_ttfb_ms": r.server_ttfb_ms,
                    "audio_bytes": r.audio_bytes,
                    "audio_seconds": r.audio_seconds,
                    "request_inflight": r.request_inflight,
                    "error": r.error,
                    "timing": r.timing,
                }
                for r in level.results
            ],
        })

    with open(path, "w") as fd:
        json.dump(payload, fd, indent=2)


async def run_stress_test(args):
    base_url = args.url

    try:
        async with httpx.AsyncClient() as client:
            response = await client.get(f"{base_url}/v1/tts/health", timeout=10)
            health = response.json()
            print(f"  Health: {health.get('status')}")
            print(f"  Version: {health.get('model_version')}")
            print(f"  GPU: {health.get('gpu_available')}")
    except Exception as e:
        print(f"ERROR: server not reachable at {base_url}: {e}")
        sys.exit(1)

    levels = [int(v) for v in args.levels.split(",") if v.strip()]
    gpu_monitor = GPUMonitor(interval_seconds=0.5)
    gpu_monitor.start()

    mode_label = "streaming" if args.stream else "non-streaming"
    results: List[LevelResult] = []

    print(f"\nWarming up (1 request, non-streaming)...")
    warmup = await run_level(
        base_url=base_url,
        num_requests=1,
        concurrency=1,
        text_category=args.text,
        chunking=args.chunking,
        max_tokens=args.max_tokens,
        request_timeout=args.timeout,
        stream=False,
        gpu_monitor=gpu_monitor,
    )
    print(f"  warmup ok: {warmup.success_rate:.0%}")

    for conc in levels:
        num_requests = max(conc * 3, 10)
        num_requests = min(num_requests, 200)
        level = await run_level(
            base_url=base_url,
            num_requests=num_requests,
            concurrency=conc,
            text_category=args.text,
            chunking=args.chunking,
            max_tokens=args.max_tokens,
            request_timeout=args.timeout,
            stream=args.stream,
            gpu_monitor=gpu_monitor,
        )
        print_level_report(level)
        results.append(level)
        if conc != levels[-1]:
            print("  cooldown 3s...")
            await asyncio.sleep(3)

    gpu_monitor.stop()
    print_summary_table(results)

    if args.output:
        save_results(results, args.output)
        print(f"\nResults written to {args.output}")


def main():
    parser = argparse.ArgumentParser(description="Detailed local optimized-path TTS benchmark")
    parser.add_argument("--url", default=DEFAULT_URL, help=f"Server URL (default: {DEFAULT_URL})")
    parser.add_argument("--levels", default="1,5,10,20,50", help="Comma-separated concurrency levels")
    parser.add_argument("--text", default="short", choices=["short", "medium", "long", "mixed"], help="Input text category")
    parser.add_argument("--chunking", action="store_true", help="Enable server-side chunking", default=True)
    parser.add_argument("--no-chunking", action="store_false", dest="chunking", help="Disable chunking")
    parser.add_argument("--stream", action="store_true", help="Use streaming endpoint")
    parser.add_argument("--max-tokens", type=int, default=128, help="Max tokens per request")
    parser.add_argument("--timeout", type=float, default=120.0, help="Per-request timeout seconds")
    parser.add_argument("--output", default="stress_test_optimized_local.json", help="Output JSON path")

    args = parser.parse_args()
    asyncio.run(run_stress_test(args))


if __name__ == "__main__":
    main()
