"""
Load Testing for Modal TTS Service.

Tests concurrent request handling and measures:
- Throughput (requests per second)
- Latency distribution (p50, p95, p99)
- Error rates under load
- TTFB under concurrent load

Usage:
    export MODAL_ENDPOINT_URL="https://mayaresearch--veena3-tts-ttsservice-serve.modal.run"
    pytest veena3modal/tests/modal_live/test_load.py -v -s
    
    # Or run standalone
    python veena3modal/tests/modal_live/test_load.py
"""

import os
import time
import asyncio
import statistics
from dataclasses import dataclass, field
from typing import List, Optional
import pytest
import httpx

# Configuration
MODAL_ENDPOINT_URL = os.environ.get("MODAL_ENDPOINT_URL", "https://mayaresearch--veena3-tts-ttsservice-serve.modal.run")
MODAL_API_KEY = os.environ.get("MODAL_API_KEY", "test-key")

skip_if_no_endpoint = pytest.mark.skipif(
    not MODAL_ENDPOINT_URL,
    reason="MODAL_ENDPOINT_URL environment variable not set"
)


@dataclass
class RequestResult:
    """Result of a single TTS request."""
    success: bool
    status_code: int
    duration_ms: float
    ttfb_ms: Optional[float] = None
    audio_bytes: int = 0
    error: Optional[str] = None


@dataclass
class LoadTestResult:
    """Aggregated results from a load test."""
    total_requests: int = 0
    successful_requests: int = 0
    failed_requests: int = 0
    total_duration_seconds: float = 0.0
    latencies: List[float] = field(default_factory=list)
    ttfbs: List[float] = field(default_factory=list)
    errors: List[str] = field(default_factory=list)
    
    @property
    def success_rate(self) -> float:
        if self.total_requests == 0:
            return 0.0
        return self.successful_requests / self.total_requests
    
    @property
    def requests_per_second(self) -> float:
        if self.total_duration_seconds == 0:
            return 0.0
        return self.total_requests / self.total_duration_seconds
    
    @property
    def p50_latency(self) -> float:
        if not self.latencies:
            return 0.0
        return statistics.median(self.latencies)
    
    @property
    def p95_latency(self) -> float:
        if not self.latencies:
            return 0.0
        sorted_latencies = sorted(self.latencies)
        idx = int(len(sorted_latencies) * 0.95)
        return sorted_latencies[min(idx, len(sorted_latencies) - 1)]
    
    @property
    def p99_latency(self) -> float:
        if not self.latencies:
            return 0.0
        sorted_latencies = sorted(self.latencies)
        idx = int(len(sorted_latencies) * 0.99)
        return sorted_latencies[min(idx, len(sorted_latencies) - 1)]
    
    @property
    def avg_latency(self) -> float:
        if not self.latencies:
            return 0.0
        return statistics.mean(self.latencies)
    
    def report(self) -> str:
        """Generate human-readable report."""
        lines = [
            "=" * 60,
            "LOAD TEST RESULTS",
            "=" * 60,
            f"Total Requests: {self.total_requests}",
            f"Successful: {self.successful_requests} ({self.success_rate:.1%})",
            f"Failed: {self.failed_requests}",
            f"Duration: {self.total_duration_seconds:.2f}s",
            f"Throughput: {self.requests_per_second:.2f} req/s",
            "",
            "LATENCY (ms):",
            f"  Average: {self.avg_latency:.0f}ms",
            f"  p50: {self.p50_latency:.0f}ms",
            f"  p95: {self.p95_latency:.0f}ms",
            f"  p99: {self.p99_latency:.0f}ms",
        ]
        
        if self.ttfbs:
            avg_ttfb = statistics.mean(self.ttfbs)
            lines.extend([
                "",
                f"TTFB (from headers):",
                f"  Average: {avg_ttfb:.0f}ms",
            ])
        
        if self.errors:
            lines.extend([
                "",
                f"ERRORS ({len(self.errors)} unique):",
            ])
            for error in set(self.errors)[:5]:  # Show first 5 unique errors
                lines.append(f"  - {error[:80]}")
        
        lines.append("=" * 60)
        return "\n".join(lines)


