#!/usr/bin/env python3
"""
Production Stress Test - Real-world sentences for Veena3 TTS.

Uses actual production-quality sentences across:
- Multiple languages (English, Hindi, Telugu)
- Various lengths (short greetings to long paragraphs)
- Emotion tags
- Mixed scripts
- Numbers, dates, special characters

Runs the same concurrency ramp as the baseline test for direct comparison.
"""

import argparse
import asyncio
import json
import os
import statistics
import subprocess
import sys
import time
import threading
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any

import httpx

DEFAULT_URL = "http://localhost:8000"
GENERATE_PATH = "/v1/tts/generate"

# === Real production sentences ===
# These mimic actual API traffic: mixed languages, emotions, varied lengths

REAL_SENTENCES = {
    "en_short": [
        "Welcome to Veena, your personal voice assistant.",
        "The weather today is sunny with a high of thirty two degrees.",
        "Your order has been confirmed and will arrive by tomorrow.",
        "Please hold on, I am transferring you to a specialist.",
        "Thank you for calling. Have a wonderful day!",
    ],
    "en_medium": [
        "Good morning! I hope you are having a great start to your day. Let me walk you through the key highlights from yesterday's meeting.",
        "The quarterly revenue report shows a fifteen percent increase compared to last year, driven primarily by growth in the South Asian market.",
        "I would like to remind you that your subscription is expiring on March fifteenth. Would you like me to help you with the renewal process?",
        "According to the latest research, artificial intelligence is expected to contribute over fifteen trillion dollars to the global economy by twenty thirty.",
        "[excited] Congratulations! You have been selected for our premium membership program. Let me tell you about the amazing benefits.",
    ],
    "en_long": [
        "In the heart of Hyderabad, where the old city meets the new, there is a small bookshop that has been running for over fifty years. The owner, Ramesh uncle, knows every book by heart. He can tell you exactly which shelf to find your favorite author on, and he always has a recommendation ready. Last Tuesday, a young girl walked in looking for a book on astronomy. Ramesh uncle smiled and led her to the science section.",
        "The Indian space program has achieved remarkable milestones in the past decade. From the Mars Orbiter Mission in twenty fourteen to the Chandrayaan three lunar landing in twenty twenty three, ISRO has consistently demonstrated that world class space exploration does not require a massive budget. The upcoming Gaganyaan mission will mark India's first crewed spaceflight.",
    ],
    "hi_short": [
        "नमस्ते! आज का दिन बहुत अच्छा है।",
        "कृपया अपना नाम और फोन नंबर बताइए।",
        "आपका ऑर्डर सफलतापूर्वक प्लेस हो गया है।",
        "धन्यवाद! आपकी कॉल हमारे लिए महत्वपूर्ण है।",
    ],
    "hi_medium": [
        "[curious] क्या आप जानते हैं कि भारत दुनिया का सबसे बड़ा लोकतंत्र है? यहाँ एक सौ चालीस करोड़ से अधिक लोग रहते हैं।",
        "आज के मौसम का हाल बताते हैं। दिल्ली में तापमान बत्तीस डिग्री रहेगा और हल्की बारिश की संभावना है।",
    ],
    "te_short": [
        "నమస్కారం! మీరు ఎలా ఉన్నారు?",
        "మీ ఆర్డర్ విజయవంతంగా ప్లేస్ చేయబడింది।",
        "దయచేసి కొంచెం ఆగండి, మేము మీకు సహాయం చేస్తాము.",
    ],
    "emotion_mixed": [
        "[laughs] Oh that is so funny! I cannot believe you said that.",
        "[whispers] Let me tell you a secret. The new product launch is happening next week.",
        "[excited] Amazing news everyone! We just crossed one million users!",
        "[sighs] Unfortunately, the flight has been delayed by two hours. I apologize for the inconvenience.",
        "[angry] This is unacceptable! The service was supposed to be restored by noon today.",
    ],
}

# Flatten and tag for round-robin selection
ALL_SENTENCES = []
for category, sentences in REAL_SENTENCES.items():
    for s in sentences:
        ALL_SENTENCES.append({"text": s, "category": category, "length": len(s)})

SPEAKERS = ["lipakshi", "vardan", "reet", "Nandini", "krishna", "anika",
            "adarsh", "Nilay", "Aarvi", "Asha", "Bittu", "Mira"]


