"""
Unit tests for Pydantic schemas (M2).

Tests cover:
- TTSGenerateRequest validation (max length, unicode, control chars)
- Speaker alias resolution
- Emotion tag normalization
- Format/output param validation
"""

import pytest
from uuid import uuid4

from veena3modal.api.schemas import (
    TTSGenerateRequest,
    AudioFormat,
    OutputSampleRate,
    MAX_TEXT_LENGTH,
    ALL_SPEAKER_NAMES,
    INDIC_SPEAKERS,
    FRIENDLY_SPEAKER_MAP,
    resolve_speaker_name,
    normalize_emotion_tags,
)


class TestSpeakerResolution:
    """Test speaker name resolution."""

    def test_internal_names_resolve_to_themselves(self):
        """Internal speaker names should resolve to themselves."""
        for name in INDIC_SPEAKERS:
            assert resolve_speaker_name(name) == name

    def test_friendly_names_resolve_to_internal(self):
        """Friendly names should resolve to their internal counterparts."""
        assert resolve_speaker_name("Mitra") == "lipakshi"
        assert resolve_speaker_name("Aaranya") == "reet"
        assert resolve_speaker_name("Taru") == "Nandini"
        assert resolve_speaker_name("Neer") == "Nilay"
        assert resolve_speaker_name("Dhruva") == "vardan"
        assert resolve_speaker_name("Ira") == "anika"
        assert resolve_speaker_name("Veda") == "adarsh"
        assert resolve_speaker_name("Aria") == "krishna"

    def test_new_speakers_resolve_to_themselves(self):
        """Newly added speakers (Aarvi, Asha, Bittu, Mira) resolve correctly."""
        assert resolve_speaker_name("Aarvi") == "Aarvi"
        assert resolve_speaker_name("Asha") == "Asha"
        assert resolve_speaker_name("Bittu") == "Bittu"
        assert resolve_speaker_name("Mira") == "Mira"

    def test_invalid_speaker_raises_error(self):
        """Invalid speaker names should raise ValueError."""
        with pytest.raises(ValueError) as exc_info:
            resolve_speaker_name("InvalidSpeaker")
        assert "Invalid speaker name" in str(exc_info.value)

    def test_case_sensitive_speaker_names(self):
        """Speaker names are case-sensitive."""
        # "Nandini" is valid, "nandini" is not
        assert resolve_speaker_name("Nandini") == "Nandini"
        with pytest.raises(ValueError):
            resolve_speaker_name("nandini")


class TestEmotionTagNormalization:
    """Test emotion tag normalization."""

    def test_legacy_angle_brackets_converted(self):
        """Old <emotion> format should convert to [emotion]."""
        assert "[laughs]" in normalize_emotion_tags("<laugh> ha ha")
        assert "[sighs]" in normalize_emotion_tags("<sigh> okay fine")
        assert "[screams]" in normalize_emotion_tags("<scream>!")
        assert "[whispers]" in normalize_emotion_tags("<whisper> shh")

    def test_bracket_variants_normalized(self):
        """Bracket emotion variants should normalize to exact model tags."""
        # Singular/plural normalization
        assert "[laughs]" in normalize_emotion_tags("[laugh] ha")
        assert "[laughs]" in normalize_emotion_tags("[laughing] so funny")
        assert "[sighs]" in normalize_emotion_tags("[sigh] okay")
        assert "[sings]" in normalize_emotion_tags("[singing] la la")
        assert "[giggle]" in normalize_emotion_tags("[giggles] hehe")

    def test_valid_tags_preserved(self):
        """Already-valid Spark TTS emotion tags should be preserved."""
        for tag in ["[angry]", "[curious]", "[excited]", "[giggle]", 
                    "[laughs harder]", "[laughs]", "[screams]", 
                    "[sighs]", "[sings]", "[whispers]"]:
            text = f"Hello {tag} world"
            normalized = normalize_emotion_tags(text)
            assert tag in normalized

    def test_multiple_emotions_in_text(self):
        """Multiple emotion tags in one text should all be normalized."""
        text = "First <laugh> then <sigh> and finally [whisper]"
        normalized = normalize_emotion_tags(text)
        assert "[laughs]" in normalized
        assert "[sighs]" in normalized
        assert "[whispers]" in normalized

    def test_unknown_emotion_preserved_in_brackets(self):
        """Unknown emotions should be kept in bracket format."""
        text = "<unknown_emotion> test"
        normalized = normalize_emotion_tags(text)
        assert "[unknown_emotion]" in normalized

    def test_case_insensitive_normalization(self):
        """Emotion normalization should be case-insensitive."""
        assert "[laughs]" in normalize_emotion_tags("[LAUGH]")
        assert "[sighs]" in normalize_emotion_tags("[SIGH]")


