#!/usr/bin/env python3
"""
Local Stress Test for Veena3 TTS Server.

Progressive load ramp with GPU monitoring, latency percentiles,
error categorization, and throughput measurement.

Unlike the Modal load tests, this runs locally with:
- No network latency (localhost)
- Direct GPU monitoring (nvidia-smi + torch.cuda)
- Higher concurrency ceiling (single-machine, no rate limits)
- Real streaming TTFB measurement (byte-level timing)

Usage:
    # Full suite (ramps 1 → 5 → 10 → 20 → 50 → 100 concurrent)
    python scripts/stress_test_local.py

    # Quick smoke test (1 → 5 → 10 only)
    python scripts/stress_test_local.py --quick

    # Custom ramp
    python scripts/stress_test_local.py --levels 1,5,10,20

    # Streaming only
    python scripts/stress_test_local.py --stream-only

    # Non-streaming only
    python scripts/stress_test_local.py --no-stream

    # Custom endpoint
    python scripts/stress_test_local.py --url http://localhost:8080
"""

import argparse
import asyncio
import json
import os
import statistics
import subprocess
import sys
import time
import threading
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Tuple

import httpx

# === Configuration ===

DEFAULT_URL = "http://localhost:8000"
GENERATE_PATH = "/v1/tts/generate"

# Test texts at different lengths to stress different parts of the pipeline
TEST_TEXTS = {
    "short": "Hello, this is a quick test.",
    "medium": (
        "The quick brown fox jumps over the lazy dog. "
        "This sentence tests basic text normalization and generation quality "
        "with moderate length input that should complete in reasonable time."
    ),
    "long": (
        "In the heart of a bustling city, where towering skyscrapers cast long shadows "
        "over crowded streets, there lived an old bookkeeper named Margaret. She had spent "
        "forty years cataloging the stories of others, yet never found time to write her own. "
        "One rainy Tuesday afternoon, she discovered a leather-bound journal tucked behind "
        "a shelf of dusty encyclopedias. Its pages were blank, waiting. She picked up her pen "
        "and began to write. The words flowed like water from a broken dam."
    ),
}

SPEAKERS = ["lipakshi", "vardan", "reet", "Nandini", "krishna", "anika"]


# === GPU Monitoring ===

@dataclass
class GPUSnapshot:
    timestamp: float
    memory_used_mb: float
    memory_total_mb: float
    gpu_utilization_pct: float
    temperature_c: float


class GPUMonitor:
    """Background thread that polls nvidia-smi for GPU stats."""

    def __init__(self, interval_seconds: float = 1.0):
        self.interval = interval_seconds
        self.snapshots: List[GPUSnapshot] = []
        self._running = False
        self._thread: Optional[threading.Thread] = None

    def start(self):
        self._running = True
        self._thread = threading.Thread(target=self._poll_loop, daemon=True)
        self._thread.start()

    def stop(self):
        self._running = False
        if self._thread:
            self._thread.join(timeout=5)

    def _poll_loop(self):
        while self._running:
            try:
                result = subprocess.run(
                    [
                        "nvidia-smi",
                        "--query-gpu=memory.used,memory.total,utilization.gpu,temperature.gpu",
                        "--format=csv,noheader,nounits",
                    ],
                    capture_output=True, text=True, timeout=5,
                )
                if result.returncode == 0:
                    parts = result.stdout.strip().split(", ")
                    if len(parts) >= 4:
                        self.snapshots.append(GPUSnapshot(
                            timestamp=time.time(),
                            memory_used_mb=float(parts[0]),
                            memory_total_mb=float(parts[1]),
                            gpu_utilization_pct=float(parts[2]),
                            temperature_c=float(parts[3]),
                        ))
            except Exception:
                pass
            time.sleep(self.interval)

    def summary(self) -> Dict[str, Any]:
        if not self.snapshots:
            return {"error": "no GPU data collected"}

        mem_used = [s.memory_used_mb for s in self.snapshots]
        gpu_util = [s.gpu_utilization_pct for s in self.snapshots]
        temps = [s.temperature_c for s in self.snapshots]

        return {
            "samples": len(self.snapshots),
            "memory_mb": {
                "min": min(mem_used),
                "max": max(mem_used),
                "avg": statistics.mean(mem_used),
                "total": self.snapshots[0].memory_total_mb,
            },
            "gpu_util_pct": {
                "min": min(gpu_util),
                "max": max(gpu_util),
                "avg": statistics.mean(gpu_util),
            },
            "temperature_c": {
                "min": min(temps),
                "max": max(temps),
                "avg": statistics.mean(temps),
            },
        }


