"""
Unit tests for structured logging module.

Tests logging schema, event emission, and request_id tracking.
"""

import pytest
import json
import io
import logging
from unittest.mock import patch, MagicMock
from datetime import datetime


class TestStructuredLogger:
    """Test StructuredLogger class."""
    
    def test_create_logger(self):
        """Create a structured logger with JSON formatting."""
        from veena3modal.shared.logging import get_logger
        
        logger = get_logger("test_service")
        assert logger is not None
        assert logger.name == "test_service"
    
    def test_log_with_request_id(self):
        """Logs should include request_id when provided."""
        from veena3modal.shared.logging import get_logger, log_event, JSONFormatter
        
        # Capture log output with JSON formatter
        stream = io.StringIO()
        handler = logging.StreamHandler(stream)
        handler.setFormatter(JSONFormatter())  # Use JSON formatter
        handler.setLevel(logging.INFO)
        
        logger = get_logger("test_request_id")
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)
        
        log_event(logger, "test_event", request_id="req-123", extra_field="value")
        
        output = stream.getvalue()
        assert "req-123" in output
        assert "test_event" in output
    
    def test_log_without_pii(self):
        """Logs should NOT contain PII (text content)."""
        from veena3modal.shared.logging import get_logger, log_event, JSONFormatter
        
        stream = io.StringIO()
        handler = logging.StreamHandler(stream)
        handler.setFormatter(JSONFormatter())  # Use JSON formatter
        handler.setLevel(logging.INFO)
        
        logger = get_logger("test_no_pii")
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)
        
        # Log with text_length but NOT actual text
        log_event(
            logger, 
            "tts_request_received",
            request_id="req-456",
            text_length=100,
            speaker="lipakshi",
        )
        
        output = stream.getvalue()
        assert "text_length" in output
        assert "100" in output
        # Should NOT contain actual text content
        assert "Hello world" not in output


class TestLifecycleEvents:
    """Test lifecycle event logging."""
    
    def test_request_received_event(self):
        """Log request_received event with required fields."""
        from veena3modal.shared.logging import create_lifecycle_event
        
        event = create_lifecycle_event(
            event_type="request_received",
            request_id="req-001",
            text_length=50,
            speaker="lipakshi",
            stream=True,
            format="wav",
        )
        
        assert event["event"] == "request_received"
        assert event["request_id"] == "req-001"
        assert event["text_length"] == 50
        assert event["speaker"] == "lipakshi"
        assert event["stream"] is True
        assert "timestamp" in event
    
    def test_auth_validated_event(self):
        """Log auth_validated event."""
        from veena3modal.shared.logging import create_lifecycle_event
        
        event = create_lifecycle_event(
            event_type="auth_validated",
            request_id="req-002",
            valid=True,
            api_key_prefix="vn3_",
        )
        
        assert event["event"] == "auth_validated"
        assert event["valid"] is True
    
    def test_first_audio_emitted_event(self):
        """Log first_audio_emitted event with TTFB."""
        from veena3modal.shared.logging import create_lifecycle_event
        
        event = create_lifecycle_event(
            event_type="first_audio_emitted",
            request_id="req-003",
            ttfb_ms=250,
            chunk_size_bytes=4096,
        )
        
        assert event["event"] == "first_audio_emitted"
        assert event["ttfb_ms"] == 250
        assert event["chunk_size_bytes"] == 4096
    
    def test_request_completed_event(self):
        """Log request_completed event with metrics."""
        from veena3modal.shared.logging import create_lifecycle_event
        
        event = create_lifecycle_event(
            event_type="request_completed",
            request_id="req-004",
            status_code=200,
            total_duration_ms=1500,
            audio_duration_seconds=3.5,
            rtf=0.43,
            chunks_sent=5,
        )
        
        assert event["event"] == "request_completed"
        assert event["status_code"] == 200
        assert event["total_duration_ms"] == 1500
        assert event["rtf"] == 0.43
    
    def test_request_failed_event(self):
        """Log request_failed event with error details."""
        from veena3modal.shared.logging import create_lifecycle_event
        
        event = create_lifecycle_event(
            event_type="request_failed",
            request_id="req-005",
            status_code=500,
            error_code="GENERATION_FAILED",
            error_message="Model inference timeout",
        )
        
        assert event["event"] == "request_failed"
        assert event["error_code"] == "GENERATION_FAILED"


class TestJSONFormatter:
    """Test JSON log formatting."""
    
    def test_json_formatter_output(self):
        """Formatter should produce valid JSON."""
        from veena3modal.shared.logging import JSONFormatter
        
        formatter = JSONFormatter()
        
        # Create a log record
        record = logging.LogRecord(
            name="test",
            level=logging.INFO,
            pathname="test.py",
            lineno=10,
            msg="Test message",
            args=(),
            exc_info=None,
        )
        record.request_id = "req-json-001"
        record.custom_field = "value"
        
        output = formatter.format(record)
        
        # Should be valid JSON
        parsed = json.loads(output)
        assert parsed["message"] == "Test message"
        assert parsed["level"] == "INFO"
        assert parsed["request_id"] == "req-json-001"
    
    def test_json_formatter_handles_exceptions(self):
        """Formatter should include exception info."""
        from veena3modal.shared.logging import JSONFormatter
        
        formatter = JSONFormatter()
        
        try:
            raise ValueError("Test error")
        except ValueError:
            import sys
            exc_info = sys.exc_info()
        
        record = logging.LogRecord(
            name="test",
            level=logging.ERROR,
            pathname="test.py",
            lineno=10,
            msg="Error occurred",
            args=(),
            exc_info=exc_info,
        )
        
        output = formatter.format(record)
        parsed = json.loads(output)
        
        assert "exception" in parsed
        assert "ValueError" in parsed["exception"]


class TestRequestContext:
    """Test request context management."""
    
    def test_set_and_get_request_id(self):
        """Set and retrieve request_id from context."""
        from veena3modal.shared.logging import set_request_context, get_request_context, clear_request_context
        
        clear_request_context()
        
        set_request_context(request_id="ctx-001", user_id="user-123")
        
        ctx = get_request_context()
        assert ctx["request_id"] == "ctx-001"
        assert ctx["user_id"] == "user-123"
        
        clear_request_context()
    
    def test_clear_context(self):
        """Clear request context."""
        from veena3modal.shared.logging import set_request_context, get_request_context, clear_request_context
        
        set_request_context(request_id="ctx-002")
        clear_request_context()
        
        ctx = get_request_context()
        assert ctx.get("request_id") is None

