#!/usr/bin/env python3
"""
Debug script to isolate ASR word error rate issues.

Hypotheses to test:
1. Is it chunking? (Compare chunked vs non-chunked for same text)
2. Is it Gemini truncating long audio?
3. Is it audio quality degradation with length?
4. Is it the word comparison algorithm being too strict?
"""

import os
import io
import re
import wave
import json
import base64
import httpx
from typing import Tuple, List

MODAL_ENDPOINT_URL = os.environ.get("MODAL_ENDPOINT_URL", "https://mayaresearch--veena3-tts-ttsservice-serve.modal.run")
GEMINI_KEY = os.environ.get("GEMINI_KEY")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "models/gemini-2.0-flash")

# Test texts of increasing length
SHORT_TEXT = "Hello, how are you today?"
MEDIUM_TEXT = """The rapid advancement of artificial intelligence has transformed many industries, 
from healthcare to finance. Machine learning algorithms can now analyze vast amounts 
of data to make predictions that were previously impossible."""

LONG_TEXT = """The history of computing is a fascinating journey that spans several centuries. 
From the earliest mechanical calculators designed by Blaise Pascal and Gottfried Wilhelm 
Leibniz in the seventeenth century, to Charles Babbage's visionary Analytical Engine in 
the nineteenth century, the concept of automated computation has captivated human imagination. 
The twentieth century witnessed remarkable breakthroughs, including Alan Turing's theoretical 
framework for computation, the development of electronic computers during World War II, and 
the subsequent evolution from room-sized mainframes to personal computers. The invention of 
the transistor, integrated circuits, and microprocessors revolutionized the field, making 
computers smaller, faster, and more affordable. Today, computing power that once required 
entire buildings can fit in our pockets, and the ongoing development of quantum computing 
promises to unlock computational capabilities that were previously unimaginable. This 
remarkable progress reflects humanity's relentless pursuit of knowledge and innovation."""


def generate_tts(text: str, stream: bool = False, chunking: bool = True) -> Tuple[bytes, dict]:
    """Generate TTS audio."""
    url = f"{MODAL_ENDPOINT_URL.rstrip('/')}/v1/tts/generate"
    response = httpx.post(
        url,
        json={
            "text": text,
            "speaker": "lipakshi",
            "stream": stream,
            "chunking": chunking,
        },
        headers={"Content-Type": "application/json"},
        timeout=300.0,
    )
    response.raise_for_status()
    return response.content, dict(response.headers)


def transcribe_with_gemini(audio_bytes: bytes, timeout: float = 180.0) -> str:
    """Transcribe audio using Gemini API."""
    if not GEMINI_KEY:
        raise ValueError("GEMINI_KEY not set")
    
    audio_b64 = base64.standard_b64encode(audio_bytes).decode("utf-8")
    audio_mb = len(audio_bytes) / (1024 * 1024)
    print(f"   Sending {audio_mb:.2f}MB to Gemini (timeout: {timeout}s)...")
    
    try:
        response = httpx.post(
            f"https://generativelanguage.googleapis.com/v1beta/{GEMINI_MODEL}:generateContent",
            headers={"Content-Type": "application/json"},
            params={"key": GEMINI_KEY},
            json={
                "contents": [{
                    "parts": [
                        {"text": "Transcribe the following audio exactly as spoken. Return ONLY the transcription, nothing else."},
                        {"inline_data": {"mime_type": "audio/wav", "data": audio_b64}}
                    ]
                }],
                "generationConfig": {"temperature": 0.0}
            },
            timeout=timeout,
        )
        response.raise_for_status()
        
        data = response.json()
        try:
            return data["candidates"][0]["content"]["parts"][0]["text"].strip()
        except (KeyError, IndexError):
            print(f"   ⚠️ Unexpected response: {data}")
            return ""
    except httpx.ReadTimeout:
        print(f"   ⚠️ Gemini API TIMEOUT after {timeout}s for {audio_mb:.2f}MB audio")
        return "[TIMEOUT]"
    except Exception as e:
        print(f"   ⚠️ Gemini API error: {e}")
        return f"[ERROR: {e}]"


