"""
Performance benchmarks for TTS service.

Tests marked with @pytest.mark.slow for CI filtering.
These tests require GPU and model to be loaded.

Run with: pytest veena3modal/tests/performance -v -m slow
"""

import pytest
import time
import asyncio
from typing import List, Tuple


# Skip all tests if model not available
def _can_run_benchmarks() -> Tuple[bool, str]:
    """Check if benchmarks can run."""
    try:
        import torch
        if not torch.cuda.is_available():
            return False, "CUDA not available"
        
        from veena3modal.services import tts_runtime
        if not tts_runtime.is_initialized():
            return False, "TTS runtime not initialized"
        
        return True, ""
    except Exception as e:
        return False, str(e)


CAN_RUN, SKIP_REASON = _can_run_benchmarks()


@pytest.fixture(scope="module")
def initialized_runtime():
    """Ensure runtime is initialized for benchmarks."""
    if not CAN_RUN:
        pytest.skip(f"Benchmarks skipped: {SKIP_REASON}")
    
    from veena3modal.services import tts_runtime
    
    if not tts_runtime.is_initialized():
        pytest.skip("TTS runtime not initialized")
    
    return tts_runtime.get_runtime()


@pytest.mark.slow
@pytest.mark.skipif(not CAN_RUN, reason=SKIP_REASON or "Benchmark prerequisites not met")
class TestTTFBBenchmarks:
    """
    Time-to-First-Byte benchmarks.
    
    Target: < 500ms warm, < 1200ms cold
    """
    
    @pytest.mark.asyncio
    async def test_ttfb_short_text_nonstreaming(self, initialized_runtime):
        """TTFB for short text non-streaming."""
        from veena3modal.services import tts_runtime
        
        text = "Hello, this is a test."
        
        start = time.time()
        audio_bytes, metrics = await tts_runtime.generate_speech(
            text=text,
            speaker="lipakshi",
        )
        ttfb = (time.time() - start) * 1000
        
        assert audio_bytes is not None
        assert ttfb < 2000, f"TTFB too high: {ttfb:.0f}ms"
        print(f"TTFB (short, non-streaming): {ttfb:.0f}ms")
    
    @pytest.mark.asyncio
    async def test_ttfb_streaming_first_chunk(self, initialized_runtime):
        """TTFB for streaming (time to first chunk)."""
        from veena3modal.services import tts_runtime
        
        text = "Hello, this is a streaming test."
        
        start = time.time()
        first_chunk_time = None
        
        async for audio_chunk, metrics in tts_runtime.generate_speech_streaming(
            text=text,
            speaker="lipakshi",
        ):
            if first_chunk_time is None:
                first_chunk_time = (time.time() - start) * 1000
            # Don't break - let it complete
        
        assert first_chunk_time is not None
        assert first_chunk_time < 1500, f"Streaming TTFB too high: {first_chunk_time:.0f}ms"
        print(f"TTFB (streaming, first chunk): {first_chunk_time:.0f}ms")


@pytest.mark.slow
@pytest.mark.skipif(not CAN_RUN, reason=SKIP_REASON or "Benchmark prerequisites not met")
class TestRTFBenchmarks:
    """
    Real-Time Factor benchmarks.
    
    RTF = generation_time / audio_duration
    Target: < 0.5 single request, < 0.8 under load
    """
    
    @pytest.mark.asyncio
    async def test_rtf_short_text(self, initialized_runtime):
        """RTF for short text."""
        from veena3modal.services import tts_runtime
        
        text = "Hello, this is a test sentence for measuring real-time factor."
        
        start = time.time()
        audio_bytes, metrics = await tts_runtime.generate_speech(
            text=text,
            speaker="lipakshi",
        )
        generation_time = time.time() - start
        
        audio_duration = metrics.get("audio_duration_seconds", 0)
        if audio_duration > 0:
            rtf = generation_time / audio_duration
            assert rtf < 1.5, f"RTF too high: {rtf:.2f}"
            print(f"RTF (short): {rtf:.2f} ({audio_duration:.1f}s audio in {generation_time:.1f}s)")
    
    @pytest.mark.asyncio
    async def test_rtf_long_text(self, initialized_runtime):
        """RTF for long text (with chunking)."""
        from veena3modal.services import tts_runtime
        
        # ~500 chars of text
        text = """
        This is a longer piece of text that will test the chunking capabilities
        of our TTS system. The text normalization should handle numbers like 12345
        and dates like January 15, 2024. It should also handle currency like $99.99
        and technical terms like API and HTTP. Let's see how well it performs with
        this amount of text that should trigger the chunking mechanism.
        """
        
        start = time.time()
        audio_bytes, metrics = await tts_runtime.generate_speech_chunked(
            text=text,
            speaker="lipakshi",
        )
        generation_time = time.time() - start
        
        audio_duration = metrics.get("audio_duration_seconds", 0)
        if audio_duration > 0:
            rtf = generation_time / audio_duration
            assert rtf < 2.0, f"RTF too high for long text: {rtf:.2f}"
            print(f"RTF (long, chunked): {rtf:.2f} ({audio_duration:.1f}s audio in {generation_time:.1f}s)")


