"""
Error handling and GPU fault recovery for Veena3 TTS.

Features:
- Standardized error responses
- GPU fault detection and container draining
- Graceful degradation for missing configs

Usage:
    from veena3modal.api.error_handlers import (
        handle_gpu_fault,
        create_error_response,
        ErrorCode,
    )
"""

import os
from enum import Enum
from typing import Dict, Any, Optional

from veena3modal.shared.logging import get_logger

logger = get_logger(__name__)


class ErrorCode(str, Enum):
    """Standard error codes for TTS API."""
    
    # Authentication errors
    INVALID_API_KEY = "INVALID_API_KEY"
    EXPIRED_API_KEY = "EXPIRED_API_KEY"
    INSUFFICIENT_CREDITS = "INSUFFICIENT_CREDITS"
    
    # Rate limiting
    RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
    
    # Validation errors
    VALIDATION_ERROR = "VALIDATION_ERROR"
    TEXT_TOO_LONG = "TEXT_TOO_LONG"
    INVALID_SPEAKER = "INVALID_SPEAKER"
    INVALID_FORMAT = "INVALID_FORMAT"
    
    # Infrastructure errors
    MODEL_NOT_LOADED = "MODEL_NOT_LOADED"
    GPU_FAULT = "GPU_FAULT"
    GPU_OOM = "GPU_OOM"
    GENERATION_FAILED = "GENERATION_FAILED"
    GENERATION_TIMEOUT = "GENERATION_TIMEOUT"
    STREAMING_ERROR = "STREAMING_ERROR"
    
    # Service errors
    SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
    INTERNAL_ERROR = "INTERNAL_ERROR"
    FORMAT_NOT_IMPLEMENTED = "FORMAT_NOT_IMPLEMENTED"


# HTTP status codes for each error
ERROR_STATUS_CODES: Dict[ErrorCode, int] = {
    ErrorCode.INVALID_API_KEY: 401,
    ErrorCode.EXPIRED_API_KEY: 401,
    ErrorCode.INSUFFICIENT_CREDITS: 403,
    ErrorCode.RATE_LIMIT_EXCEEDED: 429,
    ErrorCode.VALIDATION_ERROR: 400,
    ErrorCode.TEXT_TOO_LONG: 400,
    ErrorCode.INVALID_SPEAKER: 400,
    ErrorCode.INVALID_FORMAT: 400,
    ErrorCode.MODEL_NOT_LOADED: 503,
    ErrorCode.GPU_FAULT: 503,
    ErrorCode.GPU_OOM: 503,
    ErrorCode.GENERATION_FAILED: 500,
    ErrorCode.GENERATION_TIMEOUT: 504,
    ErrorCode.STREAMING_ERROR: 500,
    ErrorCode.SERVICE_UNAVAILABLE: 503,
    ErrorCode.INTERNAL_ERROR: 500,
    ErrorCode.FORMAT_NOT_IMPLEMENTED: 501,
}


