"""
Integration tests for /v1/tts/generate endpoint (M3).

These tests require the TTS model to be loaded.
Skip gracefully if model/GPU not available.
"""

import os
import pytest


# Check if we can run integration tests
def _can_run_integration_tests():
    """Check if we have the prerequisites for integration tests."""
    try:
        import torch
        if not torch.cuda.is_available():
            return False, "No GPU available"
    except ImportError:
        return False, "PyTorch not installed"
    
    # Check for model path
    model_path = os.environ.get('SPARK_TTS_MODEL_PATH') or os.environ.get('MODEL_PATH')
    if not model_path:
        # Check default paths
        default_paths = [
            '/models/spark_tts_4speaker',
            '/home/ubuntu/veena3/models/spark_tts_4speaker',
        ]
        for path in default_paths:
            if os.path.exists(path):
                return True, None
        return False, "Model path not found"
    
    if not os.path.exists(model_path):
        return False, f"Model path does not exist: {model_path}"
    
    return True, None


CAN_RUN, SKIP_REASON = _can_run_integration_tests()


@pytest.fixture(scope="module")
def initialized_client():
    """
    Create a test client with initialized TTS runtime.
    
    This fixture initializes the model once per test module.
    """
    if not CAN_RUN:
        pytest.skip(f"Integration tests skipped: {SKIP_REASON}")
    
    from fastapi.testclient import TestClient
    from veena3modal.api.fastapi_app import create_app
    from veena3modal.services import tts_runtime
    
    # Initialize runtime if not already done
    if not tts_runtime.is_initialized():
        try:
            tts_runtime.initialize_runtime()
        except Exception as e:
            pytest.skip(f"Could not initialize TTS runtime: {e}")
    
    app = create_app()
    client = TestClient(app)
    
    yield client
    
    # Cleanup (optional - runtime persists for performance)


@pytest.mark.skipif(not CAN_RUN, reason=SKIP_REASON or "Integration test prerequisites not met")
class TestTTSGenerateEndpoint:
    """Test /v1/tts/generate endpoint with real model."""

    def test_simple_generation(self, initialized_client):
        """Basic non-streaming WAV generation."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Hello, this is a test.",
                "speaker": "lipakshi",
                "stream": False,
                "format": "wav",
            }
        )
        
        # Debug: print error response if not 200
        if response.status_code != 200:
            print(f"Error response: {response.status_code} - {response.text}")
        
        assert response.status_code == 200
        assert response.headers["content-type"] == "audio/wav"
        assert "X-Request-ID" in response.headers
        assert "X-Model-Version" in response.headers
        assert "X-Audio-Bytes" in response.headers
        assert "X-TTFB-ms" in response.headers
        
        # Check we got actual audio
        audio_bytes = response.content
        assert len(audio_bytes) > 44  # More than just WAV header
        assert audio_bytes[:4] == b'RIFF'  # WAV magic bytes

    def test_generation_with_emotion(self, initialized_client):
        """Generation with emotion tags."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "[laughs] That's so funny!",
                "speaker": "reet",
                "stream": False,
            }
        )
        
        assert response.status_code == 200
        assert len(response.content) > 44

    def test_generation_friendly_speaker(self, initialized_client):
        """Generation with friendly speaker name."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Testing friendly speaker name.",
                "speaker": "Mitra",  # Friendly name for lipakshi
                "stream": False,
            }
        )
        
        assert response.status_code == 200

    def test_generation_with_seed(self, initialized_client):
        """Generation with seed for reproducibility."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Testing with seed.",
                "speaker": "Nandini",
                "seed": 42,
                "stream": False,
            }
        )
        
        assert response.status_code == 200
        assert response.headers.get("X-Seed") == "42"

    def test_generation_hindi_text(self, initialized_client):
        """Generation with Hindi text."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "नमस्ते, यह एक परीक्षण है।",
                "speaker": "lipakshi",
                "stream": False,
            }
        )
        
        assert response.status_code == 200
        assert len(response.content) > 44

    def test_chunking_disabled(self, initialized_client):
        """Generation with chunking disabled."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Short text without chunking.",
                "speaker": "lipakshi",
                "chunking": False,
                "stream": False,
            }
        )
        
        assert response.status_code == 200
        assert response.headers.get("X-Text-Chunked") == "false"


