#!/usr/bin/env python3
"""
Comprehensive Test Script for All New Features

Tests:
1. Friendly speaker name mappings
2. BiCodec streaming with TTFB measurement
3. ASR validation for streaming
4. Complete application validation
"""

import os
import sys
import time
import requests
import wave
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

# Import OpenAI for ASR
import openai

API_BASE_URL = "http://localhost:8000"
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")


def log_section(title):
    print(f"\n{'='*80}")
    print(f"  {title}")
    print(f"{'='*80}\n")


def test_friendly_speaker_names():
    """Test 1: Friendly speaker name mappings"""
    log_section("TEST 1: Friendly Speaker Name Mappings")
    
    test_cases = [
        ("Mitra", "lipakshi", "Hello! I am Mitra."),
        ("Aaranya", "reet", "Hi, my name is Aaranya."),
        ("Dhruva", "vardan", "Greetings from Dhruva."),
        ("Aria", "krishna", "This is Aria speaking."),
    ]
    
    results = []
    
    for friendly_name, expected_internal, text in test_cases:
        print(f"Testing: {friendly_name} → {expected_internal}")
        
        response = requests.post(
            f"{API_BASE_URL}/v1/tts/generate",
            json={
                "text": text,
                "speaker": friendly_name,
                "stream": False
            },
            timeout=30
        )
        
        if response.status_code == 200:
            audio_size = len(response.content)
            print(f"  ✅ Success! Audio size: {audio_size} bytes")
            results.append(True)
        else:
            print(f"  ❌ Failed! Status: {response.status_code}")
            print(f"     Error: {response.text[:200]}")
            results.append(False)
    
    success_rate = sum(results) / len(results) * 100
    print(f"\n📊 Friendly Names Test: {sum(results)}/{len(results)} passed ({success_rate:.0f}%)")
    return success_rate >= 75


def test_streaming_with_ttfb():
    """Test 2: BiCodec streaming with TTFB measurement"""
    log_section("TEST 2: BiCodec Streaming with TTFB")
    
    test_text = "Hello! This is a test of the streaming functionality with BiCodec."
    
    print(f"Text: {test_text}")
    print(f"Speaker: Mitra (friendly name)")
    print()
    
    # Measure TTFB
    t_start = time.time()
    
    response = requests.post(
        f"{API_BASE_URL}/v1/tts/generate",
        json={
            "text": test_text,
            "speaker": "Mitra",
            "stream": True
        },
        stream=True,  # Important for streaming
        timeout=60
    )
    
    # Time to first byte
    first_chunk_time = None
    chunks_received = 0
    total_bytes = 0
    
    audio_data = b''
    
    for chunk in response.iter_content(chunk_size=1024):
        if chunk:
            if first_chunk_time is None:
                first_chunk_time = time.time() - t_start
                print(f"🚀 TTFB: {first_chunk_time*1000:.0f}ms")
            
            chunks_received += 1
            total_bytes += len(chunk)
            audio_data += chunk
    
    total_time = time.time() - t_start
    
    print(f"\n📊 Streaming Results:")
    print(f"  - TTFB: {first_chunk_time*1000:.0f}ms")
    print(f"  - Total time: {total_time:.2f}s")
    print(f"  - Chunks received: {chunks_received}")
    print(f"  - Total bytes: {total_bytes}")
    print(f"  - Headers: {dict(response.headers)}")
    
    # Save audio
    with open('/tmp/streaming_test.wav', 'wb') as f:
        f.write(audio_data)
    
    # Check audio duration
    with wave.open('/tmp/streaming_test.wav', 'rb') as wf:
        frames = wf.getnframes()
        rate = wf.getframerate()
        duration = frames / rate
        print(f"  - Audio duration: {duration:.2f}s")
    
    # Success criteria
    success = (
        response.status_code == 200 and
        first_chunk_time < 5.0 and  # TTFB < 5s
        total_bytes > 0 and
        duration > 0.5
    )
    
    if success:
        print(f"\n✅ Streaming test PASSED")
    else:
        print(f"\n❌ Streaming test FAILED")
    
    return success


