#!/usr/bin/env python3
"""
Word Error Rate (WER) Test with Emotions

Tests:
1. Generate audio with emotion tags
2. Transcribe with OpenAI Whisper
3. Calculate WER (insertions, deletions, substitutions)
4. Identify missing vs wrong words
5. Verify no chunks are skipped

WER Metrics:
- Insertions: Words in transcription but not in original
- Deletions: Words in original but missing in transcription  
- Substitutions: Words transcribed incorrectly
- WER = (I + D + S) / N * 100%
"""

import os
import sys
import requests
import subprocess
from pathlib import Path
from openai import OpenAI
import re

# Initialize OpenAI
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
if not OPENAI_API_KEY:
    print("❌ OPENAI_API_KEY not set")
    sys.exit(1)

client = OpenAI(api_key=OPENAI_API_KEY)

# Test text with emotions (very long)
TEST_TEXT_WITH_EMOTIONS = """<excited> Once upon a time, in a magical forest filled with ancient trees and mystical creatures, there lived a brave young warrior named Arjun. </excited> 

He had trained for many years under the guidance of wise masters. <whisper> But he carried a secret that no one else knew. </whisper> 

<curious> One day, while exploring the deepest part of the forest, he discovered a hidden cave glowing with ethereal light. </curious> Inside, he found an ancient scroll that spoke of a legendary sword capable of defeating the darkness that threatened his kingdom.

<angry> The evil sorcerer Kaalrath had been terrorizing the villages for months! </angry> His dark magic had corrupted the land, turning fertile fields into barren wastelands. <sigh> The people were losing hope with each passing day. </sigh>

<excited> But Arjun knew this was his chance to save everyone! </excited> He studied the ancient scroll day and night, learning the secrets of the legendary sword. The scroll revealed that the sword was hidden in three pieces across three different realms.

<whisper> The first piece lay in the Realm of Shadows, guarded by creatures that feared the light. </whisper> Arjun ventured into this dark domain, carrying only a small lantern. <laugh> The shadow creatures fled when they saw his unwavering courage! </laugh>

The second piece was in the Realm of Ice, where temperatures dropped to unimaginable lows. <sigh> Every step was a struggle against the freezing winds. </sigh> But Arjun persevered, using his inner fire to keep warm.

<curious> The final piece was the most challenging to obtain. </curious> It rested in the Realm of Storms, where lightning struck constantly and thunder roared like angry beasts. <angry> The storm guardians tested his resolve with their fiercest attacks! </angry>

But Arjun remained steadfast. He dodged lightning bolts, weathered the thunderous roars, and finally claimed the last piece. <excited> With all three pieces united, the legendary sword materialized in his hands, blazing with divine light! </excited>

<whisper> He could feel its ancient power coursing through him. </whisper> Arjun returned to his kingdom and confronted Kaalrath. The final battle was epic, with spells and sword clashes echoing across the land.

<angry> Kaalrath unleashed his darkest magic! </angry> <excited> But the legendary sword cut through every curse and hex! </excited> With one final strike, Arjun vanquished the evil sorcerer.

<laugh> The kingdom erupted in celebration! </laugh> The land began to heal, flowers bloomed again, and rivers flowed with crystal-clear water. <excited> Arjun was hailed as the greatest hero the kingdom had ever known! </excited>

<whisper> And so, peace returned to the realm, and Arjun's legend lived on for generations to come. </whisper>"""

# Expected text without emotion tags (for WER calculation)
REFERENCE_TEXT = """Once upon a time, in a magical forest filled with ancient trees and mystical creatures, there lived a brave young warrior named Arjun. He had trained for many years under the guidance of wise masters. But he carried a secret that no one else knew. One day, while exploring the deepest part of the forest, he discovered a hidden cave glowing with ethereal light. Inside, he found an ancient scroll that spoke of a legendary sword capable of defeating the darkness that threatened his kingdom. The evil sorcerer Kaalrath had been terrorizing the villages for months! His dark magic had corrupted the land, turning fertile fields into barren wastelands. The people were losing hope with each passing day. But Arjun knew this was his chance to save everyone! He studied the ancient scroll day and night, learning the secrets of the legendary sword. The scroll revealed that the sword was hidden in three pieces across three different realms. The first piece lay in the Realm of Shadows, guarded by creatures that feared the light. Arjun ventured into this dark domain, carrying only a small lantern. The shadow creatures fled when they saw his unwavering courage! The second piece was in the Realm of Ice, where temperatures dropped to unimaginable lows. Every step was a struggle against the freezing winds. But Arjun persevered, using his inner fire to keep warm. The final piece was the most challenging to obtain. It rested in the Realm of Storms, where lightning struck constantly and thunder roared like angry beasts. The storm guardians tested his resolve with their fiercest attacks! But Arjun remained steadfast. He dodged lightning bolts, weathered the thunderous roars, and finally claimed the last piece. With all three pieces united, the legendary sword materialized in his hands, blazing with divine light! He could feel its ancient power coursing through him. Arjun returned to his kingdom and confronted Kaalrath. The final battle was epic, with spells and sword clashes echoing across the land. Kaalrath unleashed his darkest magic! But the legendary sword cut through every curse and hex! With one final strike, Arjun vanquished the evil sorcerer. The kingdom erupted in celebration! The land began to heal, flowers bloomed again, and rivers flowed with crystal-clear water. Arjun was hailed as the greatest hero the kingdom had ever known! And so, peace returned to the realm, and Arjun's legend lived on for generations to come."""


