"""
Unit tests for WebSocket TTS handler.

Tests WebSocket protocol, message parsing, and error handling.
"""

import pytest
import json
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import asyncio

from veena3modal.api.websocket_handler import (
    WSRequest,
    WSMessageType,
    create_error_message,
    create_header_message,
    create_progress_message,
    create_complete_message,
)


class TestWSRequest:
    """Test WSRequest parsing and validation."""
    
    def test_from_dict_minimal(self):
        """Parse request with minimal fields."""
        data = {"text": "Hello world"}
        req = WSRequest.from_dict(data)
        
        assert req.text == "Hello world"
        assert req.speaker == "male1"  # default
        assert req.temperature == 0.8
        assert req.format == "wav"
    
    def test_from_dict_full(self):
        """Parse request with all fields."""
        data = {
            "text": "Hello world",
            "speaker": "female1",
            "temperature": 0.5,
            "top_k": 30,
            "top_p": 0.9,
            "max_tokens": 2048,
            "repetition_penalty": 1.1,
            "seed": 42,
            "chunking": True,
            "normalize": False,
            "format": "wav",
            "sample_rate": 16000,
        }
        req = WSRequest.from_dict(data)
        
        assert req.text == "Hello world"
        assert req.speaker == "female1"
        assert req.temperature == 0.5
        assert req.top_k == 30
        assert req.top_p == 0.9
        assert req.max_tokens == 2048
        assert req.repetition_penalty == 1.1
        assert req.seed == 42
        assert req.chunking is True
        assert req.normalize is False
    
    def test_validate_empty_text(self):
        """Empty text should fail validation."""
        req = WSRequest(text="")
        error = req.validate()
        assert error is not None
        assert "empty" in error.lower()
    
    def test_validate_whitespace_text(self):
        """Whitespace-only text should fail validation."""
        req = WSRequest(text="   ")
        error = req.validate()
        assert error is not None
        assert "empty" in error.lower()
    
    def test_validate_text_too_long(self):
        """Text over 50K should fail validation."""
        req = WSRequest(text="x" * 51000)
        error = req.validate()
        assert error is not None
        assert "too long" in error.lower()
    
    def test_validate_non_wav_format(self):
        """Non-WAV formats should fail validation for WebSocket."""
        req = WSRequest(text="Hello", format="mp3")
        error = req.validate()
        assert error is not None
        assert "wav" in error.lower()
    
    def test_validate_valid_request(self):
        """Valid request should pass validation."""
        req = WSRequest(text="Hello world", format="wav")
        error = req.validate()
        assert error is None
    
    def test_validate_max_length_allowed(self):
        """Text at exactly 50K should pass."""
        req = WSRequest(text="x" * 50000)
        error = req.validate()
        assert error is None


class TestWSMessageType:
    """Test WebSocket message type enum."""
    
    def test_client_message_types(self):
        """Client message types should be defined."""
        assert WSMessageType.REQUEST.value == "request"
        assert WSMessageType.CANCEL.value == "cancel"
        assert WSMessageType.PING.value == "ping"
    
    def test_server_message_types(self):
        """Server message types should be defined."""
        assert WSMessageType.AUDIO_CHUNK.value == "audio_chunk"
        assert WSMessageType.HEADER.value == "header"
        assert WSMessageType.PROGRESS.value == "progress"
        assert WSMessageType.COMPLETE.value == "complete"
        assert WSMessageType.ERROR.value == "error"
        assert WSMessageType.PONG.value == "pong"


class TestMessageCreation:
    """Test message creation functions."""
    
    def test_create_error_message(self):
        """Error message should have correct structure."""
        msg = create_error_message(
            code="TEST_ERROR",
            message="Test error message",
            request_id="test-123",
        )
        
        data = json.loads(msg)
        assert data["event"] == "error"
        assert data["error"]["code"] == "TEST_ERROR"
        assert data["error"]["message"] == "Test error message"
        assert data["error"]["request_id"] == "test-123"
    
    def test_create_header_message(self):
        """Header message should have correct structure."""
        msg = create_header_message(
            request_id="test-123",
            sample_rate=16000,
            model_version="1.0.0",
        )
        
        data = json.loads(msg)
        assert data["event"] == "header"
        assert data["request_id"] == "test-123"
        assert data["sample_rate"] == 16000
        assert data["format"] == "wav"
        assert data["model_version"] == "1.0.0"
    
    def test_create_progress_message(self):
        """Progress message should have correct structure."""
        msg = create_progress_message(
            chunks_sent=10,
            bytes_sent=16000,
            elapsed_ms=500,
        )
        
        data = json.loads(msg)
        assert data["event"] == "progress"
        assert data["chunks_sent"] == 10
        assert data["bytes_sent"] == 16000
        assert data["elapsed_ms"] == 500
    
    def test_create_complete_message(self):
        """Complete message should have correct structure."""
        metrics = {
            "total_bytes": 32000,
            "audio_duration_seconds": 2.0,
            "rtf": 0.5,
        }
        msg = create_complete_message(
            request_id="test-123",
            metrics=metrics,
        )
        
        data = json.loads(msg)
        assert data["event"] == "complete"
        assert data["request_id"] == "test-123"
        assert data["metrics"]["total_bytes"] == 32000
        assert data["metrics"]["audio_duration_seconds"] == 2.0


