"""
Unit tests for rate limiting middleware.

Tests in-memory rate limiting with sliding window algorithm.
"""

import pytest
import asyncio
import time
from unittest.mock import patch, MagicMock, AsyncMock


class TestInMemoryRateLimiter:
    """Test in-memory rate limiter."""
    
    def test_create_rate_limiter(self):
        """Create a rate limiter with default config."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=60)
        assert limiter is not None
        assert limiter.requests_per_minute == 60
    
    def test_rate_limiter_allows_first_request(self):
        """First request should always be allowed."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=60)
        allowed, remaining, reset_after = limiter.check("test_key")
        
        assert allowed is True
        assert remaining == 59
    
    def test_rate_limiter_tracks_requests(self):
        """Rate limiter should track request count."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=10)
        
        for i in range(5):
            allowed, remaining, _ = limiter.check("test_key")
            assert allowed is True
            assert remaining == 10 - (i + 1)
    
    def test_rate_limiter_blocks_after_limit(self):
        """Rate limiter should block after limit exceeded."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=5)
        
        # Exhaust the limit
        for _ in range(5):
            allowed, _, _ = limiter.check("test_key")
            assert allowed is True
        
        # Next request should be blocked
        allowed, remaining, reset_after = limiter.check("test_key")
        assert allowed is False
        assert remaining == 0
        assert reset_after > 0
    
    def test_rate_limiter_per_key_isolation(self):
        """Each key should have independent rate limits."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=3)
        
        # Exhaust limit for key1
        for _ in range(3):
            limiter.check("key1")
        
        # key1 blocked, key2 should still work
        allowed1, _, _ = limiter.check("key1")
        allowed2, remaining2, _ = limiter.check("key2")
        
        assert allowed1 is False
        assert allowed2 is True
        assert remaining2 == 2
    
    def test_rate_limiter_window_slides(self):
        """Rate limiter window should slide over time."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        # Use very short window for testing
        limiter = RateLimiter(requests_per_minute=60, window_seconds=1)
        
        # Exhaust limit
        for _ in range(60):
            limiter.check("test_key")
        
        # Should be blocked
        allowed, _, _ = limiter.check("test_key")
        assert allowed is False
        
        # Wait for window to slide
        time.sleep(1.1)
        
        # Should be allowed again
        allowed, remaining, _ = limiter.check("test_key")
        assert allowed is True


class TestRateLimiterHeaders:
    """Test rate limit response headers."""
    
    def test_get_headers_when_allowed(self):
        """Get headers for allowed request."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=60)
        allowed, remaining, reset_after = limiter.check("test_key")
        
        headers = limiter.get_headers(allowed, remaining, reset_after)
        
        assert "X-RateLimit-Limit" in headers
        assert headers["X-RateLimit-Limit"] == "60"
        assert "X-RateLimit-Remaining" in headers
        assert headers["X-RateLimit-Remaining"] == "59"
    
    def test_get_headers_when_blocked(self):
        """Get headers for blocked request."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=1)
        limiter.check("test_key")  # Use up the limit
        
        allowed, remaining, reset_after = limiter.check("test_key")
        headers = limiter.get_headers(allowed, remaining, reset_after)
        
        assert headers["X-RateLimit-Remaining"] == "0"
        assert "Retry-After" in headers


class TestRateLimiterGracefulDegradation:
    """Test rate limiter graceful degradation."""
    
    def test_limiter_works_without_redis(self):
        """Rate limiter should work without Redis (in-memory fallback)."""
        from veena3modal.api.rate_limiter import get_rate_limiter
        
        limiter = get_rate_limiter()
        assert limiter is not None
        
        allowed, _, _ = limiter.check("test_key")
        assert allowed is True
    
    def test_disabled_limiter_always_allows(self):
        """Disabled limiter should always allow requests."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(enabled=False)
        
        # Should allow unlimited requests
        for _ in range(1000):
            allowed, _, _ = limiter.check("test_key")
            assert allowed is True


class TestRateLimiterCleanup:
    """Test rate limiter memory cleanup."""
    
    def test_cleanup_old_entries(self):
        """Rate limiter should cleanup old entries to prevent memory leaks."""
        from veena3modal.api.rate_limiter import RateLimiter
        
        limiter = RateLimiter(requests_per_minute=60, window_seconds=1)
        
        # Add entries for multiple keys
        for i in range(100):
            limiter.check(f"key_{i}")
        
        # Wait for window to expire
        time.sleep(1.1)
        
        # Cleanup should remove old entries
        limiter.cleanup()
        
        # Internal tracking should be reduced
        assert len(limiter._requests) < 100