def normalize_text(text):
    """Normalize text for comparison."""
    # Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)
    # Lowercase
    text = text.lower()
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text)
    return text.strip()


def calculate_wer(reference, hypothesis):
    """
    Calculate Word Error Rate with detailed breakdown.
    
    Returns:
        dict with insertions, deletions, substitutions, WER
    """
    ref_words = normalize_text(reference).split()
    hyp_words = normalize_text(hypothesis).split()
    
    # Use simple alignment for WER calculation
    # This is a simplified version - proper WER uses Levenshtein distance
    ref_set = set(ref_words)
    hyp_set = set(hyp_words)
    
    # Deletions: words in reference but not in hypothesis
    deletions = ref_set - hyp_set
    
    # Insertions: words in hypothesis but not in reference
    insertions = hyp_set - ref_set
    
    # For substitutions, we need position-based comparison
    # Simplified: count as both deletion and insertion
    # Proper WER would use edit distance
    
    # Common words (correctly transcribed)
    correct = ref_set & hyp_set
    
    # Calculate metrics
    total_ref_words = len(ref_words)
    num_deletions = len(deletions)
    num_insertions = len(insertions)
    num_correct = len(correct)
    
    # WER approximation
    wer = ((num_deletions + num_insertions) / total_ref_words) * 100 if total_ref_words > 0 else 0
    accuracy = (num_correct / total_ref_words) * 100 if total_ref_words > 0 else 0
    
    return {
        'total_words': total_ref_words,
        'correct': num_correct,
        'deletions': deletions,
        'insertions': insertions,
        'num_deletions': num_deletions,
        'num_insertions': num_insertions,
        'wer': wer,
        'accuracy': accuracy,
    }


