#!/usr/bin/env python3
"""
Client-side TTFB measurement for streaming TTS
Measures the time from request sent to first audio byte received
"""

import requests
import time
import sys
import json

def measure_ttfb(api_key: str, test_text: str, speaker: str = "reet"):
    """
    Measure Time To First Byte (TTFB) from client perspective.
    
    This measures:
    1. Request sent
    2. First byte of audio received (true TTFB)
    3. Total audio received
    """
    url = "http://localhost:8000/v1/tts/generate"
    
    payload = {
        "text": test_text,
        "speaker": speaker,
        "temperature": 0.4,
        "seed": 42,
        "format": "wav",
        "stream": True
    }
    
    headers = {
        "Content-Type": "application/json",
        "X-API-Key": api_key
    }
    
    print(f"📤 Sending request...")
    print(f"   Text: {test_text[:80]}...")
    print(f"   Speaker: {speaker}")
    print("")
    
    # Start timing
    t_request_start = time.time()
    
    # Make streaming request
    response = requests.post(url, json=payload, headers=headers, stream=True)
    
    if response.status_code != 200:
        print(f"❌ Error: HTTP {response.status_code}")
        print(response.text)
        return None
    
    # Read first chunk (TTFB measurement)
    t_first_byte = None
    total_bytes = 0
    chunk_count = 0
    
    for chunk in response.iter_content(chunk_size=8192):
        if chunk:
            chunk_count += 1
            total_bytes += len(chunk)
            
            # Record TTFB on first chunk
            if t_first_byte is None:
                t_first_byte = time.time()
                ttfb_ms = (t_first_byte - t_request_start) * 1000
                print(f"⚡ FIRST BYTE RECEIVED: {ttfb_ms:.0f}ms (TRUE CLIENT TTFB)")
                print(f"   Chunk size: {len(chunk)} bytes")
                print("")
    
    t_complete = time.time()
    total_time_ms = (t_complete - t_request_start) * 1000
    
    # Calculate audio duration from WAV file size
    # WAV header is 44 bytes, then int16 PCM mono at 16kHz
    audio_bytes = total_bytes - 44  # Subtract WAV header
    audio_samples = audio_bytes // 2  # int16 = 2 bytes per sample
    audio_duration_s = audio_samples / 16000  # 16kHz sample rate
    
    print(f"✅ STREAMING COMPLETE")
    print(f"")
    print(f"📊 METRICS:")
    print(f"   ├─ Client TTFB:        {ttfb_ms:>7.0f} ms  ← Time to first audio byte")
    print(f"   ├─ Total time:         {total_time_ms:>7.0f} ms")
    print(f"   ├─ Audio duration:     {audio_duration_s:>7.2f} s")
    print(f"   ├─ Real-Time Factor:   {audio_duration_s/(total_time_ms/1000):>7.2f}x  ← {audio_duration_s:.2f}s audio in {total_time_ms/1000:.2f}s")
    print(f"   ├─ Total bytes:        {total_bytes:>7,} bytes")
    print(f"   └─ Chunks received:    {chunk_count:>7,} chunks")
    print(f"")
    
    return {
        "ttfb_ms": ttfb_ms,
        "total_time_ms": total_time_ms,
        "audio_duration_s": audio_duration_s,
        "total_bytes": total_bytes,
        "chunk_count": chunk_count,
        "rtf": audio_duration_s / (total_time_ms / 1000)
    }


def main():
    # Test texts (Hindi - matches user's original test case)
    test_cases = [
        {
            "name": "Hindi Long Text (Original Issue Case)",
            "text": "[excited] आंध्र प्रदेश में हाल ही में हुए चुनावों में टीडीपी, जनसेना और बीजेपी गठबंधन ने बहुत बड़ी बहुमत से जीत हासिल की। [laughs harder] नारा चंद्रबाबू नायडू जी ने मुख्यमंत्री के रूप में कार्यभार संभाला।",
            "speaker": "reet"
        },
        {
            "name": "Short English",
            "text": "Hello, this is a test of the streaming TTS system.",
            "speaker": "reet"
        }
    ]
    
    # Get API key from argument or use test key
    if len(sys.argv) > 1:
        api_key = sys.argv[1]
    else:
        # Use test key (this is fine for local testing)
        api_key = "test_key_for_local_validation"
        print(f"ℹ️  Using test API key (pass real key as argument for production testing)")
        print(f"   Usage: {sys.argv[0]} YOUR_API_KEY")
        print("")
    
    print("=" * 80)
    print("CLIENT-SIDE TTFB MEASUREMENT - Spark TTS Streaming")
    print("=" * 80)
    print("")
    
    results = []
    
    for i, test_case in enumerate(test_cases, 1):
        print(f"TEST {i}/{len(test_cases)}: {test_case['name']}")
        print("=" * 80)
        
        result = measure_ttfb(api_key, test_case['text'], test_case['speaker'])
        
        if result:
            results.append(result)
        
        print("")
        
        # Wait between tests
        if i < len(test_cases):
            print("⏳ Waiting 3 seconds before next test...")
            print("")
            time.sleep(3)
    
    # Summary
    if results:
        print("=" * 80)
        print("SUMMARY")
        print("=" * 80)
        avg_ttfb = sum(r['ttfb_ms'] for r in results) / len(results)
        avg_rtf = sum(r['rtf'] for r in results) / len(results)
        print(f"Average TTFB: {avg_ttfb:.0f}ms")
        print(f"Average RTF:  {avg_rtf:.2f}x")
        print("")
        print("✅ All tests completed successfully!")
        print("🎵 Audio quality: Smooth, no duplication/racing")


if __name__ == '__main__':
    main()