# === Request Result ===

@dataclass
class RequestResult:
    success: bool
    status_code: int
    latency_ms: float  # Total wall-clock time
    ttfb_ms: float = 0.0  # Time to first byte (from HTTP response start)
    server_ttfb_ms: float = 0.0  # Server-reported TTFB (X-TTFB-ms header)
    audio_bytes: int = 0
    audio_seconds: float = 0.0
    error: Optional[str] = None
    speaker: str = ""
    text_length: int = 0
    stream: bool = False


# === Load Level Result ===

@dataclass
class LevelResult:
    level_name: str
    concurrency: int
    total_requests: int
    mode: str  # "non-streaming" or "streaming"
    text_category: str  # "short", "medium", "long", "mixed"
    results: List[RequestResult] = field(default_factory=list)
    wall_time_seconds: float = 0.0
    gpu_summary: Dict[str, Any] = field(default_factory=dict)

    @property
    def successes(self) -> List[RequestResult]:
        return [r for r in self.results if r.success]

    @property
    def failures(self) -> List[RequestResult]:
        return [r for r in self.results if not r.success]

    @property
    def success_rate(self) -> float:
        return len(self.successes) / len(self.results) if self.results else 0

    @property
    def throughput_rps(self) -> float:
        return self.total_requests / self.wall_time_seconds if self.wall_time_seconds > 0 else 0

    def _percentile(self, values: List[float], pct: float) -> float:
        if not values:
            return 0.0
        s = sorted(values)
        idx = int(len(s) * pct / 100)
        return s[min(idx, len(s) - 1)]

    @property
    def latency_stats(self) -> Dict[str, float]:
        lats = [r.latency_ms for r in self.successes]
        if not lats:
            return {}
        return {
            "min": min(lats),
            "avg": statistics.mean(lats),
            "p50": self._percentile(lats, 50),
            "p95": self._percentile(lats, 95),
            "p99": self._percentile(lats, 99),
            "max": max(lats),
        }

    @property
    def ttfb_stats(self) -> Dict[str, float]:
        ttfbs = [r.ttfb_ms for r in self.successes if r.ttfb_ms > 0]
        if not ttfbs:
            return {}
        return {
            "min": min(ttfbs),
            "avg": statistics.mean(ttfbs),
            "p50": self._percentile(ttfbs, 50),
            "p95": self._percentile(ttfbs, 95),
            "max": max(ttfbs),
        }

    @property
    def audio_duration_total(self) -> float:
        return sum(r.audio_seconds for r in self.successes)

    @property
    def effective_rtf(self) -> float:
        """Wall time / total audio produced. Lower = better."""
        total_audio = self.audio_duration_total
        return self.wall_time_seconds / total_audio if total_audio > 0 else 0

    def error_summary(self) -> Dict[str, int]:
        errors: Dict[str, int] = {}
        for r in self.failures:
            key = r.error or f"HTTP {r.status_code}"
            errors[key[:80]] = errors.get(key[:80], 0) + 1
        return errors


# === HTTP Client ===

