"""
Pydantic schemas for TTS API requests/responses.

Ported from veena3srv/apps/api/serializers.py with full validation parity.
No Django dependencies - pure Pydantic for Modal deployment.
"""

from __future__ import annotations

import re
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator


# === Constants (mirrored from veena3srv/apps/inference/constants.py) ===
# These are duplicated here to avoid Django import chain; must stay in sync.

MAX_TEXT_LENGTH = 50000

# Spark TTS Speaker System - 12 Speakers
# Model: BayAreaBoys/spark_tts_4speaker (HuggingFace) has 12 speaker tokens
SPEAKER_MAP = {
    "lipakshi": 0, "vardan": 1, "reet": 2, "Nandini": 3,
    "krishna": 4, "anika": 5, "adarsh": 6, "Nilay": 7,
    "Aarvi": 8, "Asha": 9, "Bittu": 10, "Mira": 11,
}

# Friendly speaker name mappings (user-facing → internal)
FRIENDLY_SPEAKER_MAP = {
    "Mitra": "lipakshi", "Aaranya": "reet", "Taru": "Nandini",
    "Neer": "Nilay", "Dhruva": "vardan", "Ira": "anika",
    "Veda": "adarsh", "Aria": "krishna",
    "Aarvi": "Aarvi", "Asha": "Asha", "Bittu": "Bittu", "Mira": "Mira",
}

INDIC_SPEAKERS = list(SPEAKER_MAP.keys())
ALL_SPEAKER_NAMES = INDIC_SPEAKERS + list(FRIENDLY_SPEAKER_MAP.keys())

# Emotion tags (Spark TTS bracket format)
INDIC_EMOTION_TAGS = [
    "[angry]", "[curious]", "[excited]", "[giggle]", "[laughs harder]",
    "[laughs]", "[screams]", "[sighs]", "[sings]", "[whispers]"
]

# Legacy emotion tag mapping (<angle> → [bracket])
LEGACY_EMOTION_MAP = {
    "<angry>": "[angry]", "<curious>": "[curious]", "<excited>": "[excited]",
    "<giggle>": "[giggle]", "<laugh_harder>": "[laughs harder]",
    "<laugh>": "[laughs]", "<scream>": "[screams]", "<sigh>": "[sighs]",
    "<sing>": "[sings]", "<whisper>": "[whispers]"
}


# === Helper Functions ===

def resolve_speaker_name(name: str) -> str:
    """
    Resolve friendly speaker name to internal name.
    
    Examples:
        resolve_speaker_name("Mitra") -> "lipakshi"
        resolve_speaker_name("lipakshi") -> "lipakshi"
    
    Raises:
        ValueError: if name is not a valid speaker name
    """
    if name in FRIENDLY_SPEAKER_MAP:
        return FRIENDLY_SPEAKER_MAP[name]
    if name in INDIC_SPEAKERS:
        return name
    raise ValueError(
        f"Invalid speaker name: {name}. Valid names: {', '.join(ALL_SPEAKER_NAMES)}"
    )


# Emotion normalization mapping (various formats → [emotion])
_EMOTION_TO_TAG = {
    **{k.strip('<>'): v for k, v in LEGACY_EMOTION_MAP.items()},
    'laughing': '[laughs]', 'laugh': '[laughs]', 'laughs harder': '[laughs harder]',
    'laugh harder': '[laughs harder]', 'laugh_harder': '[laughs harder]',
    'sighs': '[sighs]', 'sigh': '[sighs]', 'giggles': '[giggle]', 'giggle': '[giggle]',
    'angry': '[angry]', 'excited': '[excited]', 'whispers': '[whispers]',
    'whisper': '[whispers]', 'screaming': '[screams]', 'scream': '[screams]',
    'screams': '[screams]', 'singing': '[sings]', 'sing': '[sings]',
    'sings': '[sings]', 'curious': '[curious]',
}


