"""
API key authentication and caching for Veena3 TTS.

Features:
- In-memory API key cache with TTL
- Fast validation (< 5ms target)
- Graceful bypass mode for development
- Header extraction (Bearer token or X-API-Key)

Usage:
    from veena3modal.api.auth import get_api_validator, extract_api_key
    
    validator = get_api_validator()
    api_key = extract_api_key(request.headers)
    result = validator.validate(hash_api_key(api_key))
    
    if not result.is_valid:
        return 401/403 error with result.error_code
"""

import os
import time
import hashlib
from dataclasses import dataclass
from typing import Dict, Any, Optional
from threading import Lock

from veena3modal.shared.logging import get_logger

logger = get_logger(__name__)


def hash_api_key(api_key: str) -> str:
    """
    Hash an API key for storage/lookup.
    
    Uses SHA-256 for consistent, secure hashing.
    
    Args:
        api_key: The raw API key
    
    Returns:
        Hex-encoded hash of the key
    """
    return hashlib.sha256(api_key.encode('utf-8')).hexdigest()


def extract_api_key(headers: Dict[str, str]) -> Optional[str]:
    """
    Extract API key from request headers.
    
    Checks (in order of precedence):
    1. Authorization: Bearer <key>
    2. X-API-Key: <key>
    
    Args:
        headers: Request headers (case-insensitive keys)
    
    Returns:
        API key string or None if not found
    """
    # Normalize headers to lowercase
    normalized = {k.lower(): v for k, v in headers.items()}
    
    # Check Bearer token first
    auth_header = normalized.get("authorization", "")
    if auth_header.lower().startswith("bearer "):
        return auth_header[7:].strip()
    
    # Fall back to X-API-Key
    api_key = normalized.get("x-api-key")
    if api_key:
        return api_key.strip()
    
    return None


@dataclass
class ValidationResult:
    """Result of API key validation."""
    
    is_valid: bool
    user_id: Optional[str] = None
    rate_limit: int = 60  # Default requests per minute
    credits: float = 0.0
    error_code: Optional[str] = None
    error_message: Optional[str] = None


class ApiKeyCache:
    """
    In-memory cache for API key data.
    
    Thread-safe with configurable TTL.
    
    Attributes:
        ttl_seconds: Time-to-live for cache entries (default: 30s)
    """
    
    def __init__(self, ttl_seconds: int = 30):
        """
        Initialize cache.
        
        Args:
            ttl_seconds: Cache entry TTL
        """
        self.ttl_seconds = ttl_seconds
        self._cache: Dict[str, Dict[str, Any]] = {}
        self._timestamps: Dict[str, float] = {}
        self._lock = Lock()
    
    def get(self, key_hash: str) -> Optional[Dict[str, Any]]:
        """
        Get cached key data.
        
        Args:
            key_hash: Hashed API key
        
        Returns:
            Key data dict or None if not found/expired
        """
        with self._lock:
            if key_hash not in self._cache:
                return None
            
            # Check expiry
            cached_at = self._timestamps.get(key_hash, 0)
            if time.time() - cached_at > self.ttl_seconds:
                # Expired - remove and return None
                del self._cache[key_hash]
                del self._timestamps[key_hash]
                return None
            
            return self._cache[key_hash]
    
    def set(self, key_hash: str, data: Dict[str, Any]) -> None:
        """
        Cache key data.
        
        Args:
            key_hash: Hashed API key
            data: Key data to cache
        """
        with self._lock:
            self._cache[key_hash] = data
            self._timestamps[key_hash] = time.time()
    
    def invalidate(self, key_hash: str) -> None:
        """Remove a key from cache."""
        with self._lock:
            self._cache.pop(key_hash, None)
            self._timestamps.pop(key_hash, None)
    
    def clear(self) -> None:
        """Clear all cached entries."""
        with self._lock:
            self._cache.clear()
            self._timestamps.clear()