async def make_request(
    client: httpx.AsyncClient,
    base_url: str,
    text: str,
    speaker: str,
    stream: bool = False,
    timeout: float = 120.0,
) -> RequestResult:
    """Make a single TTS request with timing."""
    url = f"{base_url}{GENERATE_PATH}"
    payload = {
        "text": text,
        "speaker": speaker,
        "stream": stream,
        "format": "wav",
    }

    start = time.time()
    ttfb_time = 0.0

    try:
        if stream:
            # Streaming: measure true TTFB (time to first byte of audio)
            async with client.stream(
                "POST", url,
                json=payload,
                headers={"Content-Type": "application/json"},
                timeout=timeout,
            ) as response:
                first_byte = True
                total_bytes = 0
                async for chunk in response.aiter_bytes():
                    if first_byte:
                        ttfb_time = (time.time() - start) * 1000
                        first_byte = False
                    total_bytes += len(chunk)

                latency = (time.time() - start) * 1000

                if response.status_code == 200:
                    # Estimate audio duration from bytes (16kHz, 16-bit, mono, minus WAV header)
                    pcm_bytes = max(0, total_bytes - 44)
                    audio_seconds = pcm_bytes / (16000 * 2)

                    return RequestResult(
                        success=True,
                        status_code=200,
                        latency_ms=latency,
                        ttfb_ms=ttfb_time,
                        audio_bytes=total_bytes,
                        audio_seconds=audio_seconds,
                        speaker=speaker,
                        text_length=len(text),
                        stream=True,
                    )
                else:
                    return RequestResult(
                        success=False,
                        status_code=response.status_code,
                        latency_ms=latency,
                        ttfb_ms=ttfb_time,
                        error=f"HTTP {response.status_code}",
                        speaker=speaker,
                        text_length=len(text),
                        stream=True,
                    )
        else:
            # Non-streaming: standard request
            response = await client.post(
                url,
                json=payload,
                headers={"Content-Type": "application/json"},
                timeout=timeout,
            )
            latency = (time.time() - start) * 1000

            # Server-reported TTFB
            server_ttfb = 0.0
            ttfb_header = response.headers.get("X-TTFB-ms")
            if ttfb_header:
                try:
                    server_ttfb = float(ttfb_header)
                except ValueError:
                    pass

            # Audio duration from header or estimate
            audio_seconds = 0.0
            audio_sec_header = response.headers.get("X-Audio-Seconds")
            if audio_sec_header:
                try:
                    audio_seconds = float(audio_sec_header)
                except ValueError:
                    pass
            elif response.status_code == 200:
                pcm_bytes = max(0, len(response.content) - 44)
                audio_seconds = pcm_bytes / (16000 * 2)

            if response.status_code == 200:
                return RequestResult(
                    success=True,
                    status_code=200,
                    latency_ms=latency,
                    ttfb_ms=latency,  # Non-streaming: TTFB = total latency
                    server_ttfb_ms=server_ttfb,
                    audio_bytes=len(response.content),
                    audio_seconds=audio_seconds,
                    speaker=speaker,
                    text_length=len(text),
                    stream=False,
                )
            else:
                error_text = response.text[:200] if response.text else f"HTTP {response.status_code}"
                return RequestResult(
                    success=False,
                    status_code=response.status_code,
                    latency_ms=latency,
                    error=error_text,
                    speaker=speaker,
                    text_length=len(text),
                    stream=False,
                )

    except httpx.TimeoutException:
        return RequestResult(
            success=False, status_code=0,
            latency_ms=(time.time() - start) * 1000,
            error="TIMEOUT",
            speaker=speaker, text_length=len(text), stream=stream,
        )
    except Exception as e:
        return RequestResult(
            success=False, status_code=0,
            latency_ms=(time.time() - start) * 1000,
            error=str(e)[:200],
            speaker=speaker, text_length=len(text), stream=stream,
        )


# === Load Runner ===

async def run_level(
    base_url: str,
    num_requests: int,
    concurrency: int,
    text_category: str = "short",
    stream: bool = False,
    gpu_monitor: Optional[GPUMonitor] = None,
) -> LevelResult:
    """Run a single load level and collect results."""
    mode = "streaming" if stream else "non-streaming"
    level_name = f"{concurrency}c-{text_category}-{mode}"

    level = LevelResult(
        level_name=level_name,
        concurrency=concurrency,
        total_requests=num_requests,
        mode=mode,
        text_category=text_category,
    )

    semaphore = asyncio.Semaphore(concurrency)

    # Determine text for each request
    if text_category == "mixed":
        texts = list(TEST_TEXTS.values())
    else:
        texts = [TEST_TEXTS.get(text_category, TEST_TEXTS["short"])]

    async def limited_request(i: int, client: httpx.AsyncClient) -> RequestResult:
        async with semaphore:
            text = texts[i % len(texts)]
            speaker = SPEAKERS[i % len(SPEAKERS)]
            return await make_request(client, base_url, text, speaker, stream=stream)

    # Start GPU monitoring for this level
    if gpu_monitor:
        gpu_monitor.snapshots.clear()

    start_time = time.time()

    # Use connection pooling for realistic load
    limits = httpx.Limits(
        max_connections=concurrency + 10,
        max_keepalive_connections=concurrency,
    )
    async with httpx.AsyncClient(limits=limits) as client:
        tasks = [limited_request(i, client) for i in range(num_requests)]
        results = await asyncio.gather(*tasks, return_exceptions=True)

    level.wall_time_seconds = time.time() - start_time

    for r in results:
        if isinstance(r, Exception):
            level.results.append(RequestResult(
                success=False, status_code=0, latency_ms=0,
                error=str(r)[:200], stream=stream,
            ))
        elif isinstance(r, RequestResult):
            level.results.append(r)

    if gpu_monitor:
        level.gpu_summary = gpu_monitor.summary()

    return level