def normalize_emotion_tags(text: str) -> str:
    """
    Normalize emotion tags to Spark TTS bracket format.
    
    Conversions:
    - <laugh> → [laughs]
    - [laughing] → [laughs]
    - etc.
    """
    # Convert old <emotion> format to [emotion]
    angle_pattern = r'<([a-z_]+)>'
    
    def replace_angle(match):
        emotion = match.group(1).lower().strip()
        if emotion in _EMOTION_TO_TAG:
            return _EMOTION_TO_TAG[emotion]
        old_tag = f'<{emotion}>'
        if old_tag in LEGACY_EMOTION_MAP:
            return LEGACY_EMOTION_MAP[old_tag]
        return f'[{emotion}]'  # Unknown, convert to bracket anyway
    
    text = re.sub(angle_pattern, replace_angle, text)
    
    # Normalize [emotion] variants
    bracket_pattern = r'\[([^\]]+)\]'
    
    def replace_bracket(match):
        emotion = match.group(1).lower().strip()
        bracket_tag = f'[{emotion}]'
        if bracket_tag in INDIC_EMOTION_TAGS:
            return bracket_tag
        if emotion in _EMOTION_TO_TAG:
            return _EMOTION_TO_TAG[emotion]
        # Try singular/plural
        if emotion.endswith('s') and emotion[:-1] in _EMOTION_TO_TAG:
            return _EMOTION_TO_TAG[emotion[:-1]]
        if (emotion + 's') in _EMOTION_TO_TAG:
            return _EMOTION_TO_TAG[emotion + 's']
        return bracket_tag  # Keep as-is
    
    return re.sub(bracket_pattern, replace_bracket, text)


# === Enums ===

class AudioFormat(str, Enum):
    """Supported audio output formats."""
    WAV = "wav"
    OPUS = "opus"
    MP3 = "mp3"
    MULAW = "mulaw"
    FLAC = "flac"


class OutputSampleRate(str, Enum):
    """Output sample rate options."""
    SR_16KHZ = "16khz"
    SR_48KHZ = "48khz"


# === Request Schema ===

