"""
Prometheus metrics for Veena3 TTS service.

Features:
- Request counters (by speaker, format, status)
- TTFB histogram
- RTF histogram
- Audio duration histogram
- Model load time gauge

Usage:
    from veena3modal.shared.metrics import record_request_received, record_ttfb
    
    record_request_received(speaker="lipakshi", stream=True, format="wav")
    record_ttfb(ttfb_seconds=0.25, speaker="lipakshi", stream=True)
"""

import re
from typing import Optional

# Try to import prometheus_client, fall back to no-op if not installed
try:
    from prometheus_client import (
        Counter, Histogram, Gauge, 
        CollectorRegistry, generate_latest, CONTENT_TYPE_LATEST
    )
    PROMETHEUS_AVAILABLE = True
except ImportError:
    PROMETHEUS_AVAILABLE = False


# Module-level registry (singleton)
_registry: Optional['CollectorRegistry'] = None

# Metrics instances (lazy initialized)
_requests_total: Optional['Counter'] = None
_requests_completed: Optional['Counter'] = None
_requests_failed: Optional['Counter'] = None
_ttfb_seconds: Optional['Histogram'] = None
_rtf: Optional['Histogram'] = None
_audio_duration_seconds: Optional['Histogram'] = None
_chunks_sent: Optional['Histogram'] = None
_model_load_seconds: Optional['Gauge'] = None
_model_loaded: Optional['Gauge'] = None


def get_metrics_registry() -> Optional['CollectorRegistry']:
    """Get or create the metrics registry."""
    global _registry
    
    if not PROMETHEUS_AVAILABLE:
        return None
    
    if _registry is None:
        _registry = CollectorRegistry()
        _initialize_metrics()
    
    return _registry


def _initialize_metrics() -> None:
    """Initialize all metrics with the registry."""
    global _requests_total, _requests_completed, _requests_failed
    global _ttfb_seconds, _rtf, _audio_duration_seconds, _chunks_sent
    global _model_load_seconds, _model_loaded
    
    if not PROMETHEUS_AVAILABLE or _registry is None:
        return
    
    # Request counters
    _requests_total = Counter(
        'veena3_tts_requests_total',
        'Total TTS requests received',
        ['speaker', 'stream', 'format'],
        registry=_registry
    )
    
    _requests_completed = Counter(
        'veena3_tts_requests_completed_total',
        'Total TTS requests completed successfully',
        ['speaker', 'stream', 'status_code'],
        registry=_registry
    )
    
    _requests_failed = Counter(
        'veena3_tts_requests_failed_total',
        'Total TTS requests failed',
        ['speaker', 'error_code', 'status_code'],
        registry=_registry
    )
    
    # Performance histograms
    _ttfb_seconds = Histogram(
        'veena3_tts_ttfb_seconds',
        'Time to first byte in seconds',
        ['speaker', 'stream'],
        buckets=(0.1, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0),
        registry=_registry
    )
    
    _rtf = Histogram(
        'veena3_tts_rtf',
        'Real-time factor (generation_time / audio_duration)',
        ['speaker'],
        buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.5, 2.0),
        registry=_registry
    )
    
    _audio_duration_seconds = Histogram(
        'veena3_tts_audio_duration_seconds',
        'Generated audio duration in seconds',
        ['speaker'],
        buckets=(1, 2, 5, 10, 20, 30, 60, 120, 300),
        registry=_registry
    )
    
    _chunks_sent = Histogram(
        'veena3_tts_chunks_sent',
        'Number of audio chunks sent in streaming mode',
        ['speaker'],
        buckets=(1, 2, 5, 10, 20, 50, 100),
        registry=_registry
    )
    
    # Model metrics
    _model_load_seconds = Gauge(
        'veena3_tts_model_load_seconds',
        'Time taken to load the model',
        ['model_version'],
        registry=_registry
    )
    
    _model_loaded = Gauge(
        'veena3_tts_model_loaded',
        'Whether the model is loaded (1) or not (0)',
        ['model_version'],
        registry=_registry
    )


