"""
Live tests for deployed Modal endpoint.

These tests run against a real deployed Modal endpoint.
They require Modal credentials and a running deployment.

Usage:
    # Set the endpoint URL
    export MODAL_ENDPOINT_URL="https://your-modal-endpoint.modal.run"
    
    # Run tests
    pytest veena3modal/tests/modal_live/ -v
    
    # Or skip if not deployed
    pytest veena3modal/tests/modal_live/ -v --ignore-glob="*modal_live*"
"""

import os
import time
import json
import pytest
import httpx
from typing import Optional


# Skip all tests if endpoint URL not set
MODAL_ENDPOINT_URL = os.environ.get("MODAL_ENDPOINT_URL")
MODAL_API_KEY = os.environ.get("MODAL_API_KEY", "test-key")  # For auth testing

skip_if_no_endpoint = pytest.mark.skipif(
    not MODAL_ENDPOINT_URL,
    reason="MODAL_ENDPOINT_URL environment variable not set"
)


def get_endpoint_url(path: str) -> str:
    """Build full endpoint URL."""
    base = MODAL_ENDPOINT_URL.rstrip("/")
    return f"{base}{path}"


def make_headers(api_key: Optional[str] = None) -> dict:
    """Build request headers."""
    headers = {"Content-Type": "application/json"}
    if api_key:
        headers["Authorization"] = f"Bearer {api_key}"
    return headers


@skip_if_no_endpoint
class TestHealthEndpoint:
    """Test /v1/tts/health endpoint against deployed Modal."""
    
    def test_health_returns_200(self):
        """Health endpoint should return 200."""
        url = get_endpoint_url("/v1/tts/health")
        response = httpx.get(url, timeout=30.0)
        assert response.status_code == 200
    
    def test_health_response_structure(self):
        """Health response should have expected fields."""
        url = get_endpoint_url("/v1/tts/health")
        response = httpx.get(url, timeout=30.0)
        data = response.json()
        
        assert "status" in data
        assert data["status"] in ["healthy", "degraded", "unhealthy"]
        assert "model_loaded" in data
        assert "model_version" in data
        assert "uptime_seconds" in data
        assert "gpu_available" in data
    
    def test_health_headers(self):
        """Health response should include required headers."""
        url = get_endpoint_url("/v1/tts/health")
        response = httpx.get(url, timeout=30.0)
        
        assert "X-Model-Version" in response.headers
        assert "X-App-Version" in response.headers


@skip_if_no_endpoint
class TestMetricsEndpoint:
    """Test /v1/tts/metrics endpoint against deployed Modal."""
    
    def test_metrics_returns_200(self):
        """Metrics endpoint should return 200."""
        url = get_endpoint_url("/v1/tts/metrics")
        response = httpx.get(url, timeout=30.0)
        assert response.status_code == 200
    
    def test_metrics_content_type(self):
        """Metrics should return Prometheus text format."""
        url = get_endpoint_url("/v1/tts/metrics")
        response = httpx.get(url, timeout=30.0)
        
        content_type = response.headers.get("content-type", "")
        assert "text/plain" in content_type or "text/prometheus" in content_type
    
    def test_metrics_contains_expected_metrics(self):
        """Metrics should contain key TTS metrics."""
        url = get_endpoint_url("/v1/tts/metrics")
        response = httpx.get(url, timeout=30.0)
        text = response.text
        
        # Check for expected metric names
        assert "veena3_tts_" in text or "# HELP" in text


@skip_if_no_endpoint
class TestTTSGenerateNonStreaming:
    """Test /v1/tts/generate endpoint (non-streaming) against deployed Modal."""
    
    def test_simple_generation(self):
        """Simple TTS generation should return audio."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "Hello, this is a test.",
            "speaker": "lipakshi",
            "stream": False,
            "format": "wav",
        }
        
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=60.0,
        )
        
        # Should succeed (200) or require auth (401) depending on config
        assert response.status_code in [200, 401, 503]
        
        if response.status_code == 200:
            # Verify audio response
            assert response.headers.get("content-type") == "audio/wav"
            assert len(response.content) > 44  # At least WAV header
    
    def test_response_headers(self):
        """TTS response should include required headers."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "Testing headers.",
            "speaker": "reet",
            "stream": False,
        }
        
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=60.0,
        )
        
        if response.status_code == 200:
            assert "X-Request-ID" in response.headers
            assert "X-Model-Version" in response.headers
            assert "X-TTFB-ms" in response.headers
            assert "X-RTF" in response.headers
            assert "X-Audio-Seconds" in response.headers
    
    def test_validation_error_empty_text(self):
        """Empty text should return validation error."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "",
            "speaker": "lipakshi",
        }
        
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=30.0,
        )
        
        assert response.status_code == 400
    
    def test_validation_error_invalid_speaker(self):
        """Invalid speaker should return validation error."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "Hello",
            "speaker": "nonexistent_speaker_xyz",
        }
        
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=30.0,
        )
        
        # May be 400 (validation) or could map to default speaker
        assert response.status_code in [200, 400]