async def make_single_request(
    client: httpx.AsyncClient,
    text: str,
    speaker: str = "Aarvi",
    stream: bool = False,
    timeout: float = 120.0,
) -> RequestResult:
    """Make a single TTS request and measure timing."""
    url = f"{MODAL_ENDPOINT_URL}/v1/tts/generate"
    payload = {
        "text": text,
        "speaker": speaker,
        "stream": stream,
        "format": "wav",
    }
    
    start_time = time.time()
    
    try:
        response = await client.post(
            url,
            json=payload,
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {MODAL_API_KEY}",
            },
            timeout=timeout,
        )
        
        duration_ms = (time.time() - start_time) * 1000
        
        # Extract TTFB from headers if available
        ttfb_ms = None
        ttfb_header = response.headers.get("X-TTFB-ms")
        if ttfb_header:
            try:
                ttfb_ms = float(ttfb_header)
            except ValueError:
                pass
        
        if response.status_code == 200:
            return RequestResult(
                success=True,
                status_code=response.status_code,
                duration_ms=duration_ms,
                ttfb_ms=ttfb_ms,
                audio_bytes=len(response.content),
            )
        else:
            return RequestResult(
                success=False,
                status_code=response.status_code,
                duration_ms=duration_ms,
                ttfb_ms=ttfb_ms,
                error=response.text[:200],
            )
            
    except Exception as e:
        duration_ms = (time.time() - start_time) * 1000
        return RequestResult(
            success=False,
            status_code=0,
            duration_ms=duration_ms,
            error=str(e)[:200],
        )


async def run_load_test(
    num_requests: int,
    concurrency: int,
    text: str = "Hello, this is a load test.",
    speaker: str = "Aarvi",
    stream: bool = False,
) -> LoadTestResult:
    """
    Run load test with specified concurrency.
    
    Args:
        num_requests: Total number of requests to make
        concurrency: Maximum concurrent requests
        text: Text to synthesize
        speaker: Speaker to use
        stream: Whether to use streaming
    
    Returns:
        LoadTestResult with aggregated metrics
    """
    result = LoadTestResult()
    semaphore = asyncio.Semaphore(concurrency)
    
    async def limited_request(client: httpx.AsyncClient) -> RequestResult:
        async with semaphore:
            return await make_single_request(client, text, speaker, stream)
    
    start_time = time.time()
    
    async with httpx.AsyncClient() as client:
        # Create all tasks
        tasks = [limited_request(client) for _ in range(num_requests)]
        
        # Run with progress
        results = await asyncio.gather(*tasks, return_exceptions=True)
    
    result.total_duration_seconds = time.time() - start_time
    result.total_requests = num_requests
    
    for r in results:
        if isinstance(r, Exception):
            result.failed_requests += 1
            result.errors.append(str(r)[:200])
        elif isinstance(r, RequestResult):
            if r.success:
                result.successful_requests += 1
                result.latencies.append(r.duration_ms)
                if r.ttfb_ms:
                    result.ttfbs.append(r.ttfb_ms)
            else:
                result.failed_requests += 1
                if r.error:
                    result.errors.append(r.error)
    
    return result


# === Test Classes ===

@skip_if_no_endpoint
class TestConcurrency1:
    """Load tests with 1 concurrent request (baseline)."""
    
    def test_sequential_5_requests(self):
        """5 sequential requests to establish baseline."""
        result = asyncio.run(run_load_test(
            num_requests=5,
            concurrency=1,
            text="Sequential baseline test.",
        ))
        
        print(f"\n{result.report()}")
        
        assert result.success_rate >= 0.8, f"Too many failures: {result.success_rate:.1%}"
        assert result.avg_latency < 10000, f"Average latency too high: {result.avg_latency:.0f}ms"