def sanitize_label(value: str) -> str:
    """
    Sanitize a string for use as a Prometheus label.
    
    Prometheus labels must match [a-zA-Z_][a-zA-Z0-9_]*
    """
    # Replace non-alphanumeric chars (except underscore) with underscore
    sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(value))
    # Ensure doesn't start with a digit
    if sanitized and sanitized[0].isdigit():
        sanitized = '_' + sanitized
    return sanitized


def record_request_received(
    speaker: str,
    stream: bool,
    format: str,
) -> None:
    """Record a request received."""
    if _requests_total is None:
        get_metrics_registry()  # Initialize if needed
    
    if _requests_total is not None:
        _requests_total.labels(
            speaker=sanitize_label(speaker),
            stream=str(stream).lower(),
            format=sanitize_label(format),
        ).inc()


def record_request_completed(
    status_code: int,
    duration_seconds: float,
    speaker: str,
    stream: bool,
) -> None:
    """Record a successful request completion."""
    if _requests_completed is None:
        get_metrics_registry()
    
    if _requests_completed is not None:
        _requests_completed.labels(
            speaker=sanitize_label(speaker),
            stream=str(stream).lower(),
            status_code=str(status_code),
        ).inc()


def record_request_failed(
    status_code: int,
    error_code: str,
    speaker: str,
) -> None:
    """Record a failed request."""
    if _requests_failed is None:
        get_metrics_registry()
    
    if _requests_failed is not None:
        _requests_failed.labels(
            speaker=sanitize_label(speaker),
            error_code=sanitize_label(error_code),
            status_code=str(status_code),
        ).inc()


def record_ttfb(
    ttfb_seconds: float,
    speaker: str,
    stream: bool,
) -> None:
    """Record time to first byte."""
    if _ttfb_seconds is None:
        get_metrics_registry()
    
    if _ttfb_seconds is not None:
        _ttfb_seconds.labels(
            speaker=sanitize_label(speaker),
            stream=str(stream).lower(),
        ).observe(ttfb_seconds)


def record_rtf(
    rtf: float,
    speaker: str,
) -> None:
    """Record real-time factor."""
    if _rtf is None:
        get_metrics_registry()
    
    if _rtf is not None:
        _rtf.labels(
            speaker=sanitize_label(speaker),
        ).observe(rtf)


def record_audio_duration(
    duration_seconds: float,
    speaker: str,
) -> None:
    """Record generated audio duration."""
    if _audio_duration_seconds is None:
        get_metrics_registry()
    
    if _audio_duration_seconds is not None:
        _audio_duration_seconds.labels(
            speaker=sanitize_label(speaker),
        ).observe(duration_seconds)


def record_chunks_sent(
    chunks: int,
    speaker: str,
) -> None:
    """Record number of chunks sent in streaming."""
    if _chunks_sent is None:
        get_metrics_registry()
    
    if _chunks_sent is not None:
        _chunks_sent.labels(
            speaker=sanitize_label(speaker),
        ).observe(chunks)


def record_model_load_time(
    duration_seconds: float,
    model_version: str,
) -> None:
    """Record model loading time."""
    if _model_load_seconds is None:
        get_metrics_registry()
    
    if _model_load_seconds is not None:
        _model_load_seconds.labels(
            model_version=sanitize_label(model_version),
        ).set(duration_seconds)


def set_model_loaded(
    loaded: bool,
    model_version: str,
) -> None:
    """Set model loaded status."""
    if _model_loaded is None:
        get_metrics_registry()
    
    if _model_loaded is not None:
        _model_loaded.labels(
            model_version=sanitize_label(model_version),
        ).set(1 if loaded else 0)


def get_metrics_text() -> str:
    """
    Get metrics in Prometheus text format.
    
    Returns:
        Prometheus text format metrics string
    """
    registry = get_metrics_registry()
    
    if not PROMETHEUS_AVAILABLE or registry is None:
        return "# Prometheus client not available\n"
    
    return generate_latest(registry).decode('utf-8')


def get_content_type() -> str:
    """Get the content type for Prometheus metrics."""
    if PROMETHEUS_AVAILABLE:
        return CONTENT_TYPE_LATEST
    return "text/plain; charset=utf-8"

