"""
Structured logging for Veena3 TTS service.

Features:
- JSON-formatted logs for structured queries
- Request ID tracking via context
- Lifecycle event helpers (request_received, first_audio_emitted, etc.)
- NO PII in logs (text content stored in Supabase, not logs)

Usage:
    from veena3modal.shared.logging import get_logger, log_event, set_request_context
    
    logger = get_logger(__name__)
    set_request_context(request_id="req-123")
    log_event(logger, "request_received", text_length=100, speaker="lipakshi")
"""

import logging
import json
import sys
import traceback
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from contextvars import ContextVar

# Context variable for request-scoped data
_request_context: ContextVar[Dict[str, Any]] = ContextVar('request_context', default={})


def set_request_context(request_id: Optional[str] = None, **kwargs) -> None:
    """
    Set request context for current async task.
    
    Args:
        request_id: Unique request identifier
        **kwargs: Additional context fields (user_id, etc.)
    """
    ctx = {"request_id": request_id, **kwargs}
    _request_context.set(ctx)


def get_request_context() -> Dict[str, Any]:
    """Get current request context."""
    return _request_context.get()


def clear_request_context() -> None:
    """Clear request context after request completes."""
    _request_context.set({})


class JSONFormatter(logging.Formatter):
    """
    JSON log formatter for structured logging.
    
    Output format:
    {
        "timestamp": "2024-01-01T12:00:00.000Z",
        "level": "INFO",
        "logger": "veena3modal.api",
        "message": "Request received",
        "request_id": "req-123",
        ...extra fields...
    }
    """
    
    # Fields to exclude from extra (already handled separately)
    RESERVED_ATTRS = {
        'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
        'funcName', 'levelname', 'levelno', 'lineno', 'module', 'msecs',
        'message', 'msg', 'name', 'pathname', 'process', 'processName',
        'relativeCreated', 'stack_info', 'thread', 'threadName',
        'taskName',  # Python 3.12+
    }
    
    def format(self, record: logging.LogRecord) -> str:
        """Format log record as JSON."""
        # Build base log entry
        log_entry = {
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
        }
        
        # Add request context if available
        ctx = get_request_context()
        if ctx.get("request_id"):
            log_entry["request_id"] = ctx["request_id"]
        
        # Add extra fields from record
        for key, value in record.__dict__.items():
            if key not in self.RESERVED_ATTRS and not key.startswith('_'):
                try:
                    # Ensure value is JSON serializable
                    json.dumps(value)
                    log_entry[key] = value
                except (TypeError, ValueError):
                    log_entry[key] = str(value)
        
        # Add exception info if present
        if record.exc_info:
            log_entry["exception"] = ''.join(traceback.format_exception(*record.exc_info))
        
        return json.dumps(log_entry, default=str)


def get_logger(name: str) -> logging.Logger:
    """
    Get a logger configured for structured JSON output.
    
    Args:
        name: Logger name (typically __name__)
    
    Returns:
        Configured logger instance
    """
    logger = logging.getLogger(name)
    
    # Only configure if no handlers exist (avoid duplicate setup)
    if not logger.handlers:
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(JSONFormatter())
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)
    
    return logger


def log_event(
    logger: logging.Logger,
    event_type: str,
    level: int = logging.INFO,
    request_id: Optional[str] = None,
    **kwargs
) -> None:
    """
    Log a structured event.
    
    Args:
        logger: Logger instance
        event_type: Event name (e.g., "request_received", "first_audio_emitted")
        level: Log level (default: INFO)
        request_id: Override request_id from context
        **kwargs: Event-specific fields
    """
    extra = {"event": event_type, **kwargs}
    
    if request_id:
        extra["request_id"] = request_id
    
    logger.log(level, event_type, extra=extra)


def create_lifecycle_event(
    event_type: str,
    request_id: str,
    **kwargs
) -> Dict[str, Any]:
    """
    Create a lifecycle event dictionary.
    
    Lifecycle events:
    - request_received: TTS request started
    - auth_validated: API key validated
    - generation_started: Model inference began
    - first_audio_emitted: First audio chunk sent (TTFB)
    - request_completed: Request finished successfully
    - request_failed: Request failed with error
    
    Args:
        event_type: Event name
        request_id: Request identifier
        **kwargs: Event-specific fields
    
    Returns:
        Event dictionary ready for logging
    """
    return {
        "event": event_type,
        "request_id": request_id,
        "timestamp": datetime.now(timezone.utc).isoformat(),
        **kwargs
    }


# Pre-configured logger for the TTS service
tts_logger = get_logger("veena3modal.tts")


def log_request_received(
    request_id: str,
    text_length: int,
    speaker: str,
    stream: bool,
    format: str,
    **kwargs
) -> None:
    """Log request_received lifecycle event."""
    log_event(
        tts_logger,
        "request_received",
        request_id=request_id,
        text_length=text_length,
        speaker=speaker,
        stream=stream,
        format=format,
        **kwargs
    )


def log_first_audio_emitted(
    request_id: str,
    ttfb_ms: int,
    chunk_size_bytes: int,
    **kwargs
) -> None:
    """Log first_audio_emitted lifecycle event (TTFB marker)."""
    log_event(
        tts_logger,
        "first_audio_emitted",
        request_id=request_id,
        ttfb_ms=ttfb_ms,
        chunk_size_bytes=chunk_size_bytes,
        **kwargs
    )


def log_request_completed(
    request_id: str,
    status_code: int,
    total_duration_ms: int,
    audio_duration_seconds: float,
    rtf: float,
    chunks_sent: int = 1,
    **kwargs
) -> None:
    """Log request_completed lifecycle event."""
    log_event(
        tts_logger,
        "request_completed",
        request_id=request_id,
        status_code=status_code,
        total_duration_ms=total_duration_ms,
        audio_duration_seconds=audio_duration_seconds,
        rtf=rtf,
        chunks_sent=chunks_sent,
        **kwargs
    )


def log_request_failed(
    request_id: str,
    status_code: int,
    error_code: str,
    error_message: str,
    **kwargs
) -> None:
    """Log request_failed lifecycle event."""
    log_event(
        tts_logger,
        "request_failed",
        level=logging.ERROR,
        request_id=request_id,
        status_code=status_code,
        error_code=error_code,
        error_message=error_message,
        **kwargs
    )