@pytest.mark.skipif(not CAN_RUN, reason=SKIP_REASON or "Integration test prerequisites not met")
class TestTTSGenerateErrors:
    """Test error handling for /v1/tts/generate."""

    def test_missing_text(self, initialized_client):
        """Missing text should return 400."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "speaker": "lipakshi",
            }
        )
        
        assert response.status_code == 400
        assert "error" in response.json()

    def test_missing_speaker(self, initialized_client):
        """Missing speaker should return 400."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Hello",
            }
        )
        
        assert response.status_code == 400
        assert "error" in response.json()

    def test_invalid_speaker(self, initialized_client):
        """Invalid speaker should return 400."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Hello",
                "speaker": "invalid_speaker_name",
            }
        )
        
        assert response.status_code == 400

    def test_streaming_returns_audio(self, initialized_client):
        """Streaming should return chunked audio response."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Hello, this is a streaming test.",
                "speaker": "lipakshi",
                "stream": True,
            }
        )
        
        # Streaming returns 200 with audio content
        assert response.status_code == 200
        assert response.headers["content-type"] == "audio/wav"
        assert response.headers.get("X-Stream") == "true"
        
        # Check we got actual audio
        audio_bytes = response.content
        assert len(audio_bytes) > 44  # More than just WAV header
        assert audio_bytes[:4] == b'RIFF'  # WAV magic bytes

    def test_non_wav_format_not_implemented(self, initialized_client):
        """Non-WAV formats should return 501 (not implemented)."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Hello",
                "speaker": "lipakshi",
                "format": "opus",
            }
        )
        
        assert response.status_code == 501
        assert "FORMAT_NOT_IMPLEMENTED" in response.json().get("error", {}).get("code", "")


@pytest.mark.skipif(not CAN_RUN, reason=SKIP_REASON or "Integration test prerequisites not met")
class TestTTSStreaming:
    """Test streaming-specific functionality."""

    def test_streaming_with_emotion(self, initialized_client):
        """Streaming with emotion tags should work."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "[laughs] This is funny!",
                "speaker": "reet",
                "stream": True,
            }
        )
        
        assert response.status_code == 200
        assert response.headers.get("X-Stream") == "true"
        assert len(response.content) > 44

    def test_streaming_with_chunking(self, initialized_client):
        """Streaming with long text should use chunking."""
        # Long text that should trigger chunking (> 220 chars)
        long_text = "This is a longer text that should trigger the text chunking mechanism. " * 5
        
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": long_text,
                "speaker": "lipakshi",
                "stream": True,
                "chunking": True,
            }
        )
        
        assert response.status_code == 200
        assert response.headers.get("X-Stream") == "true"
        assert len(response.content) > 44

    def test_streaming_without_chunking(self, initialized_client):
        """Streaming with chunking disabled."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Short text without chunking.",
                "speaker": "lipakshi",
                "stream": True,
                "chunking": False,
            }
        )
        
        assert response.status_code == 200
        assert response.headers.get("X-Stream") == "true"
        assert response.headers.get("X-Chunking-Enabled") == "false"

    def test_streaming_hindi_text(self, initialized_client):
        """Streaming with Hindi text should work."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "नमस्ते, यह स्ट्रीमिंग परीक्षण है।",
                "speaker": "lipakshi",
                "stream": True,
            }
        )
        
        assert response.status_code == 200
        assert len(response.content) > 44

    def test_streaming_non_wav_not_implemented(self, initialized_client):
        """Streaming with non-WAV format should return 501."""
        response = initialized_client.post(
            "/v1/tts/generate",
            json={
                "text": "Hello",
                "speaker": "lipakshi",
                "stream": True,
                "format": "mp3",
            }
        )
        
        assert response.status_code == 501
        assert "FORMAT_NOT_IMPLEMENTED" in response.json().get("error", {}).get("code", "")