class TestTTSGenerateRequestBasic:
    """Test basic TTSGenerateRequest validation."""

    def test_valid_minimal_request(self):
        """Minimal valid request with text + speaker."""
        req = TTSGenerateRequest(text="Hello world", speaker="lipakshi")
        assert req.text == "Hello world"
        assert req.speaker == "lipakshi"
        assert req.stream is False
        assert req.format == "wav"

    def test_valid_request_with_all_options(self):
        """Full request with all optional parameters."""
        req = TTSGenerateRequest(
            text="Hello world",
            speaker="Mitra",
            stream=True,
            format=AudioFormat.OPUS,
            temperature=0.9,
            top_k=40,
            top_p=0.95,
            max_tokens=2048,
            normalize=True,
            chunking=True,
            output=OutputSampleRate.SR_48KHZ,
        )
        assert req.speaker == "Mitra"
        assert req.stream is True
        assert req.format == "opus"
        assert req.output == "48khz"


class TestTTSGenerateRequestTextValidation:
    """Test text field validation."""

    def test_empty_text_rejected(self):
        """Empty text should be rejected."""
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(text="", speaker="lipakshi")
        # Pydantic uses "String should have at least 1 character" for min_length
        assert "at least 1" in str(exc_info.value).lower() or "empty" in str(exc_info.value).lower()

    def test_whitespace_only_text_rejected(self):
        """Whitespace-only text should be rejected."""
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(text="   \n\t  ", speaker="lipakshi")
        assert "empty" in str(exc_info.value).lower() or "whitespace" in str(exc_info.value).lower()

    def test_text_max_length_enforced(self):
        """Text exceeding max length should be rejected."""
        long_text = "x" * (MAX_TEXT_LENGTH + 1)
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(text=long_text, speaker="lipakshi")
        assert "50000" in str(exc_info.value) or "max" in str(exc_info.value).lower()

    def test_text_at_max_length_accepted(self):
        """Text exactly at max length should be accepted."""
        exact_text = "x" * MAX_TEXT_LENGTH
        req = TTSGenerateRequest(text=exact_text, speaker="lipakshi")
        assert len(req.text) == MAX_TEXT_LENGTH

    def test_control_characters_rejected(self):
        """Text with control characters should be rejected."""
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(text="Hello\x00world", speaker="lipakshi")
        assert "control" in str(exc_info.value).lower()

    def test_newlines_and_tabs_allowed(self):
        """Newlines and tabs should be allowed in text."""
        req = TTSGenerateRequest(
            text="Hello\nworld\twith\rnewlines",
            speaker="lipakshi"
        )
        assert "\n" in req.text
        assert "\t" in req.text


class TestTTSGenerateRequestUnicode:
    """Test unicode handling in requests."""

    def test_hindi_text_accepted(self):
        """Hindi (Devanagari) text should be accepted."""
        hindi_text = "नमस्ते दुनिया, यह एक परीक्षण है।"
        req = TTSGenerateRequest(text=hindi_text, speaker="lipakshi")
        assert req.text == hindi_text

    def test_telugu_text_accepted(self):
        """Telugu text should be accepted."""
        telugu_text = "హలో ప్రపంచం, ఇది ఒక పరీక్ష."
        req = TTSGenerateRequest(text=telugu_text, speaker="lipakshi")
        assert req.text == telugu_text

    def test_mixed_script_text_accepted(self):
        """Mixed script (English + Hindi) text should be accepted."""
        mixed_text = "Hello, नमस्ते world दुनिया"
        req = TTSGenerateRequest(text=mixed_text, speaker="lipakshi")
        assert req.text == mixed_text

    def test_emojis_accepted_in_text(self):
        """Emojis should be accepted (normalization removes them later)."""
        emoji_text = "Hello 👋 world 🌍"
        req = TTSGenerateRequest(text=emoji_text, speaker="lipakshi")
        assert "👋" in req.text

    def test_special_unicode_accepted(self):
        """Special unicode characters (quotes, dashes) should be accepted."""
        special_text = "\u201cHello\u201d \u2014 it's 'working'"  # curly quotes + em-dash
        req = TTSGenerateRequest(text=special_text, speaker="lipakshi")
        assert "\u201c" in req.text  # left double quotation mark