# === Reporting ===

def print_level_report(level: LevelResult):
    """Print a compact report for one load level."""
    lat = level.latency_stats
    ttfb = level.ttfb_stats

    print(f"\n{'=' * 72}")
    print(f"  {level.level_name.upper()}")
    print(f"  {level.total_requests} requests, {level.concurrency} concurrent, {level.mode}")
    print(f"{'=' * 72}")

    print(f"  Success:     {len(level.successes)}/{level.total_requests} ({level.success_rate:.0%})")
    print(f"  Wall time:   {level.wall_time_seconds:.2f}s")
    print(f"  Throughput:  {level.throughput_rps:.2f} req/s")
    print(f"  Audio total: {level.audio_duration_total:.1f}s")
    print(f"  Eff. RTF:    {level.effective_rtf:.3f}")

    if lat:
        print(f"\n  Latency (ms):")
        print(f"    min={lat['min']:.0f}  avg={lat['avg']:.0f}  "
              f"p50={lat['p50']:.0f}  p95={lat['p95']:.0f}  "
              f"p99={lat['p99']:.0f}  max={lat['max']:.0f}")

    if ttfb and level.mode == "streaming":
        print(f"  TTFB (ms):")
        print(f"    min={ttfb['min']:.0f}  avg={ttfb['avg']:.0f}  "
              f"p50={ttfb['p50']:.0f}  p95={ttfb['p95']:.0f}  max={ttfb['max']:.0f}")

    if level.gpu_summary and "memory_mb" in level.gpu_summary:
        gpu = level.gpu_summary
        mem = gpu["memory_mb"]
        util = gpu["gpu_util_pct"]
        print(f"\n  GPU:")
        print(f"    Memory: {mem['avg']:.0f}MB avg, {mem['max']:.0f}MB peak / {mem['total']:.0f}MB")
        print(f"    Util:   {util['avg']:.0f}% avg, {util['max']:.0f}% peak")
        print(f"    Temp:   {gpu['temperature_c']['avg']:.0f}C avg, {gpu['temperature_c']['max']:.0f}C peak")

    errors = level.error_summary()
    if errors:
        print(f"\n  Errors:")
        for err, count in sorted(errors.items(), key=lambda x: -x[1]):
            print(f"    [{count}x] {err}")


def print_summary_table(all_levels: List[LevelResult]):
    """Print a compact comparison table across all levels."""
    print(f"\n{'=' * 100}")
    print(f"  STRESS TEST SUMMARY")
    print(f"{'=' * 100}")

    header = f"{'Level':<28} {'OK%':>5} {'RPS':>7} {'p50':>7} {'p95':>7} {'p99':>7} {'RTF':>6} {'GPU%':>5} {'GPUmem':>8}"
    print(header)
    print("-" * 100)

    for level in all_levels:
        lat = level.latency_stats
        gpu_util_avg = level.gpu_summary.get("gpu_util_pct", {}).get("avg", 0)
        gpu_mem_max = level.gpu_summary.get("memory_mb", {}).get("max", 0)

        p50 = f"{lat.get('p50', 0):.0f}" if lat else "---"
        p95 = f"{lat.get('p95', 0):.0f}" if lat else "---"
        p99 = f"{lat.get('p99', 0):.0f}" if lat else "---"

        print(
            f"{level.level_name:<28} "
            f"{level.success_rate:>4.0%} "
            f"{level.throughput_rps:>7.2f} "
            f"{p50:>7} "
            f"{p95:>7} "
            f"{p99:>7} "
            f"{level.effective_rtf:>6.3f} "
            f"{gpu_util_avg:>4.0f}% "
            f"{gpu_mem_max:>6.0f}MB"
        )

    print("=" * 100)


# === Main ===