@skip_if_no_endpoint
class TestConcurrency10:
    """Load tests with 10 concurrent requests."""
    
    def test_concurrent_10_requests(self):
        """10 concurrent requests."""
        result = asyncio.run(run_load_test(
            num_requests=10,
            concurrency=10,
            text="Concurrent load test with ten requests.",
        ))
        
        print(f"\n{result.report()}")
        
        assert result.success_rate >= 0.7, f"Too many failures: {result.success_rate:.1%}"
        # Under concurrency, latency may be higher
        assert result.p95_latency < 30000, f"p95 latency too high: {result.p95_latency:.0f}ms"


@skip_if_no_endpoint  
class TestConcurrency50:
    """Load tests with 50 concurrent requests."""
    
    @pytest.mark.slow
    def test_concurrent_50_requests(self):
        """50 concurrent requests (stress test)."""
        result = asyncio.run(run_load_test(
            num_requests=50,
            concurrency=50,
            text="High concurrency stress test.",
        ))
        
        print(f"\n{result.report()}")
        
        # Under high load, expect some failures
        assert result.success_rate >= 0.5, f"Too many failures: {result.success_rate:.1%}"


@skip_if_no_endpoint
class TestSustainedLoad:
    """Tests for sustained load over time."""
    
    @pytest.mark.slow
    def test_sustained_20_requests_5_concurrent(self):
        """20 total requests with 5 concurrent (sustained load)."""
        result = asyncio.run(run_load_test(
            num_requests=20,
            concurrency=5,
            text="Sustained load test with moderate concurrency.",
        ))
        
        print(f"\n{result.report()}")
        
        assert result.success_rate >= 0.8, f"Too many failures: {result.success_rate:.1%}"
        print(f"Throughput: {result.requests_per_second:.2f} req/s")


@skip_if_no_endpoint
class TestStreamingLoad:
    """Load tests for streaming mode."""
    
    def test_streaming_5_concurrent(self):
        """5 concurrent streaming requests."""
        result = asyncio.run(run_load_test(
            num_requests=5,
            concurrency=5,
            text="Streaming load test.",
            stream=True,
        ))
        
        print(f"\n{result.report()}")
        
        # Note: streaming may have issues with long text (known bug)
        # Just verify we don't crash
        assert result.total_requests == 5


# === Standalone Runner ===

async def run_full_load_test_suite():
    """Run all load tests and print summary."""
    print("=" * 60)
    print("FULL LOAD TEST SUITE")
    print(f"Endpoint: {MODAL_ENDPOINT_URL}")
    print("=" * 60)
    
    # Warm up
    print("\n🔥 Warming up (1 request)...")
    warmup = await run_load_test(num_requests=1, concurrency=1)
    if warmup.success_rate < 1.0:
        print(f"⚠️  Warmup failed, service may be cold")
    else:
        print(f"✅ Warmup complete ({warmup.avg_latency:.0f}ms)")
    
    # Test configurations
    configs = [
        (5, 1, "Sequential baseline"),
        (10, 5, "Light load"),
        (20, 10, "Medium load"),
        (50, 25, "Heavy load"),
        (100, 50, "Stress test"),
    ]
    
    all_results = []
    
    for total, concurrency, name in configs:
        print(f"\n🧪 Running: {name} ({total} requests, {concurrency} concurrent)...")
        
        result = await run_load_test(
            num_requests=total,
            concurrency=concurrency,
            text="Load test sentence for performance measurement.",
        )
        
        all_results.append((name, result))
        
        print(f"   ✅ {result.successful_requests}/{result.total_requests} succeeded")
        print(f"   ⏱️  p50={result.p50_latency:.0f}ms, p95={result.p95_latency:.0f}ms")
        print(f"   📈 {result.requests_per_second:.2f} req/s")
    
    # Summary
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"{'Test':<25} {'Success':<10} {'p50 (ms)':<10} {'p95 (ms)':<10} {'RPS':<10}")
    print("-" * 60)
    
    for name, result in all_results:
        print(f"{name:<25} {result.success_rate:.0%}       {result.p50_latency:<10.0f} {result.p95_latency:<10.0f} {result.requests_per_second:<10.2f}")
    
    print("=" * 60)


if __name__ == "__main__":
    asyncio.run(run_full_load_test_suite())

