"""
Unit tests for error handling and GPU fault recovery.
"""

import pytest
import os
from unittest.mock import patch, MagicMock


class TestErrorResponse:
    """Test error response creation."""
    
    def test_create_error_response(self):
        """Create a standard error response."""
        from veena3modal.api.error_handlers import create_error_response, ErrorCode
        
        response = create_error_response(
            code=ErrorCode.INVALID_API_KEY,
            message="API key not found",
            request_id="req-123",
        )
        
        assert response["error"]["code"] == "INVALID_API_KEY"
        assert response["error"]["message"] == "API key not found"
        assert response["error"]["request_id"] == "req-123"
    
    def test_create_error_response_with_details(self):
        """Create error response with additional details."""
        from veena3modal.api.error_handlers import create_error_response, ErrorCode
        
        response = create_error_response(
            code=ErrorCode.VALIDATION_ERROR,
            message="Invalid input",
            request_id="req-456",
            details={"field": "text", "max_length": 50000},
        )
        
        assert response["error"]["details"]["field"] == "text"
        assert response["error"]["details"]["max_length"] == 50000
    
    def test_get_error_status(self):
        """Get correct HTTP status for error codes."""
        from veena3modal.api.error_handlers import get_error_status, ErrorCode
        
        assert get_error_status(ErrorCode.INVALID_API_KEY) == 401
        assert get_error_status(ErrorCode.RATE_LIMIT_EXCEEDED) == 429
        assert get_error_status(ErrorCode.GPU_FAULT) == 503
        assert get_error_status(ErrorCode.INTERNAL_ERROR) == 500


class TestGpuFaultDetection:
    """Test GPU fault detection."""
    
    def test_detect_cuda_error(self):
        """Detect CUDA errors as GPU faults."""
        from veena3modal.api.error_handlers import is_gpu_fault
        
        assert is_gpu_fault(RuntimeError("CUDA error: out of memory"))
        assert is_gpu_fault(RuntimeError("RuntimeError: CUDA out of memory"))
        assert is_gpu_fault(RuntimeError("cudnn error"))
    
    def test_detect_oom(self):
        """Detect OOM errors as GPU faults."""
        from veena3modal.api.error_handlers import is_gpu_fault
        
        assert is_gpu_fault(MemoryError("out of memory"))
        assert is_gpu_fault(RuntimeError("GPU OOM"))
    
    def test_detect_illegal_memory_access(self):
        """Detect illegal memory access as GPU fault."""
        from veena3modal.api.error_handlers import is_gpu_fault
        
        assert is_gpu_fault(RuntimeError("illegal memory access was encountered"))
    
    def test_non_gpu_error(self):
        """Non-GPU errors should not be detected as GPU faults."""
        from veena3modal.api.error_handlers import is_gpu_fault
        
        assert not is_gpu_fault(ValueError("Invalid input"))
        assert not is_gpu_fault(RuntimeError("Connection refused"))
        assert not is_gpu_fault(Exception("Generic error"))


class TestGpuFaultHandling:
    """Test GPU fault handling and container draining."""
    
    def test_handle_gpu_fault_sets_flag(self):
        """Handling GPU fault should set the fault flag."""
        from veena3modal.api.error_handlers import (
            handle_gpu_fault, has_gpu_fault, reset_gpu_fault_flag
        )
        
        reset_gpu_fault_flag()
        assert not has_gpu_fault()
        
        handle_gpu_fault(RuntimeError("CUDA OOM"), "req-001")
        
        assert has_gpu_fault()
        reset_gpu_fault_flag()
    
    def test_reset_gpu_fault_flag(self):
        """GPU fault flag should be resettable."""
        from veena3modal.api.error_handlers import (
            handle_gpu_fault, has_gpu_fault, reset_gpu_fault_flag
        )
        
        handle_gpu_fault(RuntimeError("GPU error"), "req-002")
        assert has_gpu_fault()
        
        reset_gpu_fault_flag()
        assert not has_gpu_fault()