async def run_stress_test(args):
    """Run the full stress test suite."""
    base_url = args.url

    # Verify server is up
    print(f"Checking server at {base_url}...")
    try:
        async with httpx.AsyncClient() as client:
            resp = await client.get(f"{base_url}/v1/tts/health", timeout=10)
            health = resp.json()
            print(f"  Status: {health.get('status')}")
            print(f"  Model:  {health.get('model_version')}")
            print(f"  GPU:    {health.get('gpu_available')}")
            if health.get("status") != "healthy":
                print("WARNING: Server not fully healthy, results may be unreliable")
    except Exception as e:
        print(f"ERROR: Cannot reach server at {base_url}: {e}")
        sys.exit(1)

    # Parse concurrency levels
    levels = [int(x) for x in args.levels.split(",")]

    # Determine test modes
    test_modes = []
    if not args.stream_only:
        test_modes.append(False)  # non-streaming
    if not args.no_stream:
        test_modes.append(True)  # streaming

    # GPU monitor
    gpu_monitor = GPUMonitor(interval_seconds=0.5)
    gpu_monitor.start()

    all_results: List[LevelResult] = []

    # Warmup
    print(f"\nWarming up (2 requests)...")
    warmup = await run_level(base_url, num_requests=2, concurrency=1, text_category="short", stream=False)
    if warmup.success_rate < 1.0:
        print(f"  WARNING: Warmup had failures ({warmup.success_rate:.0%} success)")
    else:
        lat = warmup.latency_stats
        print(f"  OK ({lat.get('avg', 0):.0f}ms avg)")

    # Run each level
    for stream in test_modes:
        mode_label = "STREAMING" if stream else "NON-STREAMING"
        print(f"\n{'#' * 72}")
        print(f"  {mode_label} TESTS")
        print(f"{'#' * 72}")

        for conc in levels:
            # Scale requests: enough to get meaningful stats, cap at reasonable total
            num_requests = max(conc * 3, 10)  # At least 3x concurrency or 10
            num_requests = min(num_requests, 200)  # Cap at 200

            text_cat = args.text if args.text != "mixed" else "mixed"

            print(f"\n>>> {conc} concurrent, {num_requests} total requests, text={text_cat}, stream={stream}")

            level_result = await run_level(
                base_url=base_url,
                num_requests=num_requests,
                concurrency=conc,
                text_category=text_cat,
                stream=stream,
                gpu_monitor=gpu_monitor,
            )

            print_level_report(level_result)
            all_results.append(level_result)

            # Brief cooldown between levels
            if conc < levels[-1]:
                print(f"\n  Cooling down 3s...")
                await asyncio.sleep(3)

    gpu_monitor.stop()

    # Summary
    print_summary_table(all_results)

    # Save results to JSON
    output_path = args.output
    if output_path:
        save_results(all_results, output_path)
        print(f"\nResults saved to {output_path}")


def save_results(all_results: List[LevelResult], path: str):
    """Save results as JSON for later analysis."""
    data = []
    for level in all_results:
        data.append({
            "level": level.level_name,
            "concurrency": level.concurrency,
            "total_requests": level.total_requests,
            "mode": level.mode,
            "text_category": level.text_category,
            "success_rate": level.success_rate,
            "throughput_rps": level.throughput_rps,
            "wall_time_s": level.wall_time_seconds,
            "latency_ms": level.latency_stats,
            "ttfb_ms": level.ttfb_stats,
            "effective_rtf": level.effective_rtf,
            "audio_total_seconds": level.audio_duration_total,
            "errors": level.error_summary(),
            "gpu": level.gpu_summary,
        })

    with open(path, "w") as f:
        json.dump(data, f, indent=2, default=str)


def main():
    parser = argparse.ArgumentParser(
        description="Veena3 TTS Local Stress Test",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("--url", default=DEFAULT_URL, help=f"Server URL (default: {DEFAULT_URL})")
    parser.add_argument("--levels", default="1,5,10,20,50", help="Comma-separated concurrency levels (default: 1,5,10,20,50)")
    parser.add_argument("--quick", action="store_true", help="Quick mode: 1,5,10 only")
    parser.add_argument("--text", default="short", choices=["short", "medium", "long", "mixed"], help="Text category (default: short)")
    parser.add_argument("--stream-only", action="store_true", help="Only test streaming")
    parser.add_argument("--no-stream", action="store_true", help="Only test non-streaming")
    parser.add_argument("--output", default="stress_test_results.json", help="Output JSON path")

    args = parser.parse_args()

    if args.quick:
        args.levels = "1,5,10"

    asyncio.run(run_stress_test(args))


if __name__ == "__main__":
    main()