class TestTTSGenerateRequestSpeakerValidation:
    """Test speaker field validation."""

    def test_speaker_required(self):
        """Speaker is required for indic_speakers model."""
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(text="Hello world")
        assert "speaker" in str(exc_info.value).lower()

    def test_invalid_speaker_rejected(self):
        """Invalid speaker name should be rejected."""
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(text="Hello", speaker="NotASpeaker")
        assert "Invalid speaker" in str(exc_info.value)

    def test_all_valid_speakers_accepted(self):
        """All valid speaker names should be accepted."""
        for speaker in ALL_SPEAKER_NAMES:
            req = TTSGenerateRequest(text="Hello", speaker=speaker)
            assert req.speaker == speaker

    def test_speaker_resolved_correctly(self):
        """Speaker should be resolved to internal name."""
        req = TTSGenerateRequest(text="Hello", speaker="Mitra")
        assert req.get_resolved_speaker() == "lipakshi"

    def test_description_rejected_for_indic_model(self):
        """Description param should be rejected for indic_speakers model."""
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(
                text="Hello",
                speaker="lipakshi",
                description="A cheerful voice"
            )
        assert "voiceDesign" in str(exc_info.value) or "description" in str(exc_info.value).lower()

    def test_voice_id_rejected_for_indic_model(self):
        """voice_id param should be rejected for indic_speakers model."""
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(
                text="Hello",
                speaker="lipakshi",
                voice_id=uuid4()
            )
        assert "voiceDesign" in str(exc_info.value) or "voice_id" in str(exc_info.value).lower()


class TestTTSGenerateRequestFormatValidation:
    """Test format and output param validation."""

    def test_all_formats_accepted(self):
        """All audio formats should be accepted."""
        for fmt in ["wav", "opus", "mp3", "mulaw", "flac"]:
            req = TTSGenerateRequest(text="Hello", speaker="lipakshi", format=fmt)
            assert req.format == fmt

    def test_invalid_format_rejected(self):
        """Invalid format should be rejected."""
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hello", speaker="lipakshi", format="aac")

    def test_mulaw_forces_8khz(self):
        """mu-law format should force sample_rate to 8000."""
        req = TTSGenerateRequest(
            text="Hello",
            speaker="lipakshi",
            format=AudioFormat.MULAW,
            sample_rate=16000  # Should be overridden
        )
        assert req.sample_rate == 8000

    def test_valid_sample_rates_accepted(self):
        """Valid sample rates should be accepted."""
        for rate in [8000, 16000, 22050, 24000, 44100, 48000]:
            req = TTSGenerateRequest(
                text="Hello", 
                speaker="lipakshi", 
                sample_rate=rate
            )
            assert req.sample_rate == rate

    def test_invalid_sample_rate_rejected(self):
        """Invalid sample rate should be rejected."""
        with pytest.raises(ValueError) as exc_info:
            TTSGenerateRequest(text="Hello", speaker="lipakshi", sample_rate=12345)
        assert "Invalid sample rate" in str(exc_info.value)

    def test_output_sample_rate_options(self):
        """Output sample rate options should be accepted."""
        for output in ["16khz", "48khz"]:
            req = TTSGenerateRequest(text="Hello", speaker="lipakshi", output=output)
            assert req.output == output


