#!/usr/bin/env python3
"""
Measure TTFB (Time To First Byte) with detailed timing breakdown.

This script will test both streaming and non-streaming modes and show
exactly when the first hearable audio byte is actually received.
"""

import requests
import time
import sys

# Configuration
API_URL = "http://localhost:8000/v1/tts/generate"
API_KEY = None  # Will get from command line or use None for local testing

def measure_streaming_ttfb(text="Hello world, this is a test of streaming audio generation.", speaker="lipakshi"):
    """
    Measure TTFB for streaming mode.
    This is the TRUE TTFB - time until first audio byte arrives.
    """
    print("\n" + "="*80)
    print("📊 STREAMING MODE - TRUE TTFB MEASUREMENT")
    print("="*80)
    
    payload = {
        "text": text,
        "speaker": speaker,
        "stream": True,
        "format": "wav"
    }
    
    headers = {"Content-Type": "application/json"}
    if API_KEY:
        headers["X-API-Key"] = API_KEY
    
    # Timestamps
    t0_request_start = time.time()
    print(f"\n⏰ T0 = {0:.3f}s - Request initiated")
    
    try:
        # Make request with stream=True
        response = requests.post(API_URL, json=payload, headers=headers, stream=True, timeout=30)
        
        t1_headers_received = time.time() - t0_request_start
        print(f"⏰ T1 = {t1_headers_received:.3f}s - Headers received (HTTP {response.status_code})")
        
        if response.status_code != 200:
            print(f"❌ Error: {response.status_code}")
            print(response.text)
            return None
        
        # Print response headers
        print(f"\n📋 Response Headers:")
        for key, value in response.headers.items():
            if key.startswith('X-'):
                print(f"   {key}: {value}")
        
        # Read first chunk
        first_chunk = None
        chunk_count = 0
        total_bytes = 0
        
        for chunk in response.iter_content(chunk_size=4096):
            if chunk:
                chunk_count += 1
                total_bytes += len(chunk)
                
                if first_chunk is None:
                    first_chunk = chunk
                    t2_first_chunk = time.time() - t0_request_start
                    print(f"\n🎵 T2 = {t2_first_chunk:.3f}s - FIRST AUDIO CHUNK RECEIVED")
                    print(f"   └─ Size: {len(chunk)} bytes")
                    print(f"   └─ TTFB (True): {t2_first_chunk:.3f}s = {t2_first_chunk*1000:.0f}ms")
                    
                    # Check if it's WAV header
                    if chunk.startswith(b'RIFF'):
                        print(f"   └─ Content: WAV header (first 44 bytes)")
                        if len(chunk) > 44:
                            print(f"   └─ Audio data starts: {len(chunk)-44} bytes in first chunk")
                    else:
                        print(f"   └─ Content: Raw audio data")
        
        t3_complete = time.time() - t0_request_start
        print(f"\n✅ T3 = {t3_complete:.3f}s - Stream complete")
        print(f"   └─ Total chunks: {chunk_count}")
        print(f"   └─ Total bytes: {total_bytes:,}")
        print(f"   └─ Duration: ~{total_bytes / (24000 * 2):.2f}s of audio")
        
        # Calculate TTFB breakdown
        print(f"\n📊 TTFB BREAKDOWN:")
        print(f"   ├─ Headers received:      {t1_headers_received*1000:>7.0f}ms")
        print(f"   ├─ First chunk received:  {t2_first_chunk*1000:>7.0f}ms  ← TRUE TTFB")
        print(f"   └─ Total time:            {t3_complete*1000:>7.0f}ms")
        
        return {
            'ttfb': t2_first_chunk,
            'total_time': t3_complete,
            'chunks': chunk_count,
            'bytes': total_bytes
        }
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return None