class TestGracefulDegradation:
    """Test graceful degradation helpers."""
    
    def test_get_optional_config_with_value(self):
        """Get config when environment variable is set."""
        from veena3modal.api.error_handlers import get_optional_config
        
        with patch.dict(os.environ, {"TEST_CONFIG": "test_value"}):
            value = get_optional_config("TEST_CONFIG", default="default")
            assert value == "test_value"
    
    def test_get_optional_config_with_default(self):
        """Get default when environment variable is not set."""
        from veena3modal.api.error_handlers import get_optional_config
        
        # Make sure the env var is not set
        os.environ.pop("NONEXISTENT_CONFIG", None)
        
        value = get_optional_config("NONEXISTENT_CONFIG", default="default_value")
        assert value == "default_value"
    
    def test_check_required_config_present(self):
        """Check required config when present."""
        from veena3modal.api.error_handlers import check_required_config
        
        with patch.dict(os.environ, {"REQUIRED_VAR": "value"}):
            value = check_required_config("REQUIRED_VAR", "test_feature")
            assert value == "value"
    
    def test_check_required_config_missing(self):
        """Check required config when missing."""
        from veena3modal.api.error_handlers import check_required_config
        
        os.environ.pop("MISSING_VAR", None)
        
        value = check_required_config("MISSING_VAR", "test_feature")
        assert value is None


class TestFeatureFlags:
    """Test feature flag checks."""
    
    def test_supabase_enabled(self):
        """Check Supabase enabled when configured."""
        from veena3modal.api.error_handlers import FeatureFlags
        
        with patch.dict(os.environ, {
            "SUPABASE_URL": "https://example.supabase.co",
            "SUPABASE_SERVICE_KEY": "secret_key",
        }):
            assert FeatureFlags.is_supabase_enabled() is True
    
    def test_supabase_disabled(self):
        """Check Supabase disabled when not configured."""
        from veena3modal.api.error_handlers import FeatureFlags
        
        # Clear Supabase env vars
        env = os.environ.copy()
        env.pop("SUPABASE_URL", None)
        env.pop("SUPABASE_SERVICE_KEY", None)
        env.pop("SUPABASE_KEY", None)
        
        with patch.dict(os.environ, env, clear=True):
            assert FeatureFlags.is_supabase_enabled() is False
    
    def test_auth_enabled_by_default(self):
        """Auth should be enabled by default."""
        from veena3modal.api.error_handlers import FeatureFlags
        
        with patch.dict(os.environ, {"AUTH_BYPASS_MODE": "false"}):
            assert FeatureFlags.is_auth_enabled() is True
    
    def test_auth_bypass_mode(self):
        """Auth can be disabled via bypass mode."""
        from veena3modal.api.error_handlers import FeatureFlags
        
        with patch.dict(os.environ, {"AUTH_BYPASS_MODE": "true"}):
            assert FeatureFlags.is_auth_enabled() is False
    
    def test_rate_limiting_enabled_by_default(self):
        """Rate limiting should be enabled by default."""
        from veena3modal.api.error_handlers import FeatureFlags
        
        with patch.dict(os.environ, {}, clear=False):
            assert FeatureFlags.is_rate_limiting_enabled() is True
    
    def test_rate_limiting_disabled(self):
        """Rate limiting can be disabled."""
        from veena3modal.api.error_handlers import FeatureFlags
        
        with patch.dict(os.environ, {"RATE_LIMIT_ENABLED": "false"}):
            assert FeatureFlags.is_rate_limiting_enabled() is False


class TestErrorCodes:
    """Test error code coverage."""
    
    def test_all_error_codes_have_status(self):
        """Every error code should have an HTTP status."""
        from veena3modal.api.error_handlers import ErrorCode, ERROR_STATUS_CODES
        
        for code in ErrorCode:
            assert code in ERROR_STATUS_CODES, f"Missing status for {code}"
    
    def test_error_codes_are_strings(self):
        """Error codes should be string values."""
        from veena3modal.api.error_handlers import ErrorCode
        
        for code in ErrorCode:
            assert isinstance(code.value, str)