class TestTTSGenerateRequestAdvancedParams:
    """Test advanced generation parameters."""

    def test_temperature_bounds(self):
        """Temperature should be bounded 0.0-2.0."""
        # Valid
        TTSGenerateRequest(text="Hi", speaker="lipakshi", temperature=0.0)
        TTSGenerateRequest(text="Hi", speaker="lipakshi", temperature=2.0)
        
        # Invalid
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", temperature=-0.1)
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", temperature=2.1)

    def test_top_k_bounds(self):
        """top_k should be bounded 1-100."""
        TTSGenerateRequest(text="Hi", speaker="lipakshi", top_k=1)
        TTSGenerateRequest(text="Hi", speaker="lipakshi", top_k=100)
        
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", top_k=0)
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", top_k=101)

    def test_top_p_bounds(self):
        """top_p should be bounded 0.0-1.0."""
        TTSGenerateRequest(text="Hi", speaker="lipakshi", top_p=0.0)
        TTSGenerateRequest(text="Hi", speaker="lipakshi", top_p=1.0)
        
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", top_p=-0.1)
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", top_p=1.1)

    def test_max_tokens_bounds(self):
        """max_tokens should be bounded 128-4096."""
        TTSGenerateRequest(text="Hi", speaker="lipakshi", max_tokens=128)
        TTSGenerateRequest(text="Hi", speaker="lipakshi", max_tokens=4096)
        
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", max_tokens=127)
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", max_tokens=4097)

    def test_repetition_penalty_bounds(self):
        """repetition_penalty should be bounded 1.0-2.0."""
        TTSGenerateRequest(text="Hi", speaker="lipakshi", repetition_penalty=1.0)
        TTSGenerateRequest(text="Hi", speaker="lipakshi", repetition_penalty=2.0)
        
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", repetition_penalty=0.9)
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", repetition_penalty=2.1)

    def test_seed_bounds(self):
        """seed should be bounded 0 to 2^31-1."""
        TTSGenerateRequest(text="Hi", speaker="lipakshi", seed=0)
        TTSGenerateRequest(text="Hi", speaker="lipakshi", seed=2**31 - 1)
        
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", seed=-1)
        with pytest.raises(ValueError):
            TTSGenerateRequest(text="Hi", speaker="lipakshi", seed=2**31)


class TestTTSGenerateRequestPreprocessingToggles:
    """Test preprocessing toggles."""

    def test_normalize_toggle(self):
        """normalize toggle should be accepted."""
        req_true = TTSGenerateRequest(text="Hi", speaker="lipakshi", normalize=True)
        req_false = TTSGenerateRequest(text="Hi", speaker="lipakshi", normalize=False)
        
        assert req_true.normalize is True
        assert req_false.normalize is False

    def test_chunking_toggle(self):
        """chunking toggle should be accepted."""
        req_true = TTSGenerateRequest(text="Hi", speaker="lipakshi", chunking=True)
        req_false = TTSGenerateRequest(text="Hi", speaker="lipakshi", chunking=False)
        
        assert req_true.chunking is True
        assert req_false.chunking is False

    def test_normalize_verbose_toggle(self):
        """normalize_verbose toggle should be accepted."""
        req = TTSGenerateRequest(
            text="Hi", 
            speaker="lipakshi", 
            normalize_verbose=True
        )
        assert req.normalize_verbose is True


class TestGetNormalizedText:
    """Test get_normalized_text method."""

    def test_emotion_normalization_always_applied(self):
        """Emotion normalization should always be applied."""
        req = TTSGenerateRequest(
            text="Hello <laugh> world",
            speaker="lipakshi",
            normalize=False  # Text norm disabled, but emotion norm still applies
        )
        normalized = req.get_normalized_text()
        assert "[laughs]" in normalized
        assert "<laugh>" not in normalized

    def test_text_normalization_with_custom_normalizer(self):
        """Custom normalizer should be applied when provided."""
        req = TTSGenerateRequest(
            text="Hello 123",
            speaker="lipakshi",
            normalize=True
        )
        
        # Mock normalizer that uppercases
        def mock_normalizer(text):
            return text.upper()
        
        normalized = req.get_normalized_text(normalizer_func=mock_normalizer)
        assert "HELLO" in normalized

    def test_text_normalization_skipped_when_disabled(self):
        """Text normalization should be skipped when normalize=False."""
        req = TTSGenerateRequest(
            text="Hello 123",
            speaker="lipakshi",
            normalize=False
        )
        
        # Even with normalizer provided, should not be called
        def should_not_be_called(text):
            raise AssertionError("Normalizer should not be called")
        
        # normalize=False means normalizer shouldn't run
        normalized = req.get_normalized_text(normalizer_func=should_not_be_called)
        assert "123" in normalized  # Numbers not expanded