def normalize_text(text: str) -> List[str]:
    """Normalize text for comparison - strip punctuation and lowercase."""
    # Remove punctuation except apostrophes in contractions
    text = re.sub(r"[^\w\s']", " ", text)
    # Lowercase and split
    words = text.lower().split()
    # Filter out empty strings
    return [w for w in words if w]


def get_audio_duration(audio_bytes: bytes) -> float:
    """Get audio duration in seconds."""
    try:
        with wave.open(io.BytesIO(audio_bytes), 'rb') as wf:
            frames = wf.getnframes()
            rate = wf.getframerate()
            return frames / rate
    except:
        return 0.0


def compare_texts(original: str, transcription: str) -> dict:
    """Compare original and transcribed text with multiple metrics."""
    orig_words = normalize_text(original)
    trans_words = normalize_text(transcription)
    
    orig_set = set(orig_words)
    trans_set = set(trans_words)
    
    # Set-based metrics
    common = orig_set & trans_set
    missing = orig_set - trans_set
    extra = trans_set - orig_set
    
    # Sequence-based similarity (accounts for order and frequency)
    from difflib import SequenceMatcher
    seq_ratio = SequenceMatcher(None, orig_words, trans_words).ratio()
    
    return {
        "original_word_count": len(orig_words),
        "transcribed_word_count": len(trans_words),
        "unique_original": len(orig_set),
        "unique_transcribed": len(trans_set),
        "common_words": len(common),
        "missing_words": len(missing),
        "extra_words": len(extra),
        "missing_ratio": len(missing) / len(orig_set) if orig_set else 0,
        "sequence_similarity": seq_ratio,
        "missing_word_list": sorted(missing)[:20],  # First 20
        "extra_word_list": sorted(extra)[:20],
    }


def run_test(name: str, text: str, stream: bool = False, chunking: bool = True, timeout: float = 180.0):
    """Run a single test and print results."""
    print(f"\n{'='*60}")
    print(f"TEST: {name}")
    print(f"  Text length: {len(text)} chars, ~{len(text.split())} words")
    print(f"  Stream: {stream}, Chunking: {chunking}")
    print(f"{'='*60}")
    
    # Generate audio
    print("📢 Generating audio...")
    audio_bytes, headers = generate_tts(text, stream=stream, chunking=chunking)
    duration = get_audio_duration(audio_bytes)
    audio_mb = len(audio_bytes) / (1024 * 1024)
    print(f"   Audio: {len(audio_bytes)} bytes ({audio_mb:.2f}MB), {duration:.2f}s")
    print(f"   Headers: X-Text-Chunked={headers.get('x-text-chunked', 'N/A')}")
    
    # Transcribe
    print("🎤 Transcribing with Gemini...")
    transcription = transcribe_with_gemini(audio_bytes, timeout=timeout)
    
    # Handle timeout/error
    if transcription.startswith("[TIMEOUT]") or transcription.startswith("[ERROR"):
        print(f"   ⚠️ Transcription failed: {transcription}")
        return {
            "original_word_count": len(text.split()),
            "transcribed_word_count": 0,
            "unique_original": len(set(text.lower().split())),
            "unique_transcribed": 0,
            "common_words": 0,
            "missing_words": len(set(text.lower().split())),
            "extra_words": 0,
            "missing_ratio": 1.0,
            "sequence_similarity": 0.0,
            "missing_word_list": [],
            "extra_word_list": [],
            "error": transcription,
        }
    
    print(f"   Transcription ({len(transcription)} chars): {transcription[:100]}...")
    
    # Compare
    metrics = compare_texts(text, transcription)
    print(f"\n📊 Metrics:")
    print(f"   Original words: {metrics['original_word_count']} ({metrics['unique_original']} unique)")
    print(f"   Transcribed words: {metrics['transcribed_word_count']} ({metrics['unique_transcribed']} unique)")
    print(f"   Common words: {metrics['common_words']}")
    print(f"   Missing: {metrics['missing_words']} ({metrics['missing_ratio']:.1%})")
    print(f"   Extra: {metrics['extra_words']}")
    print(f"   Sequence similarity: {metrics['sequence_similarity']:.2%}")
    
    if metrics['missing_word_list']:
        print(f"   Missing words sample: {metrics['missing_word_list'][:10]}")
    
    return metrics