def main():
    API_KEY = "vn3_cdd6d45f2045d03d5adac56eda6af9a9b781211038972807f35d52dfb6400144"
    API_URL = "http://localhost:8000/v1/tts/generate"
    
    print("=" * 80)
    print("🎭 WER TEST WITH EMOTIONS - COMPREHENSIVE VALIDATION")
    print("=" * 80)
    print()
    
    print(f"📝 Test Text:")
    print(f"   Length: {len(TEST_TEXT_WITH_EMOTIONS)} chars")
    print(f"   Reference (no tags): {len(REFERENCE_TEXT)} chars")
    print(f"   Emotion tags: excited, whisper, curious, angry, sigh, laugh")
    print()
    
    # Count emotion tags
    emotion_tags = re.findall(r'<(\w+)>', TEST_TEXT_WITH_EMOTIONS)
    print(f"   Emotion tag count: {len(emotion_tags)}")
    print(f"   Unique emotions: {set(emotion_tags)}")
    print()
    
    # Preview
    print("Text preview (with emotions):")
    print(TEST_TEXT_WITH_EMOTIONS[:300])
    print("...\n")
    
    # Generate audio
    print("🎵 Generating audio with emotions...")
    response = requests.post(
        API_URL,
        headers={"Content-Type": "application/json", "X-API-Key": API_KEY},
        json={
            "text": TEST_TEXT_WITH_EMOTIONS,
            "speaker": "adarsh",
            "stream": False,
            "seed": 42
        }
    )
    
    if response.status_code != 200:
        print(f"❌ Generation failed: {response.status_code}")
        print(response.text)
        sys.exit(1)
    
    # Save audio
    output_file = "wer_test_emotions.wav"
    with open(output_file, 'wb') as f:
        f.write(response.content)
    
    # Get metrics from headers
    chunked = response.headers.get('X-Text-Chunked', 'unknown')
    audio_bytes = response.headers.get('X-Audio-Bytes', 'unknown')
    audio_secs = response.headers.get('X-Audio-Seconds', 'unknown')
    rtf = response.headers.get('X-RTF', 'unknown')
    
    print(f"✅ Audio generated!")
    print(f"   Chunked: {chunked}")
    print(f"   Size: {audio_bytes} bytes")
    print(f"   Duration: {audio_secs}s")
    print(f"   RTF: {rtf}")
    print()
    
    # Get actual duration
    result = subprocess.run(
        ['ffprobe', '-v', 'error', '-show_entries', 'format=duration',
         '-of', 'default=noprint_wrappers=1:nokey=1', output_file],
        capture_output=True, text=True
    )
    actual_duration = float(result.stdout.strip())
    
    print(f"   Actual duration (ffprobe): {actual_duration:.1f}s ({actual_duration/60:.1f} min)")
    print()
    
    # Transcribe with OpenAI Whisper
    print("🎙️  Transcribing with OpenAI Whisper API...")
    print("   (This may take 30-60 seconds)")
    
    with open(output_file, 'rb') as f:
        transcript = client.audio.transcriptions.create(
            model='whisper-1',
            file=f,
            response_format='verbose_json',
            timestamp_granularities=['segment']
        )
    
    print(f"✅ Transcription complete")
    print()
    
    # Analyze segments
    segments = transcript.segments
    speech_time = sum(seg.end - seg.start for seg in segments)
    speech_ratio = (speech_time / actual_duration) * 100
    
    print(f"📊 SPEECH ANALYSIS:")
    print(f"   Duration: {actual_duration:.1f}s")
    print(f"   Speech: {speech_time:.1f}s ({speech_ratio:.1f}%)")
    print(f"   Silence: {actual_duration - speech_time:.1f}s ({100-speech_ratio:.1f}%)")
    print(f"   Segments: {len(segments)}")
    
    # Check for large gaps
    gaps = []
    for i in range(len(segments) - 1):
        gap = segments[i+1].start - segments[i].end
        if gap > 0.5:
            gaps.append((i+1, gap, segments[i].text[-50:], segments[i+1].text[:50]))
    
    if gaps:
        print(f"   Large gaps (>500ms): {len(gaps)} ⚠️")
        for idx, gap_sec, before, after in gaps[:5]:  # Show first 5
            print(f"      Gap #{idx}: {gap_sec:.1f}s")
    else:
        print(f"   Large gaps: NONE ✅")
    print()
    
    # Calculate WER
    print("📊 WORD ERROR RATE ANALYSIS:")
    print("=" * 80)
    
    wer_results = calculate_wer(REFERENCE_TEXT, transcript.text)
    
    print(f"Total reference words: {wer_results['total_words']}")
    print(f"Correctly transcribed: {wer_results['correct']} ({wer_results['accuracy']:.1f}%)")
    print()
    print(f"Deletions (missing): {wer_results['num_deletions']}")
    print(f"Insertions (extra): {wer_results['num_insertions']}")
    print(f"WER: {wer_results['wer']:.1f}%")
    print()
    
    # Show missing words (deletions)
    if wer_results['deletions']:
        print(f"🔍 MISSING WORDS ({len(wer_results['deletions'])} words):")
        missing_list = sorted(list(wer_results['deletions']))[:30]  # Show first 30
        print(f"   {', '.join(missing_list)}")
        if len(wer_results['deletions']) > 30:
            print(f"   ... and {len(wer_results['deletions']) - 30} more")
        print()
    
    # Show extra words (insertions)
    if wer_results['insertions']:
        print(f"🔍 EXTRA WORDS ({len(wer_results['insertions'])} words):")
        extra_list = sorted(list(wer_results['insertions']))[:20]  # Show first 20
        print(f"   {', '.join(extra_list)}")
        if len(wer_results['insertions']) > 20:
            print(f"   ... and {len(wer_results['insertions']) - 20} more")
        print()
    
    # Transcription preview
    print("📝 TRANSCRIBED TEXT:")
    print("=" * 80)
    print(transcript.text[:600])
    print("...")
    print()
    
    # Save detailed report
    report_file = "wer_test_emotions_report.txt"
    with open(report_file, 'w') as f:
        f.write("=" * 80 + "\n")
        f.write("WER TEST WITH EMOTIONS - DETAILED REPORT\n")
        f.write("=" * 80 + "\n\n")
        
        f.write("TEST CONFIGURATION:\n")
        f.write(f"Text length: {len(TEST_TEXT_WITH_EMOTIONS)} chars\n")
        f.write(f"Emotion tags: {len(emotion_tags)}\n")
        f.write(f"Unique emotions: {set(emotion_tags)}\n\n")
        
        f.write("GENERATION RESULTS:\n")
        f.write(f"Chunked: {chunked}\n")
        f.write(f"Audio bytes: {audio_bytes}\n")
        f.write(f"Duration: {actual_duration:.1f}s\n")
        f.write(f"RTF: {rtf}\n\n")
        
        f.write("SPEECH ANALYSIS:\n")
        f.write(f"Speech time: {speech_time:.1f}s ({speech_ratio:.1f}%)\n")
        f.write(f"Silence: {actual_duration - speech_time:.1f}s ({100-speech_ratio:.1f}%)\n")
        f.write(f"Segments: {len(segments)}\n")
        f.write(f"Large gaps: {len(gaps)}\n\n")
        
        f.write("WER METRICS:\n")
        f.write(f"Total words: {wer_results['total_words']}\n")
        f.write(f"Correct: {wer_results['correct']} ({wer_results['accuracy']:.1f}%)\n")
        f.write(f"Deletions: {wer_results['num_deletions']}\n")
        f.write(f"Insertions: {wer_results['num_insertions']}\n")
        f.write(f"WER: {wer_results['wer']:.1f}%\n\n")
        
        f.write("=" * 80 + "\n")
        f.write("REFERENCE TEXT (Expected)\n")
        f.write("=" * 80 + "\n")
        f.write(REFERENCE_TEXT + "\n\n")
        
        f.write("=" * 80 + "\n")
        f.write("TRANSCRIBED TEXT (Actual)\n")
        f.write("=" * 80 + "\n")
        f.write(transcript.text + "\n\n")
        
        if wer_results['deletions']:
            f.write("=" * 80 + "\n")
            f.write(f"MISSING WORDS ({len(wer_results['deletions'])})\n")
            f.write("=" * 80 + "\n")
            f.write(', '.join(sorted(wer_results['deletions'])) + "\n\n")
        
        if wer_results['insertions']:
            f.write("=" * 80 + "\n")
            f.write(f"EXTRA WORDS ({len(wer_results['insertions'])})\n")
            f.write("=" * 80 + "\n")
            f.write(', '.join(sorted(wer_results['insertions'])) + "\n\n")
        
        if segments:
            f.write("=" * 80 + "\n")
            f.write("SEGMENTS WITH TIMESTAMPS\n")
            f.write("=" * 80 + "\n")
            for i, seg in enumerate(segments):
                f.write(f"\n[{seg.start:.2f}s - {seg.end:.2f}s] Segment {i+1}:\n")
                f.write(f"{seg.text}\n")
    
    print(f"💾 Detailed report saved: {report_file}")
    print()
    
    # Final verdict
    print("=" * 80)
    print("🏁 FINAL VERDICT")
    print("=" * 80)
    
    issues = []
    
    if speech_ratio < 85:
        issues.append(f"Low speech ratio ({speech_ratio:.1f}%)")
    
    if len(gaps) > 0:
        issues.append(f"{len(gaps)} large gaps detected")
    
    if wer_results['accuracy'] < 80:
        issues.append(f"Low accuracy ({wer_results['accuracy']:.1f}%)")
    
    if wer_results['num_deletions'] > wer_results['total_words'] * 0.15:
        issues.append(f"Too many missing words ({wer_results['num_deletions']}/{wer_results['total_words']})")
    
    if issues:
        print("⚠️  ISSUES DETECTED:")
        for issue in issues:
            print(f"   - {issue}")
        print()
        print("VERDICT: NEEDS INVESTIGATION")
    else:
        print("✅ EXCELLENT PERFORMANCE:")
        print(f"   - Speech ratio: {speech_ratio:.1f}% ✅")
        print(f"   - No large gaps ✅")
        print(f"   - Accuracy: {wer_results['accuracy']:.1f}% ✅")
        print(f"   - WER: {wer_results['wer']:.1f}% ✅")
        print()
        print("VERDICT: ✅ PRODUCTION READY")
    
    print()
    print("📁 Files created:")
    print(f"   - {output_file} (audio)")
    print(f"   - {report_file} (detailed analysis)")
    print()


if __name__ == '__main__':
    main()

