#!/usr/bin/env python3
"""
Test TRUE streaming implementation with detailed TTFB analysis.

This script measures:
1. Time to first byte (network level)
2. Time to first audio chunk
3. Total download time
4. Audio quality

Expected results with TRUE streaming:
- TTFB: 300-800ms (time to first audio chunk generation)
- Chunks: Received incrementally as generated
- Download: Gradual, not all at once
"""

import requests
import time
import wave
import io

API_BASE_URL = "http://localhost:8000"


def test_true_streaming(text="Hello! This is a test of true streaming.", speaker="Mitra"):
    """Test TRUE streaming with detailed timing."""
    
    print("=" * 80)
    print("🚀 Testing TRUE Streaming Implementation")
    print("=" * 80)
    print(f"\nText: {text}")
    print(f"Speaker: {speaker}")
    print(f"Endpoint: {API_BASE_URL}/v1/tts/generate")
    print()
    
    # Start request
    t_request_start = time.time()
    
    response = requests.post(
        f"{API_BASE_URL}/v1/tts/generate",
        json={
            "text": text,
            "speaker": speaker,
            "stream": True
        },
        stream=True,  # Critical for streaming
        timeout=60
    )
    
    t_response_received = time.time()
    print(f"⏱️  Response headers received: {(t_response_received - t_request_start)*1000:.0f}ms")
    print()
    
    # Track chunk timing
    chunk_times = []
    chunks_received = 0
    total_bytes = 0
    audio_data = b''
    t_first_chunk = None
    
    print("📦 Receiving chunks:")
    print("-" * 80)
    
    try:
        for chunk in response.iter_content(chunk_size=4096):
            if chunk:
                t_chunk = time.time()
                
                if t_first_chunk is None:
                    t_first_chunk = t_chunk
                    ttfb = (t_first_chunk - t_request_start) * 1000
                    print(f"🎯 TTFB (Time To First Byte): {ttfb:.0f}ms")
                    print(f"   └─ This is TRUE streaming - first audio chunk ready!")
                    print()
                
                chunks_received += 1
                chunk_size = len(chunk)
                total_bytes += chunk_size
                audio_data += chunk
                
                elapsed = (t_chunk - t_request_start) * 1000
                chunk_times.append(elapsed)
                
                # Show first 5 chunks and every 10th chunk
                if chunks_received <= 5 or chunks_received % 10 == 0:
                    print(f"   Chunk {chunks_received:3d}: {chunk_size:5d} bytes at T+{elapsed:6.0f}ms")
        
        t_complete = time.time()
        total_time = (t_complete - t_request_start) * 1000
        
        print("-" * 80)
        print()
        
        # Response headers
        print("📋 Response Headers:")
        for key in ['X-Request-ID', 'X-Engine', 'X-Stream', 'X-Sample-Rate']:
            if key in response.headers:
                print(f"   {key}: {response.headers[key]}")
        print()
        
        # Timing analysis
        print("⏱️  Timing Analysis:")
        print(f"   Response headers:  {(t_response_received - t_request_start)*1000:7.0f}ms")
        if t_first_chunk:
            print(f"   First chunk (TTFB): {(t_first_chunk - t_request_start)*1000:7.0f}ms")
        else:
            print(f"   First chunk (TTFB): NO CHUNKS RECEIVED!")
        print(f"   Total time:        {total_time:7.0f}ms")
        print()
        
        # Data transfer analysis
        print("📊 Data Transfer:")
        print(f"   Chunks received:   {chunks_received}")
        print(f"   Total bytes:       {total_bytes:,} bytes ({total_bytes/1024:.1f} KB)")
        if chunks_received > 0:
            print(f"   Avg chunk size:    {total_bytes/chunks_received:.0f} bytes")
        else:
            print(f"   ⚠️  WARNING: No chunks received - check server logs for errors!")
        print()
        
        # Streaming pattern analysis
        if len(chunk_times) >= 2:
            chunk_intervals = [chunk_times[i+1] - chunk_times[i] for i in range(len(chunk_times)-1)]
            avg_interval = sum(chunk_intervals) / len(chunk_intervals)
            print(f"   Avg chunk interval: {avg_interval:.0f}ms")
            
            # Check if streaming is incremental (intervals should be relatively consistent)
            # vs buffered (first interval very long, rest very short)
            first_interval = chunk_intervals[0] if chunk_intervals else 0
            if first_interval > avg_interval * 5:
                print(f"   ⚠️  WARNING: First interval much longer ({first_interval:.0f}ms vs avg {avg_interval:.0f}ms)")
                print(f"      This suggests buffered streaming, not true streaming!")
            else:
                print(f"   ✅ Streaming pattern: Incremental (TRUE streaming)")
        print()
        
        # Validate audio
        try:
            audio_io = io.BytesIO(audio_data)
            with wave.open(audio_io, 'rb') as wf:
                channels = wf.getnchannels()
                sample_width = wf.getsampwidth()
                framerate = wf.getframerate()
                frames = wf.getnframes()
                duration = frames / framerate
                
                print("🎵 Audio Info:")
                print(f"   Channels:      {channels}")
                print(f"   Sample width:  {sample_width} bytes")
                print(f"   Sample rate:   {framerate} Hz")
                print(f"   Frames:        {frames:,}")
                print(f"   Duration:      {duration:.2f}s")
                print()
                
                # Calculate Real-Time Factor
                rtf = (total_time / 1000) / duration
                print(f"   RTF (Real-Time Factor): {rtf:.3f}")
                if rtf < 0.5:
                    print(f"   ✅ Excellent performance!")
                elif rtf < 1.0:
                    print(f"   ✅ Good performance")
                else:
                    print(f"   ⚠️  Slower than real-time")
                print()
        
        except Exception as e:
            print(f"❌ Audio validation error: {e}")
            print()
        
        # Success criteria
        if t_first_chunk:
            ttfb_ms = (t_first_chunk - t_request_start) * 1000
        else:
            ttfb_ms = None
        
        print("=" * 80)
        print("📋 Test Results:")
        print("=" * 80)
        
        success = True
        
        # Check TTFB
        if ttfb_ms is None:
            print(f"❌ TTFB: NO AUDIO GENERATED (check server logs)")
            success = False
        elif ttfb_ms < 1000:
            print(f"✅ TTFB: {ttfb_ms:.0f}ms (Excellent - <1s)")
        elif ttfb_ms < 2000:
            print(f"⚠️  TTFB: {ttfb_ms:.0f}ms (Good but could be better)")
        else:
            print(f"❌ TTFB: {ttfb_ms:.0f}ms (Too slow - target <1s)")
            success = False
        
        # Check audio generated
        if total_bytes > 1000:
            print(f"✅ Audio generated: {total_bytes:,} bytes")
        else:
            print(f"❌ Audio too small: {total_bytes} bytes")
            success = False
        
        # Check chunking
        if chunks_received > 5:
            print(f"✅ Chunked delivery: {chunks_received} chunks")
        else:
            print(f"⚠️  Few chunks: {chunks_received} (might be buffered)")
        
        print()
        if success:
            print("🎉 TRUE STREAMING TEST PASSED!")
        else:
            print("❌ TEST FAILED - See issues above")
        print("=" * 80)
        
        return success
    
    except Exception as e:
        print(f"\n❌ Test failed with error: {e}")
        import traceback
        traceback.print_exc()
        return False


if __name__ == "__main__":
    print("\n🔍 Starting TRUE Streaming Validation\n")
    
    # Test 1: Short text
    print("\n📝 Test 1: Short Text")
    test_true_streaming(
        text="Hello, this is a streaming test.",
        speaker="Mitra"
    )
    
    print("\n" + "="*80 + "\n")
    
    # Test 2: Longer text
    print("📝 Test 2: Longer Text")
    test_true_streaming(
        text="This is a longer test to see how the streaming performs with more text. We want to verify that audio chunks are generated and streamed incrementally, not all at once.",
        speaker="Aaranya"
    )
    
    print("\n✅ All tests complete!")
    print("\nCheck server logs for detailed timing breakdown:")
    print("  - Request to init")
    print("  - Init to gen start")
    print("  - Gen start to created")
    print("  - Created to first audio")