class TTSGenerateRequest(BaseModel):
    """
    Request schema for /v1/tts/generate.
    
    Mirrors validation rules from veena3srv/apps/api/serializers.py.
    """
    
    # Required
    text: str = Field(
        ...,
        min_length=1,
        max_length=MAX_TEXT_LENGTH,
        description="Text to synthesize (max 50,000 characters)"
    )
    
    # Speaker (required for indic_speakers model)
    speaker: Optional[str] = Field(
        None,
        max_length=50,
        description="Speaker name (internal or friendly)"
    )
    
    # Voice design (not currently supported)
    description: Optional[str] = Field(
        None,
        max_length=2000,
        description="Voice description (for voiceDesign model, NOT currently supported)"
    )
    voice_id: Optional[UUID] = Field(
        None,
        description="UUID of pre-created voice profile (for voiceDesign model)"
    )
    
    # Generation options
    model: Optional[str] = Field(
        None,
        max_length=50,
        description="Model to use (optional). Default: indic_speakers"
    )
    seed: Optional[int] = Field(
        None,
        ge=0,
        le=2**31 - 1,
        description="Random seed for reproducibility (0-2147483647)"
    )
    stream: bool = Field(
        False,
        description="Stream audio as it's generated"
    )
    format: AudioFormat = Field(
        AudioFormat.WAV,
        description="Audio output format"
    )
    sample_rate: int = Field(
        16000,
        description="Sample rate in Hz (default: 16000)"
    )
    
    # Advanced parameters
    temperature: float = Field(
        0.4,
        ge=0.0,
        le=2.0,
        description="Sampling temperature (0.0-2.0)"
    )
    top_k: int = Field(
        50,
        ge=1,
        le=100,
        description="Top-k sampling (1-100)"
    )
    top_p: float = Field(
        1.0,
        ge=0.0,
        le=1.0,
        description="Nucleus sampling (0.0-1.0)"
    )
    max_tokens: int = Field(
        4096,
        ge=128,
        le=4096,
        description="Maximum BiCodec tokens to generate"
    )
    repetition_penalty: float = Field(
        1.05,
        ge=1.0,
        le=2.0,
        description="Repetition penalty (1.0-2.0)"
    )
    
    # Preprocessing toggles
    normalize: bool = Field(
        True,
        description="Apply text normalization before TTS"
    )
    normalize_verbose: bool = Field(
        False,
        description="Return normalized text in X-Normalized-Text header"
    )
    chunking: bool = Field(
        True,
        description="Enable intelligent text chunking for long inputs"
    )
    
    # Output options
    output: OutputSampleRate = Field(
        OutputSampleRate.SR_16KHZ,
        description="Output sample rate: '16khz' or '48khz' (super-resolution)"
    )
    
    # Internal fields (populated during validation)
    _original_text: Optional[str] = None
    _resolved_speaker: Optional[str] = None
    
    model_config = ConfigDict(use_enum_values=True)  # Serialize enums as strings
    
    @field_validator('text')
    @classmethod
    def validate_text(cls, v: str) -> str:
        """Validate text field: no empty, no control chars."""
        stripped = v.strip()
        if not stripped:
            raise ValueError("Text cannot be empty or whitespace only")
        
        # Check for control characters (except newlines/tabs)
        control_chars = [c for c in v if ord(c) < 32 and c not in '\n\t\r']
        if control_chars:
            raise ValueError("Text contains invalid control characters")
        
        return v
    
    @field_validator('description')
    @classmethod
    def validate_description(cls, v: Optional[str]) -> Optional[str]:
        """Validate description field."""
        if not v:
            return v
        control_chars = [c for c in v if ord(c) < 32 and c not in '\n\t\r']
        if control_chars:
            raise ValueError("Description contains invalid control characters")
        return v
    
    @field_validator('sample_rate')
    @classmethod
    def validate_sample_rate(cls, v: int) -> int:
        """Validate sample rate is one of the allowed values."""
        valid_rates = [8000, 16000, 22050, 24000, 44100, 48000]
        if v not in valid_rates:
            raise ValueError(f"Invalid sample rate. Valid options: {valid_rates}")
        return v
    
    @model_validator(mode='after')
    def cross_field_validation(self) -> 'TTSGenerateRequest':
        """
        Cross-field validation: speaker required, speaker resolution, mu-law sample rate.
        
        NOTE: We assume indic_speakers model type (voiceDesign not supported yet).
        """
        has_speaker = bool(self.speaker)
        has_description = bool(self.description)
        has_voice_id = self.voice_id is not None
        
        # For indic_speakers model: speaker is required
        if not has_speaker:
            raise ValueError(
                "Parameter 'speaker' is required for indic_speakers model. "
                f"Must be one of: {', '.join(ALL_SPEAKER_NAMES)}"
            )
        
        # Validate and resolve speaker name
        if self.speaker not in ALL_SPEAKER_NAMES:
            raise ValueError(
                f"Invalid speaker '{self.speaker}'. Valid names: {', '.join(ALL_SPEAKER_NAMES)}"
            )
        
        # Store resolved speaker
        self._resolved_speaker = resolve_speaker_name(self.speaker)
        
        # Warn if description/voice_id provided (not supported for indic model)
        if has_description or has_voice_id:
            raise ValueError(
                "Parameters 'description' and 'voice_id' are only valid for voiceDesign model. "
                "Current model type is 'indic_speakers'. Use 'speaker' instead."
            )
        
        # mu-law typically uses 8kHz
        if self.format == AudioFormat.MULAW and self.sample_rate != 8000:
            self.sample_rate = 8000
        
        return self
    
    def get_resolved_speaker(self) -> str:
        """Return the resolved internal speaker name."""
        if self._resolved_speaker:
            return self._resolved_speaker
        return resolve_speaker_name(self.speaker) if self.speaker else ""
    
    def get_normalized_text(self, normalizer_func=None) -> str:
        """
        Return normalized + emotion-normalized text.
        
        Args:
            normalizer_func: Optional custom normalizer (for testing).
                             If None, returns text with only emotion normalization.
        """
        text = self.text
        
        # Apply text normalization if enabled
        if self.normalize and normalizer_func:
            self._original_text = text
            text = normalizer_func(text)
        
        # Always apply emotion tag normalization
        text = normalize_emotion_tags(text)
        
        return text


# === Response Schemas ===

class TTSGenerateResponse(BaseModel):
    """Response metadata for non-streaming /v1/tts/generate."""
    
    request_id: str
    model: str
    format: str
    sample_rate: int
    
    # Optional fields
    voice_id: Optional[UUID] = None
    seed: Optional[int] = None
    audio_url: Optional[str] = None
    
    # Metrics
    tokens_prompt: Optional[int] = None
    tokens_generated: Optional[int] = None
    audio_duration_seconds: Optional[float] = None
    audio_bytes: Optional[int] = None
    ttfb_ms: Optional[int] = None
    rtf: Optional[float] = None
    credits_consumed: Optional[float] = None


class ErrorDetail(BaseModel):
    """Error detail structure."""
    code: str
    message: str
    details: Dict[str, Any] = Field(default_factory=dict)
    request_id: Optional[str] = None
    documentation_url: Optional[str] = None


class ErrorResponse(BaseModel):
    """Standard error response."""
    error: ErrorDetail


class HealthResponse(BaseModel):
    """Response for /v1/tts/health endpoint."""
    status: str  # healthy | degraded | unhealthy
    model_loaded: bool
    model_version: str
    uptime_seconds: float
    gpu_available: bool
    app_version: str

