"""
Unit tests for Prometheus metrics module.

Tests metric creation, recording, and exposure.
"""

import pytest
from unittest.mock import patch, MagicMock


class TestMetricsRegistry:
    """Test metrics registry and creation."""
    
    def test_get_metrics_registry(self):
        """Get or create metrics registry."""
        from veena3modal.shared.metrics import get_metrics_registry
        
        registry = get_metrics_registry()
        assert registry is not None
    
    def test_registry_is_singleton(self):
        """Registry should be a singleton."""
        from veena3modal.shared.metrics import get_metrics_registry
        
        registry1 = get_metrics_registry()
        registry2 = get_metrics_registry()
        assert registry1 is registry2


class TestRequestMetrics:
    """Test request-related metrics."""
    
    def test_record_request_received(self):
        """Record request received metric."""
        from veena3modal.shared.metrics import record_request_received
        
        # Should not raise
        record_request_received(speaker="lipakshi", stream=True, format="wav")
    
    def test_record_request_completed(self):
        """Record request completed with duration."""
        from veena3modal.shared.metrics import record_request_completed
        
        record_request_completed(
            status_code=200,
            duration_seconds=1.5,
            speaker="lipakshi",
            stream=True,
        )
    
    def test_record_request_failed(self):
        """Record request failed with error code."""
        from veena3modal.shared.metrics import record_request_failed
        
        record_request_failed(
            status_code=500,
            error_code="GENERATION_FAILED",
            speaker="lipakshi",
        )


class TestTTSMetrics:
    """Test TTS-specific metrics."""
    
    def test_record_ttfb(self):
        """Record time to first byte."""
        from veena3modal.shared.metrics import record_ttfb
        
        record_ttfb(ttfb_seconds=0.25, speaker="lipakshi", stream=True)
    
    def test_record_rtf(self):
        """Record real-time factor."""
        from veena3modal.shared.metrics import record_rtf
        
        record_rtf(rtf=0.43, speaker="lipakshi")
    
    def test_record_audio_duration(self):
        """Record generated audio duration."""
        from veena3modal.shared.metrics import record_audio_duration
        
        record_audio_duration(duration_seconds=3.5, speaker="lipakshi")
    
    def test_record_chunks_sent(self):
        """Record number of chunks sent in streaming."""
        from veena3modal.shared.metrics import record_chunks_sent
        
        record_chunks_sent(chunks=5, speaker="lipakshi")


class TestModelMetrics:
    """Test model-related metrics."""
    
    def test_record_model_load_time(self):
        """Record model loading time."""
        from veena3modal.shared.metrics import record_model_load_time
        
        record_model_load_time(duration_seconds=15.5, model_version="spark_tts_4speaker")
    
    def test_set_model_loaded(self):
        """Set model loaded gauge."""
        from veena3modal.shared.metrics import set_model_loaded
        
        set_model_loaded(loaded=True, model_version="spark_tts_4speaker")


class TestMetricsExport:
    """Test metrics export functionality."""
    
    def test_get_metrics_text(self):
        """Get Prometheus text format metrics."""
        from veena3modal.shared.metrics import get_metrics_text
        
        text = get_metrics_text()
        
        # Should be non-empty string in Prometheus format
        assert isinstance(text, str)
    
    def test_metrics_endpoint_data(self):
        """Verify metrics can be serialized for HTTP response."""
        from veena3modal.shared.metrics import get_metrics_text
        
        text = get_metrics_text()
        
        # Should be valid UTF-8
        text.encode('utf-8')


class TestMetricsLabels:
    """Test metric label handling."""
    
    def test_speaker_label_sanitized(self):
        """Speaker names should be sanitized for Prometheus labels."""
        from veena3modal.shared.metrics import sanitize_label
        
        # Normal speaker name
        assert sanitize_label("lipakshi") == "lipakshi"
        
        # Speaker with special chars (shouldn't happen, but handle gracefully)
        assert sanitize_label("speaker-1") == "speaker_1"
    
    def test_format_label(self):
        """Format labels should be valid."""
        from veena3modal.shared.metrics import sanitize_label
        
        assert sanitize_label("wav") == "wav"
        assert sanitize_label("opus") == "opus"