# === GPU Monitor (reused from stress_test_local.py) ===

@dataclass
class GPUSnapshot:
    timestamp: float
    memory_used_mb: float
    memory_total_mb: float
    gpu_utilization_pct: float
    temperature_c: float


class GPUMonitor:
    def __init__(self, interval_seconds: float = 1.0):
        self.interval = interval_seconds
        self.snapshots: List[GPUSnapshot] = []
        self._running = False
        self._thread = None

    def start(self):
        self._running = True
        self._thread = threading.Thread(target=self._poll, daemon=True)
        self._thread.start()

    def stop(self):
        self._running = False
        if self._thread:
            self._thread.join(timeout=5)

    def _poll(self):
        while self._running:
            try:
                r = subprocess.run(
                    ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu,temperature.gpu",
                     "--format=csv,noheader,nounits"],
                    capture_output=True, text=True, timeout=5)
                if r.returncode == 0:
                    parts = r.stdout.strip().split(", ")
                    if len(parts) >= 4:
                        self.snapshots.append(GPUSnapshot(
                            timestamp=time.time(),
                            memory_used_mb=float(parts[0]),
                            memory_total_mb=float(parts[1]),
                            gpu_utilization_pct=float(parts[2]),
                            temperature_c=float(parts[3]),
                        ))
            except Exception:
                pass
            time.sleep(self.interval)

    def summary(self):
        if not self.snapshots:
            return {}
        mem = [s.memory_used_mb for s in self.snapshots]
        util = [s.gpu_utilization_pct for s in self.snapshots]
        return {
            "memory_mb_avg": statistics.mean(mem),
            "memory_mb_peak": max(mem),
            "memory_mb_total": self.snapshots[0].memory_total_mb,
            "gpu_util_avg": statistics.mean(util),
            "gpu_util_peak": max(util),
            "temp_avg": statistics.mean(s.temperature_c for s in self.snapshots),
            "temp_peak": max(s.temperature_c for s in self.snapshots),
        }


# === Request Result ===

@dataclass
class RequestResult:
    success: bool
    status_code: int
    latency_ms: float
    ttfb_ms: float = 0.0
    audio_bytes: int = 0
    audio_seconds: float = 0.0
    error: Optional[str] = None
    text_length: int = 0
    category: str = ""
    stream: bool = False


# === HTTP Client ===

async def make_request(client, base_url, text, speaker, stream=False, timeout=120.0):
    url = f"{base_url}{GENERATE_PATH}"
    payload = {"text": text, "speaker": speaker, "stream": stream, "format": "wav"}
    start = time.time()

    try:
        if stream:
            async with client.stream("POST", url, json=payload,
                                     headers={"Content-Type": "application/json"},
                                     timeout=timeout) as resp:
                first_byte = True
                ttfb = 0.0
                total = 0
                async for chunk in resp.aiter_bytes():
                    if first_byte:
                        ttfb = (time.time() - start) * 1000
                        first_byte = False
                    total += len(chunk)
                latency = (time.time() - start) * 1000
                if resp.status_code == 200:
                    pcm = max(0, total - 44)
                    return RequestResult(True, 200, latency, ttfb, total,
                                        pcm / (16000 * 2), text_length=len(text), stream=True)
                return RequestResult(False, resp.status_code, latency, ttfb,
                                     error=f"HTTP {resp.status_code}", text_length=len(text), stream=True)
        else:
            resp = await client.post(url, json=payload,
                                     headers={"Content-Type": "application/json"},
                                     timeout=timeout)
            latency = (time.time() - start) * 1000
            audio_sec = 0.0
            h = resp.headers.get("X-Audio-Seconds")
            if h:
                audio_sec = float(h)
            elif resp.status_code == 200:
                audio_sec = max(0, len(resp.content) - 44) / (16000 * 2)
            if resp.status_code == 200:
                return RequestResult(True, 200, latency, latency, len(resp.content),
                                     audio_sec, text_length=len(text), stream=False)
            return RequestResult(False, resp.status_code, latency,
                                 error=resp.text[:200], text_length=len(text), stream=False)
    except httpx.TimeoutException:
        return RequestResult(False, 0, (time.time() - start) * 1000, error="TIMEOUT",
                             text_length=len(text), stream=stream)
    except Exception as e:
        return RequestResult(False, 0, (time.time() - start) * 1000, error=str(e)[:200],
                             text_length=len(text), stream=stream)