def chunk_audio_by_duration(audio_bytes: bytes, max_duration_seconds: float = 25.0) -> List[bytes]:
    """
    Split WAV audio into chunks of max_duration_seconds.
    Returns list of WAV audio byte chunks.
    """
    import struct
    
    with wave.open(io.BytesIO(audio_bytes), 'rb') as wf:
        sample_rate = wf.getframerate()
        n_channels = wf.getnchannels()
        sample_width = wf.getsampwidth()
        n_frames = wf.getnframes()
        
        # Calculate frames per chunk
        frames_per_chunk = int(max_duration_seconds * sample_rate)
        
        chunks = []
        wf.rewind()
        
        while True:
            frames = wf.readframes(frames_per_chunk)
            if not frames:
                break
            
            # Create WAV header for this chunk
            chunk_buffer = io.BytesIO()
            with wave.open(chunk_buffer, 'wb') as chunk_wf:
                chunk_wf.setnchannels(n_channels)
                chunk_wf.setsampwidth(sample_width)
                chunk_wf.setframerate(sample_rate)
                chunk_wf.writeframes(frames)
            
            chunks.append(chunk_buffer.getvalue())
        
        return chunks


def run_chunked_asr_test(name: str, text: str, max_chunk_duration: float = 25.0):
    """
    Generate audio for text, split into chunks, transcribe each chunk with ASR.
    This isolates whether ASR failure is due to audio length or audio quality.
    """
    print(f"\n{'='*60}")
    print(f"CHUNKED ASR TEST: {name}")
    print(f"  Text length: {len(text)} chars, ~{len(text.split())} words")
    print(f"  Max chunk duration: {max_chunk_duration}s")
    print(f"{'='*60}")
    
    # Generate full audio
    print("📢 Generating audio...")
    audio_bytes, headers = generate_tts(text, stream=False, chunking=True)
    total_duration = get_audio_duration(audio_bytes)
    audio_mb = len(audio_bytes) / (1024 * 1024)
    print(f"   Total audio: {len(audio_bytes)} bytes ({audio_mb:.2f}MB), {total_duration:.2f}s")
    print(f"   Headers: X-Text-Chunked={headers.get('x-text-chunked', 'N/A')}")
    
    # Split audio into chunks
    print(f"\n🔪 Splitting audio into {max_chunk_duration}s chunks...")
    chunks = chunk_audio_by_duration(audio_bytes, max_chunk_duration)
    print(f"   Created {len(chunks)} audio chunks")
    
    # Transcribe each chunk
    all_transcriptions = []
    chunk_results = []
    
    for i, chunk in enumerate(chunks):
        chunk_duration = get_audio_duration(chunk)
        chunk_mb = len(chunk) / (1024 * 1024)
        print(f"\n🎤 Chunk {i+1}/{len(chunks)}: {chunk_duration:.2f}s ({chunk_mb:.2f}MB)")
        
        transcription = transcribe_with_gemini(chunk, timeout=90.0)
        
        if transcription.startswith("[TIMEOUT]") or transcription.startswith("[ERROR"):
            print(f"   ❌ Transcription failed: {transcription}")
            chunk_results.append({
                "chunk": i+1,
                "duration": chunk_duration,
                "transcription": "",
                "error": transcription
            })
        else:
            print(f"   ✅ Transcribed: {transcription[:80]}...")
            all_transcriptions.append(transcription)
            chunk_results.append({
                "chunk": i+1,
                "duration": chunk_duration,
                "transcription": transcription,
                "word_count": len(transcription.split())
            })
    
    # Combine all transcriptions
    combined_transcription = " ".join(all_transcriptions)
    print(f"\n📝 Combined transcription ({len(combined_transcription)} chars):")
    print(f"   {combined_transcription[:200]}...")
    
    # Compare with original
    metrics = compare_texts(text, combined_transcription)
    
    print(f"\n📊 Final Metrics (Chunked ASR):")
    print(f"   Original words: {metrics['original_word_count']} ({metrics['unique_original']} unique)")
    print(f"   Transcribed words: {metrics['transcribed_word_count']} ({metrics['unique_transcribed']} unique)")
    print(f"   Common words: {metrics['common_words']}")
    print(f"   Missing: {metrics['missing_words']} ({metrics['missing_ratio']:.1%})")
    print(f"   Sequence similarity: {metrics['sequence_similarity']:.2%}")
    
    if metrics['missing_word_list']:
        print(f"   Missing words sample: {metrics['missing_word_list'][:15]}")
    
    return {
        "chunks": chunk_results,
        "combined_metrics": metrics,
        "total_duration": total_duration,
        "num_chunks": len(chunks),
    }