@skip_if_no_endpoint
class TestTTSGenerateStreaming:
    """Test /v1/tts/generate endpoint (streaming) against deployed Modal."""
    
    def test_streaming_returns_chunked_response(self):
        """Streaming should return chunked transfer encoding."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "This is a streaming test with a longer sentence for better results.",
            "speaker": "lipakshi",
            "stream": True,
            "format": "wav",
        }
        
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=120.0,
        )
        
        if response.status_code == 200:
            assert response.headers.get("content-type") == "audio/wav"
            assert "X-Stream" in response.headers
            assert response.headers.get("X-Stream") == "true"
    
    def test_streaming_ttfb_measurement(self):
        """Measure TTFB for streaming response."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "Measuring time to first byte for streaming audio generation.",
            "speaker": "reet",
            "stream": True,
        }
        
        start_time = time.time()
        first_chunk_time = None
        
        with httpx.stream(
            "POST",
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=120.0,
        ) as response:
            for chunk in response.iter_bytes(chunk_size=1024):
                if chunk and first_chunk_time is None:
                    first_chunk_time = time.time()
                    break
        
        if first_chunk_time:
            ttfb_ms = (first_chunk_time - start_time) * 1000
            print(f"\nStreaming TTFB: {ttfb_ms:.0f}ms")
            # TTFB target: < 1200ms for cold start, < 500ms warm
            assert ttfb_ms < 5000, f"TTFB too high: {ttfb_ms}ms"


@skip_if_no_endpoint
class TestPerformance:
    """Performance tests against deployed Modal endpoint."""
    
    @pytest.mark.slow
    def test_cold_start_time(self):
        """Measure cold start time (after scaledown)."""
        # This is hard to test reliably without controlling Modal scaling
        # For now, just measure response time
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "Cold start test.",
            "speaker": "lipakshi",
            "stream": False,
        }
        
        start = time.time()
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=120.0,  # Long timeout for cold start
        )
        duration = time.time() - start
        
        print(f"\nTotal request time: {duration:.2f}s")
        
        if response.status_code == 200:
            ttfb_header = response.headers.get("X-TTFB-ms", "0")
            print(f"Server TTFB: {ttfb_header}ms")
    
    @pytest.mark.slow
    def test_concurrent_requests(self):
        """Test concurrent request handling."""
        import asyncio
        import httpx as httpx_async
        
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "Concurrent test sentence.",
            "speaker": "lipakshi",
            "stream": False,
        }
        
        async def make_request(client, request_id):
            start = time.time()
            response = await client.post(
                url,
                json=payload,
                headers=make_headers(MODAL_API_KEY),
                timeout=120.0,
            )
            duration = time.time() - start
            return {
                "id": request_id,
                "status": response.status_code,
                "duration": duration,
            }
        
        async def run_concurrent(n_requests=3):
            async with httpx.AsyncClient() as client:
                tasks = [make_request(client, i) for i in range(n_requests)]
                return await asyncio.gather(*tasks)
        
        results = asyncio.run(run_concurrent(3))
        
        success_count = sum(1 for r in results if r["status"] == 200)
        avg_duration = sum(r["duration"] for r in results) / len(results)
        
        print(f"\nConcurrent results: {success_count}/{len(results)} successful")
        print(f"Average duration: {avg_duration:.2f}s")
        
        # At least some should succeed
        assert success_count > 0, "All concurrent requests failed"


