#!/usr/bin/env python3
"""
Comprehensive Spark-TTS Validation Script

Tests:
1. Large multilingual text (EN + Telugu + Hindi)
2. ASR validation with word-level matching
3. Chunking behavior analysis
4. Emotion tag testing
5. TTFB measurements
6. Concurrency testing
7. End-to-end robustness

Usage:
    python scripts/comprehensive_validation.py
"""

import os
import sys
import time
import wave
import json
import asyncio
import requests
from pathlib import Path
from typing import List, Dict, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import tempfile

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))

# OpenAI for ASR (using Whisper API)
import openai

# Constants
API_BASE_URL = "http://localhost:8000"
API_KEY = None  # Will be created dynamically
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# Test texts
MULTILINGUAL_TEXT_LARGE = """
Hello everyone! This is a comprehensive test of the Spark TTS system. We are testing the ability to generate natural speech from text. This system supports multiple languages and emotions.

The quick brown fox jumps over the lazy dog. The five boxing wizards jump quickly. How vexingly quick daft zebras jump. Pack my box with five dozen liquor jugs.

మీరు ఎలా ఉన్నారు? నేను చాలా బాగున్నాను. స్పార్క్ టీటీఎస్ వ్యవస్థ బహుళ భాషలను మద్దతు ఇస్తుంది. ఇది చాలా ఆశ్చర్యకరమైన సాంకేతికత.

आप कैसे हैं? मैं बहुत अच्छा हूं। स्पार्क टीटीएस प्रणाली कई भाषाओं का समर्थन करती है। यह एक अद्भुत तकनीक है।

Thank you for testing this system with us. We appreciate your patience. The future of voice technology is multilingual and expressive.

ధన్యవాదాలు! మీ సహకారానికి కృతజ్ఞతలు. భవిష్యత్తు టెక్నాలజీ చాలా రోమాంచకంగా ఉంది.

धन्यवाद! आपके सहयोग के लिए आभारी हूं। भविष्य की तकनीक बहुत रोमांचक है।

This concludes our comprehensive multilingual test. Have a wonderful day!
""".strip()

SPEAKERS = ["lipakshi", "vardan", "reet", "Nandini", "krishna", "anika", "adarsh", "Nilay"]
EMOTIONS = ["[excited]", "[laughs]", "[curious]", "[giggle]", "[whispers]", "[sighs]", "[angry]", "[screams]", "[sings]"]


class Colors:
    """ANSI color codes"""
    GREEN = '\033[92m'
    RED = '\033[91m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'
    MAGENTA = '\033[95m'
    CYAN = '\033[96m'
    RESET = '\033[0m'
    BOLD = '\033[1m'


def log(msg: str, color: str = ""):
    """Print colored log message"""
    print(f"{color}{msg}{Colors.RESET}")


def log_success(msg: str):
    log(f"✅ {msg}", Colors.GREEN)


def log_error(msg: str):
    log(f"❌ {msg}", Colors.RED)


def log_warning(msg: str):
    log(f"⚠️  {msg}", Colors.YELLOW)


def log_info(msg: str):
    log(f"ℹ️  {msg}", Colors.BLUE)


def log_section(title: str):
    log(f"\n{'='*80}", Colors.CYAN)
    log(f"{title}", Colors.CYAN + Colors.BOLD)
    log(f"{'='*80}\n", Colors.CYAN)


def get_audio_duration(wav_path: str) -> float:
    """Get duration of WAV file in seconds"""
    with wave.open(wav_path, 'rb') as wf:
        frames = wf.getnframes()
        rate = wf.getframerate()
        return frames / rate


def transcribe_audio_openai(wav_path: str) -> str:
    """Transcribe audio using OpenAI Whisper API"""
    if not OPENAI_API_KEY:
        log_error("OPENAI_API_KEY not set in environment")
        return ""
    
    client = openai.OpenAI(api_key=OPENAI_API_KEY)
    
    try:
        with open(wav_path, "rb") as audio_file:
            transcript = client.audio.transcriptions.create(
                model="whisper-1",
                file=audio_file,
                response_format="text"
            )
        return transcript.strip()
    except Exception as e:
        log_error(f"ASR transcription failed: {e}")
        return ""