def main():
    print("=" * 60)
    print("ASR DEBUG: Isolating Word Error Rate Issues")
    print("=" * 60)
    
    if not GEMINI_KEY:
        print("ERROR: GEMINI_KEY not set")
        return
    
    results = []
    
    # Test 1: Short text (should have near-perfect transcription)
    results.append(("Short text", run_test("Short text (baseline)", SHORT_TEXT)))
    
    # Test 2: Medium text
    results.append(("Medium text", run_test("Medium text", MEDIUM_TEXT)))
    
    # Test 3: Long text WITH chunking - full audio to Gemini
    results.append(("Long + chunking", run_test("Long text WITH chunking", LONG_TEXT, chunking=True)))
    
    # Test 4: Long text - CHUNKED ASR (split audio into <25s chunks before ASR)
    print("\n" + "="*60)
    print("🔬 CRITICAL TEST: Chunked Audio ASR")
    print("   This isolates whether the issue is Gemini vs TTS quality")
    print("="*60)
    chunked_result = run_chunked_asr_test("Long text (chunked ASR)", LONG_TEXT, max_chunk_duration=25.0)
    results.append(("Chunked ASR", chunked_result["combined_metrics"]))
    
    # Summary
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"{'Test':<20} {'Missing %':<12} {'Seq. Sim.':<12} {'Words':<15}")
    print("-" * 60)
    for name, m in results:
        print(f"{name:<20} {m['missing_ratio']:>10.1%}   {m['sequence_similarity']:>10.1%}   {m['original_word_count']:>3} -> {m['transcribed_word_count']:<3}")
    
    print("\n" + "=" * 60)
    print("CONCLUSIONS")
    print("=" * 60)
    
    # Key comparison: Full audio to Gemini vs Chunked audio to Gemini
    full_audio_m = results[2][1] if len(results) > 2 else None
    chunked_audio_m = results[3][1] if len(results) > 3 else None
    
    if full_audio_m and chunked_audio_m:
        if full_audio_m.get('error') and chunked_audio_m['missing_ratio'] < 0.3:
            print("✅ TTS QUALITY IS GOOD!")
            print("   Full audio times out on Gemini, but chunked audio transcribes well.")
            print("   The issue is GEMINI API LIMITATIONS, not TTS quality.")
        elif chunked_audio_m['missing_ratio'] < 0.3:
            print("✅ TTS QUALITY IS GOOD!")
            print(f"   Chunked ASR achieved {1-chunked_audio_m['missing_ratio']:.0%} word capture.")
        elif chunked_audio_m['missing_ratio'] > 0.5:
            print("⚠️ POTENTIAL TTS QUALITY ISSUE")
            print(f"   Even chunked audio has {chunked_audio_m['missing_ratio']:.0%} missing words.")
            print("   Investigate crossfade, pronunciation, or model issues.")
    
    # Check short vs chunked to verify baseline
    short_m = results[0][1] if results else None
    if short_m and chunked_audio_m:
        if short_m['missing_ratio'] < 0.1 and chunked_audio_m['missing_ratio'] < 0.3:
            print("\n✅ FINAL VERDICT: TTS quality is excellent.")
            print("   Short audio: 100% accurate")
            print(f"   Long audio (chunked ASR): {1-chunked_audio_m['missing_ratio']:.0%} accurate")
            print("   Any issues with full-audio ASR are Gemini API limitations.")


if __name__ == "__main__":
    main()