class ApiKeyValidator:
    """
    API key validator with caching.
    
    Validates keys against cached data, checking:
    - Key exists and is active
    - Sufficient credits available
    
    Attributes:
        cache: ApiKeyCache instance
        bypass_mode: If True, all keys are considered valid
    """
    
    def __init__(
        self,
        cache: Optional[ApiKeyCache] = None,
        bypass_mode: bool = False,
    ):
        """
        Initialize validator.
        
        Args:
            cache: Cache instance (creates new if None)
            bypass_mode: Skip validation (for development)
        """
        self.cache = cache or ApiKeyCache()
        self.bypass_mode = bypass_mode
    
    def validate(self, key_hash: str) -> ValidationResult:
        """
        Validate an API key.
        
        Args:
            key_hash: Hashed API key
        
        Returns:
            ValidationResult with is_valid flag and details
        """
        start = time.time()
        
        # Bypass mode for development
        if self.bypass_mode:
            return ValidationResult(
                is_valid=True,
                user_id="bypass_user",
                rate_limit=9999,
                credits=999999.0,
            )
        
        # Look up in cache
        key_data = self.cache.get(key_hash)
        
        if key_data is None:
            # Cache miss - key not found
            # In production, this would trigger a DB lookup
            # For now, treat as invalid
            duration_ms = (time.time() - start) * 1000
            logger.info(
                "api_key_validation",
                extra={
                    "valid": False,
                    "reason": "not_found",
                    "duration_ms": round(duration_ms, 2),
                }
            )
            return ValidationResult(
                is_valid=False,
                error_code="INVALID_API_KEY",
                error_message="API key not found or invalid",
            )
        
        # Check if key is active
        if not key_data.get("is_active", False):
            duration_ms = (time.time() - start) * 1000
            logger.info(
                "api_key_validation",
                extra={
                    "valid": False,
                    "reason": "inactive",
                    "user_id": key_data.get("user_id"),
                    "duration_ms": round(duration_ms, 2),
                }
            )
            return ValidationResult(
                is_valid=False,
                user_id=key_data.get("user_id"),
                error_code="INVALID_API_KEY",
                error_message="API key is inactive",
            )
        
        # Check credits
        credits = key_data.get("credits", 0.0)
        if credits <= 0:
            duration_ms = (time.time() - start) * 1000
            logger.info(
                "api_key_validation",
                extra={
                    "valid": False,
                    "reason": "no_credits",
                    "user_id": key_data.get("user_id"),
                    "credits": credits,
                    "duration_ms": round(duration_ms, 2),
                }
            )
            return ValidationResult(
                is_valid=False,
                user_id=key_data.get("user_id"),
                credits=credits,
                error_code="INSUFFICIENT_CREDITS",
                error_message="Insufficient credits",
            )
        
        # Valid key
        duration_ms = (time.time() - start) * 1000
        logger.info(
            "api_key_validation",
            extra={
                "valid": True,
                "user_id": key_data.get("user_id"),
                "credits": credits,
                "duration_ms": round(duration_ms, 2),
            }
        )
        
        return ValidationResult(
            is_valid=True,
            user_id=key_data.get("user_id"),
            rate_limit=key_data.get("rate_limit", 60),
            credits=credits,
        )


# Module-level singleton
_api_validator: Optional[ApiKeyValidator] = None


def get_api_validator() -> ApiKeyValidator:
    """
    Get the singleton API key validator.
    
    Configuration via environment variables:
    - AUTH_BYPASS_MODE: "true" to skip validation (dev only)
    - AUTH_CACHE_TTL: Cache TTL in seconds (default: 30)
    
    Returns:
        Configured ApiKeyValidator instance
    """
    global _api_validator
    
    if _api_validator is None:
        bypass_mode = os.environ.get("AUTH_BYPASS_MODE", "false").lower() == "true"
        cache_ttl = int(os.environ.get("AUTH_CACHE_TTL", "30"))
        
        cache = ApiKeyCache(ttl_seconds=cache_ttl)
        _api_validator = ApiKeyValidator(cache=cache, bypass_mode=bypass_mode)
        
        logger.info(
            "api_validator_initialized",
            extra={
                "bypass_mode": bypass_mode,
                "cache_ttl": cache_ttl,
            }
        )
    
    return _api_validator


def reset_api_validator() -> None:
    """Reset the singleton (for testing)."""
    global _api_validator
    _api_validator = None