def percentile(values, pct):
    if not values:
        return 0
    s = sorted(values)
    return s[min(int(len(s) * pct / 100), len(s) - 1)]


async def run_level(base_url, num_requests, concurrency, stream, gpu_monitor=None):
    sem = asyncio.Semaphore(concurrency)
    results = []

    async def req(i, client):
        async with sem:
            item = ALL_SENTENCES[i % len(ALL_SENTENCES)]
            speaker = SPEAKERS[i % len(SPEAKERS)]
            r = await make_request(client, base_url, item["text"], speaker, stream)
            r.category = item["category"]
            return r

    if gpu_monitor:
        gpu_monitor.snapshots.clear()

    start = time.time()
    limits = httpx.Limits(max_connections=concurrency + 10, max_keepalive_connections=concurrency)
    async with httpx.AsyncClient(limits=limits) as client:
        tasks = [req(i, client) for i in range(num_requests)]
        raw = await asyncio.gather(*tasks, return_exceptions=True)
    wall = time.time() - start

    for r in raw:
        if isinstance(r, Exception):
            results.append(RequestResult(False, 0, 0, error=str(r)[:200], stream=stream))
        else:
            results.append(r)

    successes = [r for r in results if r.success]
    failures = [r for r in results if not r.success]
    lats = [r.latency_ms for r in successes]
    ttfbs = [r.ttfb_ms for r in successes if r.ttfb_ms > 0]
    audio_total = sum(r.audio_seconds for r in successes)
    rps = num_requests / wall if wall > 0 else 0
    rtf = wall / audio_total if audio_total > 0 else 0

    gpu = gpu_monitor.summary() if gpu_monitor else {}
    mode = "streaming" if stream else "non-streaming"
    label = f"{concurrency}c-production-{mode}"

    print(f"\n{'=' * 72}")
    print(f"  {label.upper()}")
    print(f"  {num_requests} requests, {concurrency} concurrent, {mode}")
    print(f"  Sentences: real production ({len(ALL_SENTENCES)} unique, multilingual)")
    print(f"{'=' * 72}")
    print(f"  Success:     {len(successes)}/{num_requests} ({len(successes)/num_requests:.0%})")
    print(f"  Failures:    {len(failures)}")
    print(f"  Wall time:   {wall:.2f}s")
    print(f"  Throughput:  {rps:.2f} req/s")
    print(f"  Audio total: {audio_total:.1f}s")
    print(f"  Eff. RTF:    {rtf:.3f}")

    if lats:
        print(f"\n  Latency (ms):")
        print(f"    min={min(lats):.0f}  avg={statistics.mean(lats):.0f}  "
              f"p50={percentile(lats,50):.0f}  p95={percentile(lats,95):.0f}  "
              f"p99={percentile(lats,99):.0f}  max={max(lats):.0f}")

    if ttfbs and stream:
        print(f"  TTFB (ms):")
        print(f"    min={min(ttfbs):.0f}  avg={statistics.mean(ttfbs):.0f}  "
              f"p50={percentile(ttfbs,50):.0f}  p95={percentile(ttfbs,95):.0f}  max={max(ttfbs):.0f}")

    if gpu:
        print(f"\n  GPU:")
        print(f"    Memory: {gpu['memory_mb_avg']:.0f}MB avg, {gpu['memory_mb_peak']:.0f}MB peak / {gpu['memory_mb_total']:.0f}MB")
        print(f"    Util:   {gpu['gpu_util_avg']:.0f}% avg, {gpu['gpu_util_peak']:.0f}% peak")
        print(f"    Temp:   {gpu['temp_avg']:.0f}C avg, {gpu['temp_peak']:.0f}C peak")

    errors = {}
    for r in failures:
        k = (r.error or f"HTTP {r.status_code}")[:60]
        errors[k] = errors.get(k, 0) + 1
    if errors:
        print(f"\n  Errors:")
        for e, c in sorted(errors.items(), key=lambda x: -x[1]):
            print(f"    [{c}x] {e}")

    return {
        "level": label, "concurrency": concurrency, "requests": num_requests,
        "mode": mode, "success_rate": len(successes) / num_requests if num_requests else 0,
        "throughput_rps": rps, "wall_time_s": wall, "eff_rtf": rtf,
        "audio_total_s": audio_total,
        "latency_ms": {"avg": statistics.mean(lats), "p50": percentile(lats, 50),
                        "p95": percentile(lats, 95), "p99": percentile(lats, 99),
                        "max": max(lats)} if lats else {},
        "ttfb_ms": {"avg": statistics.mean(ttfbs), "p50": percentile(ttfbs, 50),
                     "p95": percentile(ttfbs, 95), "max": max(ttfbs)} if ttfbs else {},
        "gpu": gpu, "errors": errors,
    }


