#!/usr/bin/env python3
"""
Direct-runtime stress benchmark with detailed timing for optimized local TTS path.

This bypasses HTTP serving and calls `tts_runtime` directly so we can isolate
generation + decode bottlenecks under concurrency.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import os
import statistics
import subprocess
import sys
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional


# Path setup (same pattern as local_server.py)
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
EXTERNAL_DIR = os.path.join(REPO_ROOT, "external")
for path in [os.path.join(EXTERNAL_DIR, "sparktts"), os.path.join(EXTERNAL_DIR, "AP-BWE"), REPO_ROOT]:
    if path not in sys.path:
        sys.path.insert(0, path)

# Keep local auth disabled to match local_server behavior
os.environ["AUTH_BYPASS_MODE"] = "true"


TEST_TEXTS = {
    "short": "Hello, this is a quick benchmark test.",
    "medium": (
        "The quick brown fox jumps over the lazy dog. "
        "This benchmark sentence checks token generation speed, decode latency, "
        "and concurrency behavior for Spark TTS."
    ),
    "long": (
        "In the center of a crowded city, a quiet storyteller documented voices from dawn to dusk. "
        "Every pause, every laugh, every whisper revealed a different cadence. "
        "She replayed each recording at night, aligning tone and timing until each phrase felt natural."
    ),
}

SPEAKERS = ["lipakshi", "vardan", "reet", "Nandini", "krishna", "anika"]


@dataclass
class GPUSnapshot:
    timestamp: float
    memory_used_mb: float
    memory_total_mb: float
    gpu_utilization_pct: float
    temperature_c: float


class GPUMonitor:
    def __init__(self, interval_seconds: float = 0.5):
        self.interval = interval_seconds
        self.snapshots: List[GPUSnapshot] = []
        self._running = False
        self._thread: Optional[threading.Thread] = None

    def start(self) -> None:
        self._running = True
        self._thread = threading.Thread(target=self._poll_loop, daemon=True)
        self._thread.start()

    def stop(self) -> None:
        self._running = False
        if self._thread:
            self._thread.join(timeout=3)

    def _poll_loop(self) -> None:
        while self._running:
            try:
                res = subprocess.run(
                    [
                        "nvidia-smi",
                        "--query-gpu=memory.used,memory.total,utilization.gpu,temperature.gpu",
                        "--format=csv,noheader,nounits",
                    ],
                    capture_output=True,
                    text=True,
                    timeout=5,
                )
                if res.returncode == 0 and res.stdout:
                    parts = [p.strip() for p in res.stdout.strip().split(",")]
                    if len(parts) >= 4:
                        self.snapshots.append(
                            GPUSnapshot(
                                timestamp=time.time(),
                                memory_used_mb=float(parts[0]),
                                memory_total_mb=float(parts[1]),
                                gpu_utilization_pct=float(parts[2]),
                                temperature_c=float(parts[3]),
                            )
                        )
            except Exception:
                pass
            time.sleep(self.interval)

    def summary(self) -> Dict[str, Any]:
        if not self.snapshots:
            return {"samples": 0, "error": "no GPU samples captured"}
        mem = [s.memory_used_mb for s in self.snapshots]
        util = [s.gpu_utilization_pct for s in self.snapshots]
        temp = [s.temperature_c for s in self.snapshots]
        return {
            "samples": len(self.snapshots),
            "memory_mb": {
                "min": min(mem),
                "avg": statistics.mean(mem),
                "max": max(mem),
                "total": self.snapshots[0].memory_total_mb,
            },
            "gpu_util_pct": {
                "min": min(util),
                "avg": statistics.mean(util),
                "max": max(util),
            },
            "temperature_c": {
                "min": min(temp),
                "avg": statistics.mean(temp),
                "max": max(temp),
            },
        }


@dataclass
class RequestResult:
    success: bool
    latency_ms: float
    audio_bytes: int
    audio_seconds: float
    speaker: str
    text_length: int
    error: Optional[str] = None
    timing: Dict[str, Any] = field(default_factory=dict)


@dataclass
class LevelResult:
    concurrency: int
    total_requests: int
    results: List[RequestResult] = field(default_factory=list)
    wall_time_s: float = 0.0
    gpu_summary: Dict[str, Any] = field(default_factory=dict)

    @property
    def successes(self) -> List[RequestResult]:
        return [r for r in self.results if r.success]

    @property
    def failures(self) -> List[RequestResult]:
        return [r for r in self.results if not r.success]

    @property
    def success_rate(self) -> float:
        return len(self.successes) / len(self.results) if self.results else 0.0

    @property
    def throughput_rps(self) -> float:
        return self.total_requests / self.wall_time_s if self.wall_time_s > 0 else 0.0

    def p(self, values: List[float], pct: int) -> float:
        if not values:
            return 0.0
        s = sorted(values)
        idx = int((len(s) - 1) * (pct / 100))
        return s[idx]

    def timing_stats(self, key: str) -> Dict[str, float]:
        vals: List[float] = []
        for r in self.successes:
            v = r.timing.get(key)
            if isinstance(v, (int, float)):
                vals.append(float(v))
        if not vals:
            return {}
        return {
            "avg": statistics.mean(vals),
            "p50": self.p(vals, 50),
            "p95": self.p(vals, 95),
            "max": max(vals),
        }

    def latency_stats(self) -> Dict[str, float]:
        vals = [r.latency_ms for r in self.successes]
        if not vals:
            return {}
        return {
            "avg": statistics.mean(vals),
            "p50": self.p(vals, 50),
            "p95": self.p(vals, 95),
            "max": max(vals),
        }

    def latency_over_metric_stats(self, key: str) -> Dict[str, float]:
        vals: List[float] = []
        for r in self.successes:
            base = r.timing.get(key)
            if isinstance(base, (int, float)):
                vals.append(max(0.0, float(r.latency_ms) - float(base)))
        if not vals:
            return {}
        return {
            "avg": statistics.mean(vals),
            "p50": self.p(vals, 50),
            "p95": self.p(vals, 95),
            "max": max(vals),
        }


def init_runtime(
    model_path: str,
    gpu_mem: float,
    num_engines: int,
    max_num_batched_tokens: Optional[int],
    max_num_seqs: Optional[int],
    enable_chunked_prefill: Optional[bool],
    enable_prefix_caching: Optional[bool],
    disable_log_stats: Optional[bool],
    enforce_eager: Optional[bool],
) -> None:
    from veena3modal.local_server import resolve_model_paths
    from veena3modal.services.tts_runtime import initialize_runtime, is_initialized

    llm_path, bicodec_path = resolve_model_paths(model_path)
    if not is_initialized():
        initialize_runtime(
            model_path=llm_path,
            bicodec_path=bicodec_path,
            device="cuda",
            gpu_memory_utilization=gpu_mem,
            num_engines=num_engines,
            max_num_batched_tokens=max_num_batched_tokens,
            max_num_seqs=max_num_seqs,
            enable_chunked_prefill=enable_chunked_prefill,
            enable_prefix_caching=enable_prefix_caching,
            disable_log_stats=disable_log_stats,
            enforce_eager=enforce_eager,
            precompute_speaker_globals=False,
        )


async def run_one_request(
    text: str,
    speaker: str,
    max_tokens: int,
    timeout_s: float,
    chunking: bool,
) -> RequestResult:
    from veena3modal.services import tts_runtime

    t0 = time.perf_counter()
    try:
        if chunking:
            coro = tts_runtime.generate_speech_chunked(
                text=text,
                speaker=speaker,
                max_tokens=max_tokens,
            )
        else:
            coro = tts_runtime.generate_speech(
                text=text,
                speaker=speaker,
                max_tokens=max_tokens,
            )
        audio_bytes, timing = await asyncio.wait_for(coro, timeout=timeout_s)
        elapsed_ms = (time.perf_counter() - t0) * 1000
        if not audio_bytes:
            return RequestResult(
                success=False,
                latency_ms=elapsed_ms,
                audio_bytes=0,
                audio_seconds=0.0,
                speaker=speaker,
                text_length=len(text),
                error="no_audio",
                timing=timing or {},
            )
        sample_rate = int((timing or {}).get("output_sample_rate", 16000))
        audio_seconds = max(0.0, (len(audio_bytes) - 44) / (sample_rate * 2))
        return RequestResult(
            success=True,
            latency_ms=elapsed_ms,
            audio_bytes=len(audio_bytes),
            audio_seconds=audio_seconds,
            speaker=speaker,
            text_length=len(text),
            timing=timing or {},
        )
    except asyncio.TimeoutError:
        return RequestResult(
            success=False,
            latency_ms=(time.perf_counter() - t0) * 1000,
            audio_bytes=0,
            audio_seconds=0.0,
            speaker=speaker,
            text_length=len(text),
            error="timeout",
        )
    except Exception as exc:
        return RequestResult(
            success=False,
            latency_ms=(time.perf_counter() - t0) * 1000,
            audio_bytes=0,
            audio_seconds=0.0,
            speaker=speaker,
            text_length=len(text),
            error=str(exc)[:180],
        )


async def run_level(
    concurrency: int,
    total_requests: int,
    text: str,
    max_tokens: int,
    timeout_s: float,
    chunking: bool,
    gpu_monitor: GPUMonitor,
) -> LevelResult:
    level = LevelResult(concurrency=concurrency, total_requests=total_requests)
    semaphore = asyncio.Semaphore(concurrency)

    async def wrapped(i: int) -> RequestResult:
        async with semaphore:
            speaker = SPEAKERS[i % len(SPEAKERS)]
            return await run_one_request(
                text=text,
                speaker=speaker,
                max_tokens=max_tokens,
                timeout_s=timeout_s,
                chunking=chunking,
            )

    gpu_monitor.snapshots.clear()
    t0 = time.perf_counter()
    level.results = await asyncio.gather(*[wrapped(i) for i in range(total_requests)])
    level.wall_time_s = time.perf_counter() - t0
    level.gpu_summary = gpu_monitor.summary()
    return level


def print_level(level: LevelResult) -> None:
    print(f"\n=== concurrency={level.concurrency} requests={level.total_requests} ===")
    print(f"success={len(level.successes)}/{level.total_requests} ({level.success_rate:.0%})")
    print(f"wall={level.wall_time_s:.2f}s throughput={level.throughput_rps:.2f} req/s")
    lat = level.latency_stats()
    if lat:
        print(f"latency_ms avg={lat['avg']:.0f} p50={lat['p50']:.0f} p95={lat['p95']:.0f} max={lat['max']:.0f}")

    keys = [
        "generation_ms",
        "llm_time_in_queue_ms",
        "llm_scheduler_ms",
        "llm_first_token_ms",
        "llm_request_lifecycle_ms",
        "llm_queued_to_scheduled_ms",
        "llm_scheduled_to_first_token_ms",
        "llm_first_to_last_token_ms",
        "llm_queued_to_last_token_ms",
        "timeline_llm_first_batch_ms",
        "timeline_llm_done_ms",
        "timeline_parse_done_ms",
        "timeline_bicodec_done_ms",
        "timeline_total_ms",
        "llm_batch_wall_ms_total",
        "llm_batch_gpu_ms_total",
        "llm_decode_wall_ms_total",
        "llm_decode_gpu_ms_total",
        "llm_parse_ms",
        "llm_time_per_token_ms",
        "bicodec_decode_wall_ms",
        "bicodec_decode_gpu_ms",
    ]
    for key in keys:
        st = level.timing_stats(key)
        if st:
            print(f"{key}: avg={st['avg']:.2f} p50={st['p50']:.2f} p95={st['p95']:.2f} max={st['max']:.2f}")

    tl_total = level.timing_stats("timeline_total_ms")
    llm_wall = level.timing_stats("llm_batch_wall_ms_total")
    bicodec = level.timing_stats("bicodec_decode_wall_ms")
    latency_over_timeline = level.latency_over_metric_stats("timeline_total_ms")
    latency_over_generation = level.latency_over_metric_stats("generation_ms")
    if latency_over_timeline:
        print(
            "latency_over_timeline_ms: avg={:.2f} p50={:.2f} p95={:.2f} max={:.2f}".format(
                latency_over_timeline["avg"],
                latency_over_timeline["p50"],
                latency_over_timeline["p95"],
                latency_over_timeline["max"],
            )
        )
    if latency_over_generation:
        print(
            "latency_over_generation_ms: avg={:.2f} p50={:.2f} p95={:.2f} max={:.2f}".format(
                latency_over_generation["avg"],
                latency_over_generation["p50"],
                latency_over_generation["p95"],
                latency_over_generation["max"],
            )
        )
    if tl_total and llm_wall:
        total_avg = tl_total.get("avg", 0.0)
        llm_avg = llm_wall.get("avg", 0.0)
        bicodec_avg = bicodec.get("avg", 0.0) if bicodec else 0.0
        other_avg = max(0.0, total_avg - llm_avg - bicodec_avg)
        if total_avg > 0:
            print(
                "timeline_share: llm={:.1f}% bicodec={:.1f}% other={:.1f}%".format(
                    (llm_avg / total_avg) * 100,
                    (bicodec_avg / total_avg) * 100,
                    (other_avg / total_avg) * 100,
                )
            )
        samples = [
            r for r in level.successes
            if isinstance(r.timing.get("timeline_total_ms"), (int, float))
        ]
        if samples:
            sample = sorted(samples, key=lambda r: r.latency_ms)[len(samples) // 2]
            t = sample.timing
            print(
                "timeline_sample_ms: first_batch={:.1f} llm_done={:.1f} parse_done={:.1f} "
                "bicodec_done={:.1f} total={:.1f}".format(
                    float(t.get("timeline_llm_first_batch_ms", 0.0)),
                    float(t.get("timeline_llm_done_ms", 0.0)),
                    float(t.get("timeline_parse_done_ms", 0.0)),
                    float(t.get("timeline_bicodec_done_ms", 0.0)),
                    float(t.get("timeline_total_ms", 0.0)),
                )
            )

    if "memory_mb" in level.gpu_summary:
        g = level.gpu_summary
        print(
            "gpu: util_avg={:.0f}% util_max={:.0f}% mem_avg={:.0f}MB mem_max={:.0f}MB".format(
                g["gpu_util_pct"]["avg"],
                g["gpu_util_pct"]["max"],
                g["memory_mb"]["avg"],
                g["memory_mb"]["max"],
            )
        )

    errors: Dict[str, int] = {}
    for r in level.failures:
        k = r.error or "failed"
        errors[k] = errors.get(k, 0) + 1
    for k, v in sorted(errors.items(), key=lambda kv: -kv[1]):
        print(f"error[{v}] {k}")


def save_results(path: str, levels: List[LevelResult], args: argparse.Namespace) -> None:
    payload = {
        "config": {
            "levels": args.levels,
            "text": args.text,
            "max_tokens": args.max_tokens,
            "timeout_s": args.timeout,
            "chunking": args.chunking,
            "requests_multiplier": args.requests_multiplier,
            "gpu_memory": args.gpu_memory,
            "num_engines": args.num_engines,
            "max_num_batched_tokens": args.max_num_batched_tokens,
            "max_num_seqs": args.max_num_seqs,
            "enable_chunked_prefill": args.enable_chunked_prefill,
            "disable_chunked_prefill": args.disable_chunked_prefill,
            "disable_prefix_caching": args.disable_prefix_caching,
            "disable_engine_stats_logs": args.disable_engine_stats_logs,
            "enforce_eager": args.enforce_eager,
            "enable_gpu_decode_timing": args.enable_gpu_decode_timing,
            "non_stream_final_only": args.non_stream_final_only,
            "timestamp": int(time.time()),
        },
        "levels": [],
    }
    for lvl in levels:
        payload["levels"].append(
            {
                "concurrency": lvl.concurrency,
                "total_requests": lvl.total_requests,
                "success_rate": lvl.success_rate,
                "throughput_rps": lvl.throughput_rps,
                "wall_time_s": lvl.wall_time_s,
                "latency_ms": lvl.latency_stats(),
                "gpu": lvl.gpu_summary,
                "timing": {
                    "generation_ms": lvl.timing_stats("generation_ms"),
                    "llm_time_in_queue_ms": lvl.timing_stats("llm_time_in_queue_ms"),
                    "llm_scheduler_ms": lvl.timing_stats("llm_scheduler_ms"),
                    "llm_first_token_ms": lvl.timing_stats("llm_first_token_ms"),
                    "llm_request_lifecycle_ms": lvl.timing_stats("llm_request_lifecycle_ms"),
                    "llm_queued_to_scheduled_ms": lvl.timing_stats("llm_queued_to_scheduled_ms"),
                    "llm_scheduled_to_first_token_ms": lvl.timing_stats("llm_scheduled_to_first_token_ms"),
                    "llm_first_to_last_token_ms": lvl.timing_stats("llm_first_to_last_token_ms"),
                    "llm_queued_to_last_token_ms": lvl.timing_stats("llm_queued_to_last_token_ms"),
                    "llm_batch_wall_ms_total": lvl.timing_stats("llm_batch_wall_ms_total"),
                    "llm_batch_gpu_ms_total": lvl.timing_stats("llm_batch_gpu_ms_total"),
                    "llm_decode_wall_ms_total": lvl.timing_stats("llm_decode_wall_ms_total"),
                    "llm_decode_gpu_ms_total": lvl.timing_stats("llm_decode_gpu_ms_total"),
                    "llm_parse_ms": lvl.timing_stats("llm_parse_ms"),
                    "llm_time_per_token_ms": lvl.timing_stats("llm_time_per_token_ms"),
                    "bicodec_decode_wall_ms": lvl.timing_stats("bicodec_decode_wall_ms"),
                    "bicodec_decode_gpu_ms": lvl.timing_stats("bicodec_decode_gpu_ms"),
                    "timeline_llm_first_batch_ms": lvl.timing_stats("timeline_llm_first_batch_ms"),
                    "timeline_llm_done_ms": lvl.timing_stats("timeline_llm_done_ms"),
                    "timeline_parse_done_ms": lvl.timing_stats("timeline_parse_done_ms"),
                    "timeline_bicodec_done_ms": lvl.timing_stats("timeline_bicodec_done_ms"),
                    "timeline_total_ms": lvl.timing_stats("timeline_total_ms"),
                    "latency_over_timeline_ms": lvl.latency_over_metric_stats("timeline_total_ms"),
                    "latency_over_generation_ms": lvl.latency_over_metric_stats("generation_ms"),
                },
                "requests": [
                    {
                        "success": r.success,
                        "latency_ms": r.latency_ms,
                        "audio_bytes": r.audio_bytes,
                        "audio_seconds": r.audio_seconds,
                        "speaker": r.speaker,
                        "text_length": r.text_length,
                        "error": r.error,
                        "timing": r.timing,
                    }
                    for r in lvl.results
                ],
            }
        )
    with open(path, "w", encoding="utf-8") as fd:
        json.dump(payload, fd, indent=2)


async def main_async(args: argparse.Namespace) -> None:
    text = TEST_TEXTS[args.text]
    levels = [int(x.strip()) for x in args.levels.split(",") if x.strip()]
    gpu_monitor = GPUMonitor(interval_seconds=0.5)
    gpu_monitor.start()

    # warmup
    await run_one_request(
        text=text,
        speaker=SPEAKERS[0],
        max_tokens=args.max_tokens,
        timeout_s=args.timeout,
        chunking=args.chunking,
    )

    results: List[LevelResult] = []
    for idx, c in enumerate(levels):
        n = max(args.min_requests, c * args.requests_multiplier)
        level = await run_level(
            concurrency=c,
            total_requests=n,
            text=text,
            max_tokens=args.max_tokens,
            timeout_s=args.timeout,
            chunking=args.chunking,
            gpu_monitor=gpu_monitor,
        )
        print_level(level)
        results.append(level)
        if idx < len(levels) - 1:
            await asyncio.sleep(2)

    gpu_monitor.stop()
    save_results(args.output, results, args)
    print(f"\nresults saved: {args.output}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Direct runtime stress benchmark (detailed timings)")
    parser.add_argument("--levels", default="1,2,4,8", help="Comma-separated concurrency levels")
    parser.add_argument("--text", choices=["short", "medium", "long"], default="short")
    parser.add_argument("--max-tokens", type=int, default=128, help="max_tokens sent to generation")
    parser.add_argument("--timeout", type=float, default=120.0, help="Per-request timeout seconds")
    parser.add_argument("--requests-multiplier", type=int, default=4, help="Requests per level ~= concurrency * multiplier")
    parser.add_argument("--min-requests", type=int, default=8)
    parser.add_argument("--chunking", action="store_true", default=False, help="Use generate_speech_chunked")
    parser.add_argument("--gpu-memory", type=float, default=0.25)
    parser.add_argument("--num-engines", type=int, default=1, help="Number of vLLM engines")
    parser.add_argument("--max-num-batched-tokens", type=int, default=None)
    parser.add_argument("--max-num-seqs", type=int, default=None)
    parser.add_argument("--enable-chunked-prefill", action="store_true", default=None)
    parser.add_argument("--disable-chunked-prefill", action="store_true", default=False)
    parser.add_argument("--disable-prefix-caching", action="store_true", default=False)
    parser.add_argument("--disable-engine-stats-logs", action="store_true", default=False)
    parser.add_argument("--enforce-eager", action="store_true", default=False)
    parser.add_argument(
        "--non-stream-final-only",
        action="store_true",
        default=False,
        help="Use vLLM FINAL_ONLY output mode (less per-token driver overhead, less granular timeline)",
    )
    parser.add_argument(
        "--enable-gpu-decode-timing",
        action="store_true",
        default=False,
        help="Enable precise BiCodec GPU timing (adds synchronize overhead)",
    )
    parser.add_argument("--model-path", type=str, default="")
    parser.add_argument("--output", default="stress_runtime_detailed.json")
    args = parser.parse_args()

    from veena3modal.local_server import DEFAULT_LOCAL_MODEL_DIR
    model_path = args.model_path or DEFAULT_LOCAL_MODEL_DIR
    if args.enable_gpu_decode_timing:
        os.environ["VEENA3_PERF_GPU_TIMING"] = "1"
    else:
        os.environ.setdefault("VEENA3_PERF_GPU_TIMING", "0")
    if args.non_stream_final_only:
        os.environ["VEENA3_NON_STREAM_FINAL_ONLY"] = "1"
    else:
        os.environ.setdefault("VEENA3_NON_STREAM_FINAL_ONLY", "0")
    enable_chunked_prefill = args.enable_chunked_prefill
    if args.disable_chunked_prefill:
        enable_chunked_prefill = False
    init_runtime(
        model_path=model_path,
        gpu_mem=args.gpu_memory,
        num_engines=args.num_engines,
        max_num_batched_tokens=args.max_num_batched_tokens,
        max_num_seqs=args.max_num_seqs,
        enable_chunked_prefill=enable_chunked_prefill,
        enable_prefix_caching=(False if args.disable_prefix_caching else None),
        disable_log_stats=(True if args.disable_engine_stats_logs else None),
        enforce_eager=(True if args.enforce_eager else None),
    )

    asyncio.run(main_async(args))


if __name__ == "__main__":
    main()
