"""
Edge case tests for concurrent access and thread safety.

Tests:
- Concurrent API calls
- Rate limiter under load
- Cache consistency
"""

import pytest
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor


class TestRateLimiterConcurrency:
    """Test rate limiter under concurrent load."""
    
    def test_rate_limiter_concurrent_same_key(self):
        """Multiple threads hitting same key should respect limit."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=10, window_seconds=60)
        results = []
        
        def check_limit():
            allowed, remaining, _ = limiter.check("shared_key")
            results.append(allowed)
            return allowed
        
        # Run 20 concurrent requests
        with ThreadPoolExecutor(max_workers=20) as executor:
            futures = [executor.submit(check_limit) for _ in range(20)]
            for f in futures:
                f.result()
        
        # Should have exactly 10 allowed, 10 blocked
        allowed_count = sum(1 for r in results if r)
        assert allowed_count == 10, f"Expected 10 allowed, got {allowed_count}"
    
    def test_rate_limiter_concurrent_different_keys(self):
        """Different keys should have independent limits."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=5, window_seconds=60)
        
        def check_limit(key):
            allowed, _, _ = limiter.check(key)
            return allowed
        
        # 10 different keys, 5 requests each
        with ThreadPoolExecutor(max_workers=50) as executor:
            futures = []
            for key_id in range(10):
                for _ in range(5):
                    futures.append(executor.submit(check_limit, f"key_{key_id}"))
            
            results = [f.result() for f in futures]
        
        # All should be allowed (each key gets 5 requests, limit is 5)
        assert all(results), "All requests should be allowed"


class TestApiKeyCacheConcurrency:
    """Test API key cache under concurrent access."""
    
    def test_cache_concurrent_reads(self):
        """Concurrent cache reads should be consistent."""
        from veena3modal.api.auth import ApiKeyCache
        
        cache = ApiKeyCache(ttl_seconds=60)
        cache.set("test_key", {"user_id": "user-123", "credits": 100})
        
        results = []
        
        def read_cache():
            data = cache.get("test_key")
            results.append(data)
            return data
        
        # 100 concurrent reads
        with ThreadPoolExecutor(max_workers=100) as executor:
            futures = [executor.submit(read_cache) for _ in range(100)]
            for f in futures:
                f.result()
        
        # All reads should return the same value
        assert all(r == {"user_id": "user-123", "credits": 100} for r in results)
    
    def test_cache_concurrent_writes_and_reads(self):
        """Concurrent writes and reads should not corrupt data."""
        from veena3modal.api.auth import ApiKeyCache
        
        cache = ApiKeyCache(ttl_seconds=60)
        
        def write_cache(key_id):
            cache.set(f"key_{key_id}", {"id": key_id})
        
        def read_cache(key_id):
            return cache.get(f"key_{key_id}")
        
        # Interleaved writes and reads
        with ThreadPoolExecutor(max_workers=50) as executor:
            write_futures = [executor.submit(write_cache, i) for i in range(25)]
            read_futures = [executor.submit(read_cache, i) for i in range(25)]
            
            for f in write_futures:
                f.result()
            
            # After writes complete, reads should succeed
            results = [f.result() for f in read_futures]
        
        # Should have valid data (some may be None if read before write)
        non_none = [r for r in results if r is not None]
        assert len(non_none) > 0


class TestMetricsConcurrency:
    """Test metrics recording under concurrent load."""
    
    def test_metrics_concurrent_recording(self):
        """Concurrent metrics recording should not lose data."""
        from veena3modal.shared.metrics import (
            get_metrics_registry,
            record_request_received,
        )
        
        # Reset registry
        import veena3modal.shared.metrics as metrics_module
        metrics_module._registry = None
        metrics_module._requests_total = None
        
        registry = get_metrics_registry()
        
        def record_request():
            record_request_received(speaker="lipakshi", stream=False, format="wav")
        
        # 100 concurrent recordings
        with ThreadPoolExecutor(max_workers=100) as executor:
            futures = [executor.submit(record_request) for _ in range(100)]
            for f in futures:
                f.result()
        
        # Metrics should have recorded all requests
        # (We can't easily verify the exact count without exposing internals)


class TestLoggingConcurrency:
    """Test logging under concurrent access."""
    
    def test_request_context_isolation(self):
        """Each async task should have isolated request context."""
        from veena3modal.shared.logging import (
            set_request_context,
            get_request_context,
            clear_request_context,
        )
        
        async def task_with_context(task_id):
            set_request_context(request_id=f"req-{task_id}")
            await asyncio.sleep(0.01)  # Simulate async work
            ctx = get_request_context()
            clear_request_context()
            return ctx.get("request_id")
        
        async def run_tasks():
            tasks = [task_with_context(i) for i in range(10)]
            return await asyncio.gather(*tasks)
        
        results = asyncio.run(run_tasks())
        
        # Each task should have its own request_id
        expected = {f"req-{i}" for i in range(10)}
        actual = set(results)
        assert actual == expected, f"Expected {expected}, got {actual}"


class TestErrorRecovery:
    """Test error recovery in concurrent scenarios."""
    
    def test_rate_limiter_recovers_from_errors(self):
        """Rate limiter should be resilient to internal errors."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=100)
        
        # Normal operation should work
        allowed, _, _ = limiter.check("test_key")
        assert allowed
        
        # After cleanup, should still work
        limiter.cleanup()
        allowed, _, _ = limiter.check("test_key")
        assert allowed
    
    def test_cache_handles_expired_entries(self):
        """Cache should handle entry expiration gracefully."""
        from veena3modal.api.auth import ApiKeyCache
        
        cache = ApiKeyCache(ttl_seconds=1)
        cache.set("expiring_key", {"value": "test"})
        
        # Should exist initially
        assert cache.get("expiring_key") is not None
        
        # Wait for expiry
        time.sleep(1.1)
        
        # Should return None, not error
        assert cache.get("expiring_key") is None
        
        # Should be able to set new value
        cache.set("expiring_key", {"value": "new"})
        assert cache.get("expiring_key") == {"value": "new"}