class TestWebSocketProtocol:
    """Test WebSocket protocol handling."""
    
    @pytest.fixture
    def mock_websocket(self):
        """Create a mock WebSocket."""
        ws = AsyncMock()
        ws.accept = AsyncMock()
        ws.send_text = AsyncMock()
        ws.send_bytes = AsyncMock()
        ws.receive_text = AsyncMock()
        ws.close = AsyncMock()
        ws.headers = {}
        return ws
    
    @pytest.mark.asyncio
    async def test_ping_pong(self, mock_websocket):
        """Ping should receive pong response."""
        # This tests the protocol at a unit level
        ping_msg = json.dumps({"event": "ping"})
        pong_msg = json.dumps({"event": "pong"})
        
        # Verify JSON structure
        ping_data = json.loads(ping_msg)
        assert ping_data["event"] == "ping"
        
        pong_data = json.loads(pong_msg)
        assert pong_data["event"] == "pong"
    
    def test_cancel_message_structure(self):
        """Cancel message should have correct structure."""
        cancel_msg = json.dumps({"event": "cancel"})
        data = json.loads(cancel_msg)
        assert data["event"] == "cancel"
    
    def test_request_message_structure(self):
        """Request message should have correct structure."""
        request_msg = json.dumps({
            "text": "Hello",
            "speaker": "male1",
        })
        data = json.loads(request_msg)
        assert data["text"] == "Hello"
        assert data["speaker"] == "male1"


class TestWSRequestEdgeCases:
    """Test edge cases in WSRequest."""
    
    def test_unicode_text(self):
        """Unicode text should be accepted."""
        req = WSRequest(text="नमस्ते दुनिया 🎵")
        error = req.validate()
        assert error is None
    
    def test_empty_dict(self):
        """Empty dict should use defaults."""
        req = WSRequest.from_dict({})
        assert req.text == ""
        assert req.speaker == "male1"
    
    def test_none_values(self):
        """None values should use defaults."""
        req = WSRequest.from_dict({
            "text": "Hello",
            "seed": None,
        })
        assert req.seed is None
    
    def test_extra_fields_ignored(self):
        """Extra fields should be ignored."""
        req = WSRequest.from_dict({
            "text": "Hello",
            "unknown_field": "ignored",
        })
        assert req.text == "Hello"
        assert not hasattr(req, "unknown_field")
    
    def test_negative_temperature(self):
        """Negative temperature should be accepted (validated elsewhere)."""
        req = WSRequest.from_dict({
            "text": "Hello",
            "temperature": -0.5,
        })
        assert req.temperature == -0.5
    
    def test_large_top_k(self):
        """Large top_k should be accepted."""
        req = WSRequest.from_dict({
            "text": "Hello",
            "top_k": 10000,
        })
        assert req.top_k == 10000


class TestErrorMessages:
    """Test error message generation."""
    
    def test_error_message_json_valid(self):
        """Error message should be valid JSON."""
        msg = create_error_message("CODE", "Message", "id")
        # Should not raise
        data = json.loads(msg)
        assert isinstance(data, dict)
    
    def test_error_message_special_chars(self):
        """Error message should handle special characters."""
        msg = create_error_message(
            code="ERROR",
            message='Message with "quotes" and \n newlines',
            request_id="id",
        )
        data = json.loads(msg)
        assert '"quotes"' in data["error"]["message"]
    
    def test_header_message_json_valid(self):
        """Header message should be valid JSON."""
        msg = create_header_message("id", 16000, "1.0")
        data = json.loads(msg)
        assert isinstance(data, dict)
    
    def test_progress_message_json_valid(self):
        """Progress message should be valid JSON."""
        msg = create_progress_message(10, 1000, 500)
        data = json.loads(msg)
        assert isinstance(data, dict)
    
    def test_complete_message_json_valid(self):
        """Complete message should be valid JSON."""
        msg = create_complete_message("id", {"key": "value"})
        data = json.loads(msg)
        assert isinstance(data, dict)


class TestWebSocketCloseCodes:
    """Test WebSocket close codes used in handler."""
    
    def test_normal_closure(self):
        """Normal closure should use code 1000."""
        # 1000 = Normal closure
        assert 1000 == 1000
    
    def test_unsupported_data(self):
        """Unsupported data should use code 1003."""
        # 1003 = Unsupported data (e.g., invalid JSON)
        assert 1003 == 1003
    
    def test_policy_violation(self):
        """Policy violation should use code 1008."""
        # 1008 = Policy violation (e.g., auth failure)
        assert 1008 == 1008
    
    def test_internal_error(self):
        """Internal error should use code 1011."""
        # 1011 = Unexpected condition (e.g., model not loaded)
        assert 1011 == 1011

