#!/usr/bin/env python3
"""
Quick validation script for Modal TTS endpoint.

Usage:
    python scripts/validate_modal_endpoint.py
    
    # Or with custom URL
    MODAL_ENDPOINT_URL=https://your-endpoint.modal.run python scripts/validate_modal_endpoint.py
"""

import os
import sys
import time
import json
import httpx
import asyncio
from typing import Optional

# Default endpoint
ENDPOINT_URL = os.environ.get(
    "MODAL_ENDPOINT_URL",
    "https://mayaresearch--veena3-tts-ttsservice-serve.modal.run"
)

# Test speakers
TEST_SPEAKERS = ["Aarvi", "lipakshi", "vardan"]


def print_header(title: str):
    """Print a section header."""
    print(f"\n{'='*60}")
    print(f" {title}")
    print('='*60)


def test_health():
    """Test health endpoint."""
    print_header("Health Check")
    
    url = f"{ENDPOINT_URL}/v1/tts/health"
    print(f"URL: {url}")
    
    start = time.time()
    response = httpx.get(url, timeout=60.0)
    duration = time.time() - start
    
    print(f"Status: {response.status_code}")
    print(f"Duration: {duration:.2f}s")
    print(f"Response: {json.dumps(response.json(), indent=2)}")
    
    if response.status_code == 200:
        data = response.json()
        if data.get("model_loaded"):
            print("✅ Model loaded and ready!")
            return True
        else:
            print("⚠️ Model not loaded yet")
            return False
    return False


def test_metrics():
    """Test metrics endpoint."""
    print_header("Metrics Check")
    
    url = f"{ENDPOINT_URL}/v1/tts/metrics"
    print(f"URL: {url}")
    
    response = httpx.get(url, timeout=30.0)
    print(f"Status: {response.status_code}")
    
    if response.status_code == 200:
        lines = response.text.split('\n')[:10]
        print("Sample metrics:")
        for line in lines:
            print(f"  {line}")
        print("✅ Metrics endpoint working!")
        return True
    return False


def test_generate_validation():
    """Test request validation."""
    print_header("Request Validation")
    
    url = f"{ENDPOINT_URL}/v1/tts/generate"
    
    # Test empty text
    response = httpx.post(
        url,
        json={"text": "", "speaker": "Aarvi"},
        timeout=30.0
    )
    print(f"Empty text: {response.status_code} (expected 400)")
    assert response.status_code == 400, "Empty text should be rejected"
    
    # Test invalid speaker
    response = httpx.post(
        url,
        json={"text": "Hello", "speaker": "invalid_speaker_xyz"},
        timeout=30.0
    )
    print(f"Invalid speaker: {response.status_code} (expected 400)")
    assert response.status_code == 400, "Invalid speaker should be rejected"
    
    print("✅ Validation working!")
    return True


def test_generate_audio(speaker: str = "Aarvi"):
    """Test audio generation (non-streaming)."""
    print_header(f"Audio Generation (speaker: {speaker})")
    
    url = f"{ENDPOINT_URL}/v1/tts/generate"
    text = "Hello, this is a test of the text to speech system."
    
    print(f"Text: {text}")
    print(f"Speaker: {speaker}")
    
    start = time.time()
    response = httpx.post(
        url,
        json={
            "text": text,
            "speaker": speaker,
            "stream": False,
            "format": "wav",
        },
        timeout=120.0
    )
    duration = time.time() - start
    
    print(f"Status: {response.status_code}")
    print(f"Duration: {duration:.2f}s")
    
    if response.status_code == 200:
        # Check headers
        headers = dict(response.headers)
        print(f"Content-Type: {headers.get('content-type')}")
        print(f"Audio bytes: {len(response.content)}")
        print(f"TTFB: {headers.get('x-ttfb-ms', 'N/A')}ms")
        print(f"RTF: {headers.get('x-rtf', 'N/A')}")
        print(f"Audio duration: {headers.get('x-audio-seconds', 'N/A')}s")
        
        # Save audio file
        filename = f"/tmp/test_output_{speaker}.wav"
        with open(filename, "wb") as f:
            f.write(response.content)
        print(f"Saved to: {filename}")
        print("✅ Audio generation working!")
        return True
    else:
        print(f"Error: {response.text}")
        return False