def measure_non_streaming_ttfb(text="Hello world, this is a test.", speaker="lipakshi"):
    """
    Measure TTFB for non-streaming mode.
    Note: In non-streaming, we receive the entire audio at once.
    """
    print("\n" + "="*80)
    print("📊 NON-STREAMING MODE - RESPONSE TIME MEASUREMENT")
    print("="*80)
    
    payload = {
        "text": text,
        "speaker": speaker,
        "stream": False,
        "format": "wav"
    }
    
    headers = {"Content-Type": "application/json"}
    if API_KEY:
        headers["X-API-Key"] = API_KEY
    
    # Timestamps
    t0_request_start = time.time()
    print(f"\n⏰ T0 = {0:.3f}s - Request initiated")
    
    try:
        response = requests.post(API_URL, json=payload, headers=headers, timeout=30)
        
        t1_response_complete = time.time() - t0_request_start
        print(f"⏰ T1 = {t1_response_complete:.3f}s - Full response received (HTTP {response.status_code})")
        
        if response.status_code != 200:
            print(f"❌ Error: {response.status_code}")
            print(response.text)
            return None
        
        # Print response headers
        print(f"\n📋 Response Headers:")
        server_ttfb = None
        for key, value in response.headers.items():
            if key.startswith('X-'):
                print(f"   {key}: {value}")
                if key == 'X-TTFB-ms':
                    server_ttfb = float(value) / 1000  # Convert to seconds
        
        audio_bytes = len(response.content)
        audio_duration = audio_bytes / (24000 * 2)  # 24kHz, 16-bit
        
        print(f"\n🎵 Audio received:")
        print(f"   └─ Size: {audio_bytes:,} bytes")
        print(f"   └─ Duration: ~{audio_duration:.2f}s")
        
        print(f"\n📊 TIMING ANALYSIS:")
        print(f"   ├─ Client measured (total):  {t1_response_complete*1000:>7.0f}ms")
        if server_ttfb:
            print(f"   ├─ Server reported (gen):    {server_ttfb*1000:>7.0f}ms  ← Generation time")
            network_time = t1_response_complete - server_ttfb
            print(f"   └─ Network overhead:         {network_time*1000:>7.0f}ms")
        
        print(f"\n   Note: Non-streaming TTFB = full generation time")
        print(f"         (All audio generated before first byte sent)")
        
        return {
            'total_time': t1_response_complete,
            'server_ttfb': server_ttfb,
            'bytes': audio_bytes,
            'duration': audio_duration
        }
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return None


def main():
    global API_KEY
    
    if len(sys.argv) > 1:
        API_KEY = sys.argv[1]
    
    print("🔍 VEENA3 TTS - DETAILED TTFB MEASUREMENT")
    print("=" * 80)
    print(f"Target: {API_URL}")
    print(f"API Key: {'Provided' if API_KEY else 'None (local testing)'}")
    
    # Test 1: Short text, streaming
    print("\n\n" + "🧪 TEST 1: Short text, STREAMING MODE".center(80, "="))
    streaming_short = measure_streaming_ttfb(
        text="Hello world, this is a test.",
        speaker="lipakshi"
    )
    
    time.sleep(0.5)  # Brief pause between tests
    
    # Test 2: Short text, non-streaming
    print("\n\n" + "🧪 TEST 2: Short text, NON-STREAMING MODE".center(80, "="))
    non_streaming_short = measure_non_streaming_ttfb(
        text="Hello world, this is a test.",
        speaker="lipakshi"
    )
    
    time.sleep(0.5)
    
    # Test 3: Longer text, streaming
    print("\n\n" + "🧪 TEST 3: Longer text, STREAMING MODE".center(80, "="))
    long_text = "Hello, this is a longer test of the text to speech system. " * 3
    streaming_long = measure_streaming_ttfb(
        text=long_text,
        speaker="reet"
    )
    
    # Summary
    print("\n\n" + "="*80)
    print("📊 SUMMARY - TTFB COMPARISON")
    print("="*80)
    
    if streaming_short:
        print(f"\n✅ Streaming (short):    TTFB = {streaming_short['ttfb']*1000:>6.0f}ms")
    if non_streaming_short:
        print(f"✅ Non-streaming (short): Time = {non_streaming_short['total_time']*1000:>6.0f}ms (full generation)")
    if streaming_long:
        print(f"✅ Streaming (long):     TTFB = {streaming_long['ttfb']*1000:>6.0f}ms")
    
    print(f"\n💡 KEY INSIGHT:")
    print(f"   • Streaming TTFB = Time to first audio chunk (TRUE TTFB)")
    print(f"   • Non-streaming 'TTFB' = Full generation time (all audio ready)")
    print(f"   • For low latency, use STREAMING mode")
    print("="*80)


if __name__ == "__main__":
    main()