@pytest.mark.slow
@pytest.mark.skipif(not CAN_RUN, reason=SKIP_REASON or "Benchmark prerequisites not met")
class TestThroughputBenchmarks:
    """
    Throughput benchmarks.
    
    Tests concurrent request handling.
    """
    
    @pytest.mark.asyncio
    async def test_concurrent_requests(self, initialized_runtime):
        """Measure throughput with concurrent requests."""
        from veena3modal.services import tts_runtime
        
        texts = [
            "First test sentence.",
            "Second test sentence.",
            "Third test sentence.",
            "Fourth test sentence.",
            "Fifth test sentence.",
        ]
        
        async def generate(text):
            start = time.time()
            audio_bytes, metrics = await tts_runtime.generate_speech(
                text=text,
                speaker="lipakshi",
            )
            return time.time() - start, metrics.get("audio_duration_seconds", 0)
        
        start = time.time()
        results = await asyncio.gather(*[generate(t) for t in texts])
        total_time = time.time() - start
        
        gen_times = [r[0] for r in results]
        audio_durations = [r[1] for r in results]
        
        total_audio = sum(audio_durations)
        
        print(f"\nConcurrent throughput ({len(texts)} requests):")
        print(f"  Total wall time: {total_time:.1f}s")
        print(f"  Total audio generated: {total_audio:.1f}s")
        print(f"  Avg per-request time: {sum(gen_times)/len(gen_times):.1f}s")
        print(f"  Effective RTF: {total_time/total_audio:.2f}" if total_audio > 0 else "")


@pytest.mark.slow
@pytest.mark.skipif(not CAN_RUN, reason=SKIP_REASON or "Benchmark prerequisites not met")
class TestMemoryBenchmarks:
    """
    Memory usage benchmarks.
    """
    
    def test_gpu_memory_baseline(self, initialized_runtime):
        """Check GPU memory usage at rest."""
        import torch
        
        if not torch.cuda.is_available():
            pytest.skip("CUDA not available")
        
        allocated = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
        reserved = torch.cuda.memory_reserved() / (1024 ** 3)  # GB
        
        print(f"\nGPU Memory (baseline):")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved: {reserved:.2f} GB")
        
        # Should be reasonable for inference
        assert reserved < 40, f"GPU memory too high: {reserved:.2f} GB"
    
    @pytest.mark.asyncio
    async def test_gpu_memory_after_generation(self, initialized_runtime):
        """Check GPU memory after generation."""
        import torch
        from veena3modal.services import tts_runtime
        
        if not torch.cuda.is_available():
            pytest.skip("CUDA not available")
        
        # Generate some audio
        for i in range(3):
            await tts_runtime.generate_speech(
                text=f"Test sentence number {i+1}.",
                speaker="lipakshi",
            )
        
        allocated = torch.cuda.memory_allocated() / (1024 ** 3)
        reserved = torch.cuda.memory_reserved() / (1024 ** 3)
        
        print(f"\nGPU Memory (after 3 generations):")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved: {reserved:.2f} GB")
        
        # Memory should not leak significantly
        assert reserved < 50, f"GPU memory too high after generation: {reserved:.2f} GB"


@pytest.mark.slow
@pytest.mark.skipif(not CAN_RUN, reason=SKIP_REASON or "Benchmark prerequisites not met")
class TestCacheBenchmarks:
    """
    Cache performance benchmarks.
    """
    
    def test_rate_limiter_check_performance(self):
        """Rate limiter check should be < 1ms."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=1000)
        
        times = []
        for i in range(100):
            start = time.time()
            limiter.check(f"key_{i % 10}")
            times.append((time.time() - start) * 1000)
        
        avg_time = sum(times) / len(times)
        max_time = max(times)
        
        print(f"\nRate limiter check performance:")
        print(f"  Avg: {avg_time:.3f}ms")
        print(f"  Max: {max_time:.3f}ms")
        
        assert avg_time < 1, f"Rate limiter too slow: {avg_time:.3f}ms avg"
    
    def test_api_key_cache_performance(self):
        """API key cache lookup should be < 1ms."""
        from veena3modal.api.auth import ApiKeyCache
        
        cache = ApiKeyCache()
        
        # Populate cache
        for i in range(100):
            cache.set(f"key_{i}", {"user_id": f"user-{i}", "credits": 100})
        
        times = []
        for i in range(100):
            start = time.time()
            cache.get(f"key_{i}")
            times.append((time.time() - start) * 1000)
        
        avg_time = sum(times) / len(times)
        max_time = max(times)
        
        print(f"\nAPI key cache lookup performance:")
        print(f"  Avg: {avg_time:.3f}ms")
        print(f"  Max: {max_time:.3f}ms")
        
        assert avg_time < 1, f"Cache lookup too slow: {avg_time:.3f}ms avg"