def calculate_wer(reference: str, hypothesis: str) -> float:
    """
    Calculate Word Error Rate (WER)
    Simple implementation without dependencies
    """
    ref_words = reference.lower().split()
    hyp_words = hypothesis.lower().split()
    
    # Levenshtein distance for words
    m, n = len(ref_words), len(hyp_words)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if ref_words[i-1] == hyp_words[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
    
    wer = dp[m][n] / m if m > 0 else 0.0
    return wer


def check_server_health() -> bool:
    """Check if server is healthy"""
    try:
        response = requests.get(f"{API_BASE_URL}/health", timeout=5)
        if response.status_code == 200:
            data = response.json()
            log_success(f"Server healthy: {data.get('service', 'unknown')}")
            return True
        else:
            log_error(f"Server unhealthy: status {response.status_code}")
            return False
    except Exception as e:
        log_error(f"Cannot connect to server: {e}")
        return False


def generate_tts(text: str, speaker: str = "lipakshi", stream: bool = False, timeout: int = 120) -> Tuple[bytes, Dict]:
    """
    Generate TTS audio
    
    Returns:
        (audio_bytes, headers_dict)
    """
    url = f"{API_BASE_URL}/v1/tts/generate"
    headers = {"Content-Type": "application/json"}
    payload = {
        "text": text,
        "speaker": speaker,
        "stream": stream
    }
    
    try:
        response = requests.post(url, json=payload, headers=headers, timeout=timeout)
        
        if response.status_code == 200:
            return response.content, dict(response.headers)
        else:
            log_error(f"TTS generation failed: {response.status_code}")
            log_error(f"Response: {response.text[:500]}")
            return None, {}
    except Exception as e:
        log_error(f"TTS request failed: {e}")
        return None, {}


def test_multilingual_generation():
    """Test 1: Large multilingual text generation with ASR validation"""
    log_section("TEST 1: Multilingual Generation + ASR Validation")
    
    log_info(f"Text length: {len(MULTILINGUAL_TEXT_LARGE)} chars")
    log_info("Languages: English, Telugu, Hindi")
    
    # Generate audio
    log_info("Generating audio...")
    t_start = time.time()
    audio_bytes, headers = generate_tts(MULTILINGUAL_TEXT_LARGE, speaker="lipakshi", stream=False, timeout=180)
    t_end = time.time()
    
    if not audio_bytes:
        log_error("Failed to generate audio")
        return False
    
    generation_time = t_end - t_start
    log_success(f"Generated {len(audio_bytes)} bytes in {generation_time:.2f}s")
    
    # Save audio
    temp_path = "/tmp/multilingual_test.wav"
    with open(temp_path, "wb") as f:
        f.write(audio_bytes)
    
    # Check duration
    duration = get_audio_duration(temp_path)
    log_info(f"Audio duration: {duration:.2f}s")
    
    # Expected duration check (rough estimate: ~150 words at ~2 words/sec = ~75s)
    word_count = len(MULTILINGUAL_TEXT_LARGE.split())
    expected_duration = word_count / 2.5  # ~2.5 words per second
    
    if duration > expected_duration * 3:
        log_error(f"Audio duration suspiciously long! {duration:.2f}s vs expected ~{expected_duration:.2f}s")
        return False
    elif duration < expected_duration * 0.3:
        log_error(f"Audio duration suspiciously short! {duration:.2f}s vs expected ~{expected_duration:.2f}s")
        return False
    else:
        log_success(f"Audio duration reasonable ({duration:.2f}s vs expected ~{expected_duration:.2f}s)")
    
    # ASR validation
    log_info("Transcribing audio with OpenAI Whisper...")
    transcript = transcribe_audio_openai(temp_path)
    
    if not transcript:
        log_warning("ASR transcription returned empty")
        return False
    
    log_info(f"Transcript ({len(transcript)} chars): {transcript[:200]}...")
    
    # Calculate WER (English words only for simplicity)
    english_words_original = " ".join([word for word in MULTILINGUAL_TEXT_LARGE.split() if word.encode().isalpha()])
    english_words_transcript = " ".join([word for word in transcript.split() if word.encode().isalpha()])
    
    wer = calculate_wer(english_words_original, english_words_transcript)
    log_info(f"Word Error Rate (English): {wer*100:.1f}%")
    
    if wer < 0.3:
        log_success(f"WER acceptable: {wer*100:.1f}% < 30%")
        return True
    else:
        log_warning(f"WER high: {wer*100:.1f}% >= 30%")
        return False


def test_chunking_behavior():
    """Test 2: Analyze chunking behavior with debug logs"""
    log_section("TEST 2: Chunking Behavior Analysis")
    
    # Test various text lengths to trigger chunking
    test_cases = [
        ("Short text (no chunking)", "Hello world!", 100),
        ("Medium text (~500 chars)", "Hello! " * 70, 500),  # ~490 chars
        ("Long text (>600 chars)", MULTILINGUAL_TEXT_LARGE[:700], 700),
    ]
    
    results = []
    
    for name, text, expected_len in test_cases:
        log_info(f"\nTesting: {name} ({len(text)} chars)")
        
        t_start = time.time()
        audio_bytes, headers = generate_tts(text, speaker="lipakshi", stream=False, timeout=120)
        t_end = time.time()
        
        if audio_bytes:
            duration = len(audio_bytes) / (2 * 16000)  # rough estimate
            log_success(f"Generated in {t_end - t_start:.2f}s, audio ~{duration:.2f}s")
            results.append((name, True, t_end - t_start))
        else:
            log_error(f"Generation failed")
            results.append((name, False, 0))
    
    # Summary
    log_info("\nChunking Test Summary:")
    for name, success, duration in results:
        status = "✅ PASS" if success else "❌ FAIL"
        log(f"  {status} - {name}: {duration:.2f}s", Colors.GREEN if success else Colors.RED)
    
    return all(success for _, success, _ in results)


def test_emotions():
    """Test 3: Emotion tags with multiple speakers"""
    log_section("TEST 3: Emotion Tag Testing")
    
    test_cases = [
        ("[excited] Hello everyone!", "lipakshi"),
        ("[laughs] That's so funny!", "vardan"),
        ("[whispers] This is a secret.", "reet"),
        ("[angry] I can't believe this!", "krishna"),
    ]
    
    results = []
    
    for text, speaker in test_cases:
        log_info(f"\nTesting: {text[:40]}... with speaker={speaker}")
        
        audio_bytes, headers = generate_tts(text, speaker=speaker, stream=False, timeout=60)
        
        if audio_bytes:
            log_success(f"Generated {len(audio_bytes)} bytes")
            results.append(True)
        else:
            log_error("Generation failed")
            results.append(False)
    
    success_rate = sum(results) / len(results) * 100
    log_info(f"\nEmotion Test Success Rate: {success_rate:.0f}%")
    
    return success_rate >= 75


def test_ttfb():
    """Test 4: Measure Time To First Byte"""
    log_section("TEST 4: TTFB Measurement")
    
    test_text = "Hello! This is a test for measuring time to first byte."
    
    # Non-streaming TTFB
    log_info("Testing non-streaming TTFB...")
    t_start = time.time()
    audio_bytes, headers = generate_tts(test_text, speaker="lipakshi", stream=False, timeout=60)
    ttfb_non_stream = time.time() - t_start
    
    if audio_bytes:
        log_success(f"Non-streaming TTFB: {ttfb_non_stream*1000:.0f}ms")
    else:
        log_error("Non-streaming generation failed")
        return False
    
    # Streaming TTFB (if supported)
    log_info("Testing streaming TTFB...")
    try:
        url = f"{API_BASE_URL}/v1/tts/generate"
        payload = {"text": test_text, "speaker": "lipakshi", "stream": True}
        
        t_start = time.time()
        response = requests.post(url, json=payload, stream=True, timeout=60)
        
        # Read first chunk
        first_chunk = None
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                first_chunk = chunk
                break
        
        ttfb_stream = time.time() - t_start
        
        if first_chunk:
            log_success(f"Streaming TTFB: {ttfb_stream*1000:.0f}ms")
        else:
            log_warning("Streaming returned no data")
            
    except Exception as e:
        log_warning(f"Streaming test failed: {e}")
    
    # Check if TTFB meets targets
    if ttfb_non_stream < 5.0:
        log_success(f"Non-streaming TTFB meets target (<5s)")
        return True
    else:
        log_warning(f"Non-streaming TTFB above target: {ttfb_non_stream:.2f}s")
        return False


def test_concurrency():
    """Test 5: Concurrent request handling"""
    log_section("TEST 5: Concurrency Testing")
    
    test_text = "Hello! This is a concurrent test."
    concurrency_levels = [1, 3, 5]
    
    results = {}
    
    for n_concurrent in concurrency_levels:
        log_info(f"\nTesting with {n_concurrent} concurrent requests...")
        
        def make_request(idx):
            t_start = time.time()
            audio_bytes, headers = generate_tts(
                f"{test_text} Request {idx}",
                speaker=SPEAKERS[idx % len(SPEAKERS)],
                stream=False,
                timeout=120
            )
            duration = time.time() - t_start
            return (idx, audio_bytes is not None, duration)
        
        with ThreadPoolExecutor(max_workers=n_concurrent) as executor:
            futures = [executor.submit(make_request, i) for i in range(n_concurrent)]
            
            request_results = []
            for future in as_completed(futures):
                request_results.append(future.result())
        
        # Analyze results
        success_count = sum(1 for _, success, _ in request_results if success)
        avg_duration = sum(dur for _, _, dur in request_results) / len(request_results)
        max_duration = max(dur for _, _, dur in request_results)
        
        log_info(f"  Success: {success_count}/{n_concurrent}")
        log_info(f"  Avg duration: {avg_duration:.2f}s")
        log_info(f"  Max duration: {max_duration:.2f}s")
        
        results[n_concurrent] = {
            "success_rate": success_count / n_concurrent,
            "avg_duration": avg_duration,
            "max_duration": max_duration
        }
    
    # Check if all requests succeeded
    all_success = all(r["success_rate"] == 1.0 for r in results.values())
    
    if all_success:
        log_success("All concurrent requests succeeded")
        return True
    else:
        log_warning("Some concurrent requests failed")
        return False


def test_robustness():
    """Test 6: End-to-end robustness with various edge cases"""
    log_section("TEST 6: Robustness Testing")
    
    test_cases = [
        ("Empty emotion tag", "[] Hello world", "lipakshi"),
        ("Multiple emotions", "[excited] [laughs] Hello!", "vardan"),
        ("Very short text", "Hi", "reet"),
        ("Numbers and punctuation", "1234567890!@#$%", "krishna"),
        ("Mixed case emotions", "[EXCITED] hello [Laughs] there", "anika"),
    ]
    
    results = []
    
    for name, text, speaker in test_cases:
        log_info(f"\nTesting: {name}")
        log_info(f"  Text: {text[:60]}")
        
        try:
            audio_bytes, headers = generate_tts(text, speaker=speaker, stream=False, timeout=60)
            
            if audio_bytes:
                log_success(f"Handled gracefully: {len(audio_bytes)} bytes")
                results.append(True)
            else:
                log_warning("Failed to generate")
                results.append(False)
        except Exception as e:
            log_error(f"Exception: {e}")
            results.append(False)
    
    success_rate = sum(results) / len(results) * 100
    log_info(f"\nRobustness Test Success Rate: {success_rate:.0f}%")
    
    return success_rate >= 60  # Allow some failures for edge cases


def main():
    """Run all validation tests"""
    log_section("🚀 Comprehensive Spark-TTS Validation Suite")
    
    # Check server health
    if not check_server_health():
        log_error("Server is not healthy. Please start the server first.")
        return 1
    
    # Run all tests
    tests = [
        ("Multilingual Generation + ASR", test_multilingual_generation),
        ("Chunking Behavior", test_chunking_behavior),
        ("Emotion Tags", test_emotions),
        ("TTFB Measurement", test_ttfb),
        ("Concurrency", test_concurrency),
        ("Robustness", test_robustness),
    ]
    
    results = []
    
    for test_name, test_func in tests:
        try:
            log_info(f"\nStarting: {test_name}")
            result = test_func()
            results.append((test_name, result))
            
            if result:
                log_success(f"{test_name}: PASSED")
            else:
                log_warning(f"{test_name}: FAILED")
            
            # Small delay between tests
            time.sleep(2)
            
        except Exception as e:
            log_error(f"{test_name}: EXCEPTION - {e}")
            import traceback
            traceback.print_exc()
            results.append((test_name, False))
    
    # Final summary
    log_section("📊 VALIDATION SUMMARY")
    
    passed = sum(1 for _, result in results if result)
    total = len(results)
    
    for test_name, result in results:
        status = "✅ PASS" if result else "❌ FAIL"
        log(f"  {status} - {test_name}", Colors.GREEN if result else Colors.RED)
    
    log_info(f"\nOverall: {passed}/{total} tests passed ({passed/total*100:.0f}%)")
    
    if passed == total:
        log_success("🎉 ALL TESTS PASSED!")
        return 0
    elif passed >= total * 0.7:
        log_warning("⚠️  MOST TESTS PASSED")
        return 0
    else:
        log_error("❌ VALIDATION FAILED")
        return 1


if __name__ == "__main__":
    sys.exit(main())