@skip_if_no_endpoint
class TestErrorHandling:
    """Test error handling against deployed Modal endpoint."""
    
    def test_missing_api_key(self):
        """Request without API key should fail appropriately."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {"text": "Test", "speaker": "lipakshi"}
        
        response = httpx.post(
            url,
            json=payload,
            headers={"Content-Type": "application/json"},  # No auth header
            timeout=30.0,
        )
        
        # Should either work (auth disabled) or return 401
        assert response.status_code in [200, 401]
        
        if response.status_code == 401:
            data = response.json()
            assert "error" in data
    
    def test_invalid_json(self):
        """Invalid JSON should return 400."""
        url = get_endpoint_url("/v1/tts/generate")
        
        response = httpx.post(
            url,
            content="not valid json{",
            headers={"Content-Type": "application/json"},
            timeout=30.0,
        )
        
        assert response.status_code == 400
    
    def test_text_too_long(self):
        """Text exceeding 50K chars should return validation error."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "x" * 51000,  # Over 50K limit
            "speaker": "lipakshi",
        }
        
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=30.0,
        )
        
        assert response.status_code == 400


@skip_if_no_endpoint  
class TestAudioFormats:
    """Test audio format support against deployed Modal endpoint."""
    
    @pytest.mark.parametrize("format", ["wav", "opus", "mp3", "flac"])
    def test_non_streaming_formats(self, format):
        """Non-streaming should support multiple formats (may fall back to WAV)."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": f"Testing {format} format.",
            "speaker": "lipakshi",
            "stream": False,
            "format": format,
        }
        
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=60.0,
        )
        
        # Should succeed or return "not implemented" for some formats
        assert response.status_code in [200, 501]
        
        if response.status_code == 200:
            # Verify format header - may fall back to WAV if encoders unavailable
            format_header = response.headers.get("X-Format", "")
            # Accept requested format OR wav fallback (when encoders not available)
            assert format in format_header or format_header == format or format_header == "wav", \
                f"Expected format '{format}' or 'wav' fallback, got '{format_header}'"
    
    def test_streaming_non_wav_returns_501(self):
        """Streaming with non-WAV format should return 501."""
        url = get_endpoint_url("/v1/tts/generate")
        payload = {
            "text": "Testing streaming opus.",
            "speaker": "lipakshi",
            "stream": True,
            "format": "opus",
        }
        
        response = httpx.post(
            url,
            json=payload,
            headers=make_headers(MODAL_API_KEY),
            timeout=30.0,
        )
        
        # Should return 501 Not Implemented
        assert response.status_code == 501


@skip_if_no_endpoint
class TestWebSocket:
    """Test WebSocket TTS endpoint against deployed Modal."""
    
    @pytest.fixture(autouse=True)
    def websocket_cooldown(self):
        """Add cooldown between WebSocket tests to avoid connection issues."""
        import time
        time.sleep(0.5)  # Half second cooldown before each test
        yield
        time.sleep(0.5)  # Half second cooldown after each test
    
    def test_websocket_connection(self):
        """WebSocket should accept connection."""
        import websockets.sync.client as ws_client
        import time
        
        ws_url = MODAL_ENDPOINT_URL.replace("https://", "wss://").replace("http://", "ws://")
        ws_url = f"{ws_url.rstrip('/')}/v1/tts/ws"
        
        if MODAL_API_KEY:
            ws_url = f"{ws_url}?api_key={MODAL_API_KEY}"
        
        max_retries = 3
        last_error = None
        
        for attempt in range(max_retries):
            try:
                with ws_client.connect(ws_url, close_timeout=30, open_timeout=15) as ws:
                    # Send a simple request
                    ws.send(json.dumps({
                        "text": "Hello",
                        "speaker": "lipakshi",
                    }))
                    
                    # Should receive header message first
                    msg = ws.recv(timeout=30.0)
                    if isinstance(msg, str):
                        data = json.loads(msg)
                        assert data.get("event") in ["header", "error", "progress", "complete"]
                    return  # Success
            except Exception as e:
                last_error = e
                if attempt < max_retries - 1:
                    time.sleep(1)
                    continue
        
        pytest.skip(f"WebSocket connection failed after {max_retries} attempts: {last_error}")
    
    def test_websocket_ping_pong(self):
        """WebSocket should respond to ping."""
        import websockets.sync.client as ws_client
        import time
        
        # Add small delay to avoid connection reuse issues
        time.sleep(0.5)
        
        ws_url = MODAL_ENDPOINT_URL.replace("https://", "wss://").replace("http://", "ws://")
        ws_url = f"{ws_url.rstrip('/')}/v1/tts/ws"
        
        if MODAL_API_KEY:
            ws_url = f"{ws_url}?api_key={MODAL_API_KEY}"
        
        max_retries = 3
        last_error = None
        
        for attempt in range(max_retries):
            try:
                with ws_client.connect(ws_url, close_timeout=30, open_timeout=15) as ws:
                    ws.send(json.dumps({"event": "ping"}))
                    msg = ws.recv(timeout=15.0)
                    data = json.loads(msg)
                    assert data.get("event") == "pong", f"Expected pong, got {data}"
                    return  # Success
            except Exception as e:
                last_error = e
                if attempt < max_retries - 1:
                    time.sleep(1)  # Wait before retry
                    continue
        
        pytest.skip(f"WebSocket ping test failed after {max_retries} attempts: {last_error}")
    
    def test_websocket_streaming_audio(self):
        """WebSocket should stream audio chunks."""
        import websockets.sync.client as ws_client
        from websockets.exceptions import ConnectionClosed, ConnectionClosedError
        
        ws_url = MODAL_ENDPOINT_URL.replace("https://", "wss://").replace("http://", "ws://")
        ws_url = f"{ws_url.rstrip('/')}/v1/tts/ws"
        
        if MODAL_API_KEY:
            ws_url = f"{ws_url}?api_key={MODAL_API_KEY}"
        
        header_received = False
        audio_chunks = 0
        complete_received = False
        error_msg = None
        
        try:
            with ws_client.connect(ws_url, close_timeout=120, open_timeout=30) as ws:
                # Send request
                ws.send(json.dumps({
                    "text": "Testing WebSocket streaming audio generation.",
                    "speaker": "lipakshi",
                }))
                
                # Collect messages with timeout
                max_wait = 120  # seconds
                import time
                start = time.time()
                
                while time.time() - start < max_wait:
                    try:
                        msg = ws.recv(timeout=30.0)
                    except TimeoutError:
                        print(f"  ⚠️ Recv timeout, continuing...")
                        continue
                    
                    if isinstance(msg, bytes):
                        audio_chunks += 1
                        if audio_chunks == 1:
                            print(f"  📦 First audio chunk: {len(msg)} bytes")
                    else:
                        try:
                            data = json.loads(msg)
                        except json.JSONDecodeError:
                            print(f"  ⚠️ Non-JSON message: {msg[:100]}")
                            continue
                            
                        event = data.get("event")
                        
                        if event == "header":
                            header_received = True
                            print(f"  📋 Header received")
                        elif event == "complete":
                            complete_received = True
                            print(f"  ✅ Complete received")
                            break
                        elif event == "error":
                            error_msg = data.get("message", str(data))
                            print(f"  ❌ Error: {error_msg}")
                            break
                        elif event == "progress":
                            pass  # Ignore progress updates
                        else:
                            print(f"  📩 Event: {event}")
                
        except ConnectionClosed as e:
            if audio_chunks > 0 and header_received:
                # Connection closed after receiving some audio - that's OK
                print(f"  Connection closed after {audio_chunks} chunks")
            else:
                pytest.skip(f"WebSocket connection closed: {e}")
        except Exception as e:
            pytest.skip(f"WebSocket streaming test error: {type(e).__name__}: {e}")
        
        # Validate results
        if error_msg:
            pytest.skip(f"Server returned error: {error_msg}")
        
        print(f"\n📊 WebSocket Results:")
        print(f"   Header received: {header_received}")
        print(f"   Audio chunks: {audio_chunks}")
        print(f"   Complete received: {complete_received}")
        
        # At minimum, we should receive header and some audio
        assert header_received or audio_chunks > 0, "Should receive header or audio chunks"
        
        if header_received and audio_chunks > 0:
            print(f"✅ WebSocket streaming working: {audio_chunks} audio chunks received")


if __name__ == "__main__":
    # Quick manual test
    if MODAL_ENDPOINT_URL:
        print(f"Testing against: {MODAL_ENDPOINT_URL}")
        
        # Health check
        response = httpx.get(get_endpoint_url("/v1/tts/health"), timeout=30.0)
        print(f"Health: {response.status_code} - {response.json()}")
    else:
        print("Set MODAL_ENDPOINT_URL to run tests")