def create_error_response(
    code: ErrorCode,
    message: str,
    request_id: str,
    details: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    """
    Create a standardized error response body.
    
    Args:
        code: Error code enum
        message: Human-readable error message
        request_id: Request identifier for tracking
        details: Optional additional details
    
    Returns:
        Error response dictionary
    """
    response = {
        "error": {
            "code": code.value,
            "message": message,
            "request_id": request_id,
        }
    }
    
    if details:
        response["error"]["details"] = details
    
    return response


def get_error_status(code: ErrorCode) -> int:
    """Get HTTP status code for an error code."""
    return ERROR_STATUS_CODES.get(code, 500)


# GPU fault detection
_gpu_fault_detected = False


def is_gpu_fault(exception: Exception) -> bool:
    """
    Check if an exception indicates a GPU fault.
    
    Args:
        exception: The caught exception
    
    Returns:
        True if this is a GPU-related fault
    """
    error_msg = str(exception).lower()
    
    gpu_fault_indicators = [
        "cuda error",
        "cudnn error",
        "out of memory",
        "oom",
        "illegal memory access",
        "device-side assert",
        "nccl error",
        "gpu",
    ]
    
    return any(indicator in error_msg for indicator in gpu_fault_indicators)


def handle_gpu_fault(exception: Exception, request_id: str) -> None:
    """
    Handle a GPU fault by logging and signaling container drain.
    
    In Modal, this should call `modal.experimental.stop_fetching_inputs()`
    to prevent the container from accepting more requests while it drains.
    
    Args:
        exception: The GPU fault exception
        request_id: Request ID for logging
    """
    global _gpu_fault_detected
    
    logger.error(
        "gpu_fault_detected",
        extra={
            "request_id": request_id,
            "exception": str(exception),
            "exception_type": type(exception).__name__,
        }
    )
    
    _gpu_fault_detected = True
    
    # Try to signal Modal to stop fetching inputs
    try:
        import modal
        if hasattr(modal, 'experimental') and hasattr(modal.experimental, 'stop_fetching_inputs'):
            modal.experimental.stop_fetching_inputs()
            logger.info(
                "container_drain_initiated",
                extra={"request_id": request_id},
            )
    except ImportError:
        # Not running in Modal environment
        pass
    except Exception as e:
        logger.warning(
            "container_drain_failed",
            extra={
                "request_id": request_id,
                "error": str(e),
            }
        )


def has_gpu_fault() -> bool:
    """Check if a GPU fault has been detected in this container."""
    return _gpu_fault_detected


def reset_gpu_fault_flag() -> None:
    """Reset GPU fault flag (for testing)."""
    global _gpu_fault_detected
    _gpu_fault_detected = False


# Graceful degradation helpers

def get_optional_config(
    env_var: str,
    default: Any = None,
    required_for: str = "",
) -> Any:
    """
    Get an optional configuration value with graceful degradation.
    
    Logs a warning if the config is missing but doesn't crash.
    
    Args:
        env_var: Environment variable name
        default: Default value if not set
        required_for: Feature name (for warning message)
    
    Returns:
        Config value or default
    """
    value = os.environ.get(env_var)
    
    if value is None and required_for:
        logger.warning(
            "optional_config_missing",
            extra={
                "env_var": env_var,
                "feature": required_for,
                "using_default": str(default),
            }
        )
    
    return value if value is not None else default


def check_required_config(
    env_var: str,
    feature_name: str,
) -> Optional[str]:
    """
    Check if a required config is available.
    
    Unlike get_optional_config, this doesn't provide a default.
    Returns None if missing, allowing the caller to handle gracefully.
    
    Args:
        env_var: Environment variable name
        feature_name: Feature that requires this config
    
    Returns:
        Config value or None
    """
    value = os.environ.get(env_var)
    
    if value is None:
        logger.warning(
            "required_config_missing",
            extra={
                "env_var": env_var,
                "feature": feature_name,
            }
        )
    
    return value


class FeatureFlags:
    """
    Feature flags for graceful degradation.
    
    Checks environment to determine which features are available.
    """
    
    @staticmethod
    def is_supabase_enabled() -> bool:
        """Check if Supabase is configured."""
        url = os.environ.get("SUPABASE_URL")
        key = os.environ.get("SUPABASE_SERVICE_KEY") or os.environ.get("SUPABASE_KEY")
        return bool(url and key)
    
    @staticmethod
    def is_redis_enabled() -> bool:
        """Check if Redis is configured."""
        return bool(os.environ.get("REDIS_URL"))
    
    @staticmethod
    def is_auth_enabled() -> bool:
        """Check if authentication is enabled."""
        bypass = os.environ.get("AUTH_BYPASS_MODE", "false").lower()
        return bypass != "true"
    
    @staticmethod
    def is_rate_limiting_enabled() -> bool:
        """Check if rate limiting is enabled."""
        enabled = os.environ.get("RATE_LIMIT_ENABLED", "true").lower()
        return enabled == "true"
    
    @staticmethod
    def is_metrics_enabled() -> bool:
        """Check if Prometheus metrics are enabled."""
        return os.environ.get("PROMETHEUS_ENABLED", "true").lower() == "true"