def test_generate_streaming(speaker: str = "Aarvi"):
    """Test streaming audio generation."""
    print_header(f"Streaming Generation (speaker: {speaker})")
    
    url = f"{ENDPOINT_URL}/v1/tts/generate"
    text = "This is a test of the streaming text to speech functionality. It should send audio chunks progressively."
    
    print(f"Text: {text[:50]}...")
    print(f"Speaker: {speaker}")
    
    start = time.time()
    first_chunk_time = None
    total_bytes = 0
    chunk_count = 0
    
    with httpx.stream(
        "POST",
        url,
        json={
            "text": text,
            "speaker": speaker,
            "stream": True,
            "format": "wav",
        },
        timeout=120.0
    ) as response:
        print(f"Status: {response.status_code}")
        
        if response.status_code != 200:
            print(f"Error: {response.text}")
            return False
        
        # Collect chunks
        audio_data = b""
        for chunk in response.iter_bytes(chunk_size=4096):
            if chunk:
                chunk_count += 1
                total_bytes += len(chunk)
                audio_data += chunk
                
                if first_chunk_time is None:
                    first_chunk_time = time.time()
                    ttfb = (first_chunk_time - start) * 1000
                    print(f"First chunk received! TTFB: {ttfb:.0f}ms")
    
    total_time = time.time() - start
    
    print(f"Total chunks: {chunk_count}")
    print(f"Total bytes: {total_bytes}")
    print(f"Total time: {total_time:.2f}s")
    
    # Save audio
    filename = f"/tmp/test_streaming_{speaker}.wav"
    with open(filename, "wb") as f:
        f.write(audio_data)
    print(f"Saved to: {filename}")
    print("✅ Streaming generation working!")
    return True


def test_websocket():
    """Test WebSocket endpoint."""
    print_header("WebSocket Test")
    
    try:
        import websockets.sync.client as ws_client
    except ImportError:
        print("⚠️ websockets not installed, skipping WebSocket test")
        return None
    
    ws_url = ENDPOINT_URL.replace("https://", "wss://").replace("http://", "ws://")
    ws_url = f"{ws_url}/v1/tts/ws"
    
    print(f"URL: {ws_url}")
    
    try:
        with ws_client.connect(ws_url, close_timeout=30) as ws:
            # Send request
            ws.send(json.dumps({
                "text": "Hello via WebSocket!",
                "speaker": "Aarvi",
            }))
            
            # Receive header
            msg = ws.recv(timeout=60.0)
            if isinstance(msg, str):
                data = json.loads(msg)
                print(f"Header received: {data.get('event')}")
            
            # Receive some audio chunks
            chunk_count = 0
            audio_data = b""
            
            while True:
                try:
                    msg = ws.recv(timeout=60.0)
                    if isinstance(msg, bytes):
                        chunk_count += 1
                        audio_data += msg
                    else:
                        data = json.loads(msg)
                        if data.get("event") == "complete":
                            print(f"Complete: {data.get('metrics', {})}")
                            break
                        elif data.get("event") == "progress":
                            print(f"Progress: chunks={data.get('chunks_sent')}")
                except Exception as e:
                    print(f"Error receiving: {e}")
                    break
            
            print(f"Total chunks: {chunk_count}")
            print(f"Total bytes: {len(audio_data)}")
            print("✅ WebSocket working!")
            return True
            
    except Exception as e:
        print(f"❌ WebSocket error: {e}")
        return False


def main():
    """Run all validation tests."""
    print(f"\n🔍 Validating Modal Endpoint: {ENDPOINT_URL}\n")
    
    results = {}
    
    # Health check
    try:
        results["health"] = test_health()
    except Exception as e:
        print(f"❌ Health check failed: {e}")
        results["health"] = False
    
    # Metrics
    try:
        results["metrics"] = test_metrics()
    except Exception as e:
        print(f"❌ Metrics check failed: {e}")
        results["metrics"] = False
    
    # Validation
    try:
        results["validation"] = test_generate_validation()
    except Exception as e:
        print(f"❌ Validation check failed: {e}")
        results["validation"] = False
    
    # Only test generation if model is loaded
    if results.get("health"):
        # Non-streaming
        try:
            results["generate"] = test_generate_audio()
        except Exception as e:
            print(f"❌ Generation failed: {e}")
            results["generate"] = False
        
        # Streaming
        try:
            results["streaming"] = test_generate_streaming()
        except Exception as e:
            print(f"❌ Streaming failed: {e}")
            results["streaming"] = False
        
        # WebSocket
        try:
            results["websocket"] = test_websocket()
        except Exception as e:
            print(f"❌ WebSocket failed: {e}")
            results["websocket"] = False
    else:
        print("\n⚠️ Skipping generation tests - model not loaded")
        results["generate"] = None
        results["streaming"] = None
        results["websocket"] = None
    
    # Summary
    print_header("Summary")
    for test, result in results.items():
        if result is True:
            status = "✅ PASS"
        elif result is False:
            status = "❌ FAIL"
        else:
            status = "⏸️ SKIP"
        print(f"  {test}: {status}")
    
    # Exit code
    failures = sum(1 for r in results.values() if r is False)
    if failures > 0:
        print(f"\n❌ {failures} test(s) failed")
        sys.exit(1)
    else:
        print("\n✅ All tests passed!")
        sys.exit(0)


if __name__ == "__main__":
    main()