def test_asr_validation():
    """Test 3: ASR validation for streaming output"""
    log_section("TEST 3: ASR Validation for Streaming")
    
    if not OPENAI_API_KEY:
        print("⚠️  OPENAI_API_KEY not set, skipping ASR validation")
        return True
    
    client = openai.OpenAI(api_key=OPENAI_API_KEY)
    
    # Use the streaming test output
    if not os.path.exists('/tmp/streaming_test.wav'):
        print("❌ Streaming test audio not found")
        return False
    
    print("Transcribing streaming output with OpenAI Whisper...")
    
    try:
        with open('/tmp/streaming_test.wav', 'rb') as f:
            transcript = client.audio.transcriptions.create(
                model='whisper-1',
                file=f,
                response_format='text'
            )
        
        expected_text = "Hello! This is a test of the streaming functionality with BiCodec."
        
        print(f"\nExpected: {expected_text}")
        print(f"Got:      {transcript.strip()}")
        
        # Check if key words are present
        key_words = ['hello', 'test', 'streaming']
        matches = sum(1 for word in key_words if word.lower() in transcript.lower())
        
        print(f"\nKey words matched: {matches}/{len(key_words)}")
        
        success = matches >= 2
        
        if success:
            print("✅ ASR validation PASSED")
        else:
            print("❌ ASR validation FAILED")
        
        return success
        
    except Exception as e:
        print(f"❌ ASR error: {e}")
        return False


def test_complete_validation():
    """Test 4: Complete application validation"""
    log_section("TEST 4: Complete Application Validation")
    
    tests = [
        ("Health endpoint", lambda: requests.get(f"{API_BASE_URL}/health").status_code == 200),
        ("Non-streaming generation", lambda: test_non_streaming()),
        ("Emotion tags", lambda: test_emotion()),
        ("Long text chunking", lambda: test_chunking()),
    ]
    
    results = []
    
    for test_name, test_func in tests:
        print(f"Testing: {test_name}")
        try:
            result = test_func()
            results.append((test_name, result))
            status = "✅ PASS" if result else "❌ FAIL"
            print(f"  {status}")
        except Exception as e:
            print(f"  ❌ ERROR: {e}")
            results.append((test_name, False))
    
    passed = sum(1 for _, result in results if result)
    total = len(results)
    
    print(f"\n📊 Complete Validation: {passed}/{total} tests passed")
    
    return passed >= total * 0.75


def test_non_streaming():
    """Quick non-streaming test"""
    response = requests.post(
        f"{API_BASE_URL}/v1/tts/generate",
        json={"text": "Quick test", "speaker": "Taru", "stream": False},
        timeout=10
    )
    return response.status_code == 200 and len(response.content) > 1000


def test_emotion():
    """Quick emotion test"""
    response = requests.post(
        f"{API_BASE_URL}/v1/tts/generate",
        json={"text": "[excited] Hello!", "speaker": "Ira", "stream": False},
        timeout=10
    )
    return response.status_code == 200


def test_chunking():
    """Quick chunking test"""
    long_text = "Hello! " * 100  # 700 chars
    response = requests.post(
        f"{API_BASE_URL}/v1/tts/generate",
        json={"text": long_text, "speaker": "Veda", "stream": False},
        timeout=30
    )
    return response.status_code == 200


def main():
    log_section("🚀 Complete Feature Validation Suite")
    
    # Check server health first
    try:
        r = requests.get(f"{API_BASE_URL}/health", timeout=5)
        if r.status_code != 200:
            print("❌ Server not healthy!")
            return 1
        print("✅ Server is healthy\n")
    except Exception as e:
        print(f"❌ Cannot connect to server: {e}")
        return 1
    
    # Run all tests
    results = []
    
    tests = [
        ("Friendly Speaker Names", test_friendly_speaker_names),
        ("Streaming with TTFB", test_streaming_with_ttfb),
        ("ASR Validation", test_asr_validation),
        ("Complete Validation", test_complete_validation),
    ]
    
    for test_name, test_func in tests:
        try:
            result = test_func()
            results.append((test_name, result))
        except Exception as e:
            print(f"❌ {test_name} exception: {e}")
            import traceback
            traceback.print_exc()
            results.append((test_name, False))
        
        time.sleep(2)  # Small delay between tests
    
    # Final summary
    log_section("📊 FINAL RESULTS")
    
    for test_name, result in results:
        status = "✅ PASS" if result else "❌ FAIL"
        print(f"  {status} - {test_name}")
    
    passed = sum(1 for _, result in results if result)
    total = len(results)
    
    print(f"\n🎯 Overall: {passed}/{total} tests passed ({passed/total*100:.0f}%)")
    
    if passed == total:
        print("🎉 ALL TESTS PASSED!")
        return 0
    elif passed >= total * 0.75:
        print("⚠️  MOST TESTS PASSED")
        return 0
    else:
        print("❌ VALIDATION FAILED")
        return 1


if __name__ == "__main__":
    sys.exit(main())