async def main_async(args):
    base_url = args.url
    print(f"Checking server at {base_url}...")
    try:
        async with httpx.AsyncClient() as c:
            r = await c.get(f"{base_url}/v1/tts/health", timeout=10)
            h = r.json()
            print(f"  Status: {h.get('status')}, Model: {h.get('model_version')}, GPU: {h.get('gpu_available')}")
    except Exception as e:
        print(f"ERROR: {e}")
        sys.exit(1)

    levels = [int(x) for x in args.levels.split(",")]
    gpu_monitor = GPUMonitor(interval_seconds=0.5)
    gpu_monitor.start()

    all_results = []

    # Warmup with 3 real sentences
    print(f"\nWarming up (3 requests with real sentences)...")
    await run_level(base_url, 3, 1, stream=False, gpu_monitor=None)

    for stream in [False, True]:
        mode = "STREAMING" if stream else "NON-STREAMING"
        print(f"\n{'#' * 72}")
        print(f"  {mode} - PRODUCTION SENTENCES")
        print(f"{'#' * 72}")

        for conc in levels:
            n = max(conc * 3, 10)
            n = min(n, 200)
            print(f"\n>>> {conc} concurrent, {n} total, real sentences, stream={stream}")
            result = await run_level(base_url, n, conc, stream, gpu_monitor)
            all_results.append(result)

            if conc < levels[-1]:
                print(f"\n  Cooling down 3s...")
                await asyncio.sleep(3)

    gpu_monitor.stop()

    # Summary table
    print(f"\n{'=' * 105}")
    print(f"  PRODUCTION STRESS TEST SUMMARY (Tier 1 + Tier 2 Optimizations)")
    print(f"{'=' * 105}")
    header = f"{'Level':<32} {'OK%':>5} {'RPS':>7} {'p50':>7} {'p95':>7} {'p99':>7} {'RTF':>6} {'GPU%':>5} {'Mem':>7}"
    print(header)
    print("-" * 105)
    for r in all_results:
        lat = r.get("latency_ms", {})
        gpu = r.get("gpu", {})
        print(f"{r['level']:<32} "
              f"{r['success_rate']:>4.0%} "
              f"{r['throughput_rps']:>7.2f} "
              f"{lat.get('p50',0):>7.0f} "
              f"{lat.get('p95',0):>7.0f} "
              f"{lat.get('p99',0):>7.0f} "
              f"{r['eff_rtf']:>6.3f} "
              f"{gpu.get('gpu_util_avg',0):>4.0f}% "
              f"{gpu.get('memory_mb_peak',0):>6.0f}M")
    print("=" * 105)

    # Streaming TTFB comparison
    stream_results = [r for r in all_results if r["mode"] == "streaming"]
    if stream_results:
        print(f"\n  STREAMING TTFB DETAIL:")
        print(f"  {'Level':<32} {'TTFB avg':>10} {'TTFB p50':>10} {'TTFB p95':>10} {'TTFB max':>10}")
        print(f"  {'-'*75}")
        for r in stream_results:
            t = r.get("ttfb_ms", {})
            print(f"  {r['level']:<32} "
                  f"{t.get('avg',0):>9.0f}ms "
                  f"{t.get('p50',0):>9.0f}ms "
                  f"{t.get('p95',0):>9.0f}ms "
                  f"{t.get('max',0):>9.0f}ms")

    # Save results
    output = args.output
    with open(output, "w") as f:
        json.dump(all_results, f, indent=2, default=str)
    print(f"\nResults saved to {output}")


def main():
    parser = argparse.ArgumentParser(description="Veena3 Production Stress Test")
    parser.add_argument("--url", default=DEFAULT_URL)
    parser.add_argument("--levels", default="1,5,10,20,50")
    parser.add_argument("--output", default="stress_test_production.json")
    args = parser.parse_args()
    asyncio.run(main_async(args))


if __name__ == "__main__":
    main()
