"""
Rate limiting for Veena3 TTS API.

Implements sliding window rate limiting with in-memory storage.
Supports optional Redis backend for distributed rate limiting.

Usage:
    from veena3modal.api.rate_limiter import get_rate_limiter
    
    limiter = get_rate_limiter()
    allowed, remaining, reset_after = limiter.check(api_key_hash)
    if not allowed:
        return 429 response with Retry-After header
"""

import time
import os
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
from threading import Lock

from veena3modal.shared.logging import get_logger

logger = get_logger(__name__)


class RateLimiter:
    """
    In-memory sliding window rate limiter.
    
    Thread-safe implementation with automatic cleanup of old entries.
    
    Attributes:
        requests_per_minute: Maximum requests allowed per window
        window_seconds: Window size in seconds (default: 60)
        enabled: Whether rate limiting is active
    """
    
    def __init__(
        self,
        requests_per_minute: int = 60,
        window_seconds: int = 60,
        enabled: bool = True,
    ):
        """
        Initialize rate limiter.
        
        Args:
            requests_per_minute: Max requests per window
            window_seconds: Window size in seconds
            enabled: If False, all requests are allowed
        """
        self.requests_per_minute = requests_per_minute
        self.window_seconds = window_seconds
        self.enabled = enabled
        
        # {key: [timestamp1, timestamp2, ...]} - list of request timestamps
        self._requests: Dict[str, List[float]] = defaultdict(list)
        self._lock = Lock()
        self._last_cleanup = time.time()
        self._cleanup_interval = 60  # Cleanup every 60 seconds
    
    def check(self, key: str) -> Tuple[bool, int, float]:
        """
        Check if request is allowed under rate limit.
        
        Args:
            key: Unique identifier (API key hash, user ID, IP)
        
        Returns:
            Tuple of (allowed, remaining, reset_after_seconds)
            - allowed: True if request should proceed
            - remaining: Number of requests remaining in window
            - reset_after: Seconds until window resets (only meaningful if blocked)
        """
        if not self.enabled:
            return True, self.requests_per_minute, 0.0
        
        now = time.time()
        window_start = now - self.window_seconds
        
        with self._lock:
            # Cleanup old timestamps for this key
            self._requests[key] = [
                ts for ts in self._requests[key] 
                if ts > window_start
            ]
            
            current_count = len(self._requests[key])
            
            if current_count >= self.requests_per_minute:
                # Rate limited - calculate reset time
                oldest_in_window = min(self._requests[key]) if self._requests[key] else now
                reset_after = oldest_in_window + self.window_seconds - now
                return False, 0, max(0.0, reset_after)
            
            # Allow request - record timestamp
            self._requests[key].append(now)
            remaining = self.requests_per_minute - current_count - 1
            
            # Periodic cleanup of old keys
            if now - self._last_cleanup > self._cleanup_interval:
                self._do_cleanup(window_start)
                self._last_cleanup = now
            
            return True, remaining, 0.0
    
    def _do_cleanup(self, window_start: float) -> None:
        """Remove entries older than the window."""
        keys_to_remove = []
        for key, timestamps in self._requests.items():
            # Filter timestamps in window
            valid = [ts for ts in timestamps if ts > window_start]
            if not valid:
                keys_to_remove.append(key)
            else:
                self._requests[key] = valid
        
        for key in keys_to_remove:
            del self._requests[key]
    
    def cleanup(self) -> None:
        """
        Manual cleanup trigger for testing.
        Removes all expired entries.
        """
        now = time.time()
        window_start = now - self.window_seconds
        
        with self._lock:
            self._do_cleanup(window_start)
    
    def get_headers(
        self, 
        allowed: bool, 
        remaining: int, 
        reset_after: float
    ) -> Dict[str, str]:
        """
        Get rate limit headers for HTTP response.
        
        Args:
            allowed: Whether request was allowed
            remaining: Remaining requests in window
            reset_after: Seconds until reset
        
        Returns:
            Dict of header name -> value
        """
        headers = {
            "X-RateLimit-Limit": str(self.requests_per_minute),
            "X-RateLimit-Remaining": str(max(0, remaining)),
        }
        
        if not allowed:
            headers["Retry-After"] = str(int(reset_after) + 1)
        
        return headers
    
    def reset(self, key: str) -> None:
        """Reset rate limit for a specific key (admin use)."""
        with self._lock:
            if key in self._requests:
                del self._requests[key]


# Module-level singleton
_rate_limiter: Optional[RateLimiter] = None


def get_rate_limiter() -> RateLimiter:
    """
    Get the singleton rate limiter instance.
    
    Configuration via environment variables:
    - RATE_LIMIT_REQUESTS_PER_MINUTE: Max requests (default: 60)
    - RATE_LIMIT_ENABLED: "true"/"false" (default: true)
    
    Returns:
        Configured RateLimiter instance
    """
    global _rate_limiter
    
    if _rate_limiter is None:
        requests_per_minute = int(os.environ.get("RATE_LIMIT_REQUESTS_PER_MINUTE", "60"))
        enabled = os.environ.get("RATE_LIMIT_ENABLED", "true").lower() == "true"
        
        _rate_limiter = RateLimiter(
            requests_per_minute=requests_per_minute,
            enabled=enabled,
        )
        
        logger.info(
            "rate_limiter_initialized",
            extra={
                "requests_per_minute": requests_per_minute,
                "enabled": enabled,
            }
        )
    
    return _rate_limiter


def reset_rate_limiter() -> None:
    """Reset the singleton (for testing)."""
    global _rate_limiter
    _rate_limiter = None

