#!/usr/bin/env python3
"""
Test script to validate the validators module works correctly.
Runs each validator on sample segments and verifies outputs.
"""
import os
import sys
import json
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))

from src.validators import ValidatorRunner, ValidationResult
from src.validators.base import normalize_language_code

# Test configuration
SEGMENTS_DIR = "/tmp/maya3_transcribe/pF_BQpHaIdU/extracted/pF_BQpHaIdU/segments"
LANGUAGE = "te"  # Telugu
NUM_SAMPLES = 3  # Test with first 3 segments

# Sample reference texts from our earlier transcription (for MFA)
# These are from the Gemini transcription we did earlier
REFERENCE_TEXTS = {
    "SPEAKER_00_0000_0.03-2.61.flac": "నాకు కొన్ని యాడ్స్ గుర్తుంటాయ్ లైక్ ఎగ్జాంపుల్ కచ్చా మ్యాంగో బైట్",
    "SPEAKER_00_0001_3.05-13.40.flac": "ఆ కచ్చితంగా ఒక అబ్బాయి అయితే మా వాడు ఒకడున్నాడు ఐ విల్ నాట్ నేమ్ హిమ్ బట్ మరేంటంటే మ్యాంగోస్ తీసుకుని జిరాక్స్ షాప్ కి వెళ్ళిపోయాడు అన్నమాట",
    "SPEAKER_00_0002_16.08-28.31.flac": "అదేంటది ప్రాసెస్ అని చెప్పి కరెక్ట్ వాడు రెండు మ్యాంగోస్ తీసుకెళ్లి అంటే క్యూరియాసిటీ అవును అవును తెలియకపోవడం కాదు అక్కడికెళ్ళి అరే ఇది ఇస్తే ఏ నిజంగా ఏమంటాడో చూద్దాం",
}


def get_sample_segments(segments_dir: str, num_samples: int = 3):
    """Get first N audio segments for testing."""
    segment_files = sorted(Path(segments_dir).glob("*.flac"))[:num_samples]
    return [str(f) for f in segment_files]


def test_single_validator(runner, audio_path, reference_text, validator_name):
    """Test a single validator and return detailed results."""
    print(f"\n{'='*60}")
    print(f"Testing: {validator_name}")
    print(f"Audio: {Path(audio_path).name}")
    print(f"{'='*60}")
    
    # Create a new runner with only this validator enabled
    single_runner = ValidatorRunner(
        enable_indicwav2vec=(validator_name == "indicwav2vec"),
        enable_indicmfa=(validator_name == "indicmfa"),
        enable_vistaar=(validator_name == "vistaar"),
        enable_indic_conformer=(validator_name == "indic_conformer"),
        language=LANGUAGE,
        device="cuda"  # Use GPU if available
    )
    
    result = single_runner.validate(
        audio_path=audio_path,
        reference_text=reference_text,
        language=LANGUAGE
    )
    
    single_runner.cleanup()
    return result


def validate_output(result: ValidationResult, validator_name: str) -> dict:
    """Validate the output from a validator and return assessment."""
    assessment = {
        "validator": validator_name,
        "success": result.success,
        "issues": [],
        "quality": "unknown"
    }
    
    if not result.success:
        assessment["issues"].append(f"Failed: {result.error_message}")
        assessment["quality"] = "failed"
        return assessment
    
    # Check transcription
    if result.transcription:
        trans = result.transcription.strip()
        
        # Check for empty output
        if not trans:
            assessment["issues"].append("Empty transcription")
            
        # Check for reasonable length (not too short for the audio)
        elif result.audio_duration_sec and len(trans) < result.audio_duration_sec:
            assessment["issues"].append(f"Very short transcription ({len(trans)} chars for {result.audio_duration_sec:.1f}s audio)")
            
        # Check for garbage (all same character repeated)
        elif len(set(trans.replace(" ", ""))) < 3:
            assessment["issues"].append("Possibly garbage output (low character diversity)")
            
        # Check for reasonable word count
        word_count = len(trans.split())
        if result.audio_duration_sec:
            wps = word_count / result.audio_duration_sec
            if wps > 10:
                assessment["issues"].append(f"Unrealistic WPS: {wps:.1f} (too fast)")
            elif wps < 0.5:
                assessment["issues"].append(f"Unrealistic WPS: {wps:.1f} (too slow)")
    else:
        assessment["issues"].append("No transcription output")
    
    # Check word alignments
    if result.word_alignments:
        # Check for reasonable timestamps
        for wa in result.word_alignments:
            if wa.start_time < 0 or wa.end_time < wa.start_time:
                assessment["issues"].append(f"Invalid timestamp for word: {wa.word}")
                break
    
    # Determine overall quality
    if not assessment["issues"]:
        assessment["quality"] = "good"
    elif len(assessment["issues"]) == 1 and "WPS" in assessment["issues"][0]:
        assessment["quality"] = "acceptable"
    else:
        assessment["quality"] = "poor"
    
    return assessment


def main():
    print("="*60)
    print("VALIDATORS TEST SCRIPT")
    print("="*60)
    
    # Check segments exist
    if not os.path.exists(SEGMENTS_DIR):
        print(f"ERROR: Segments directory not found: {SEGMENTS_DIR}")
        print("Run the transcription pipeline first to download segments.")
        return
    
    # Get sample segments
    sample_segments = get_sample_segments(SEGMENTS_DIR, NUM_SAMPLES)
    print(f"\nTesting with {len(sample_segments)} sample segments")
    
    # Validators to test
    validators_to_test = [
        "indicwav2vec",
        "vistaar",
        # "indicmfa",  # Requires MFA installation - skip for now
        # "indic_conformer",  # May require NeMo - try later
    ]
    
    all_results = {}
    
    for audio_path in sample_segments:
        filename = Path(audio_path).name
        ref_text = REFERENCE_TEXTS.get(filename, None)
        
        print(f"\n\n{'#'*60}")
        print(f"# AUDIO: {filename}")
        print(f"# Reference: {ref_text[:50] if ref_text else 'None'}...")
        print(f"{'#'*60}")
        
        all_results[filename] = {}
        
        for validator_name in validators_to_test:
            try:
                result = test_single_validator(
                    runner=None,
                    audio_path=audio_path,
                    reference_text=ref_text,
                    validator_name=validator_name
                )
                
                # Get the actual validation result
                if validator_name in result.results:
                    vr = result.results[validator_name]
                    
                    # Print results
                    print(f"\nResult from {validator_name}:")
                    print(f"  Success: {vr.success}")
                    
                    if vr.success:
                        print(f"  Transcription: {vr.transcription[:100] if vr.transcription else 'None'}...")
                        print(f"  Audio duration: {vr.audio_duration_sec:.2f}s" if vr.audio_duration_sec else "  Audio duration: N/A")
                        print(f"  Processing time: {vr.processing_time_sec:.2f}s" if vr.processing_time_sec else "  Processing time: N/A")
                        print(f"  Confidence: {vr.overall_confidence:.3f}" if vr.overall_confidence else "  Confidence: N/A")
                        print(f"  Word alignments: {len(vr.word_alignments)} words")
                        
                        # Show first few word alignments
                        if vr.word_alignments:
                            print("  Sample alignments:")
                            for wa in vr.word_alignments[:5]:
                                print(f"    {wa.start_time:.2f}-{wa.end_time:.2f}s: '{wa.word}'")
                    else:
                        print(f"  Error: {vr.error_message}")
                    
                    # Validate output
                    assessment = validate_output(vr, validator_name)
                    print(f"\n  Quality Assessment: {assessment['quality'].upper()}")
                    if assessment['issues']:
                        print(f"  Issues: {assessment['issues']}")
                    
                    all_results[filename][validator_name] = {
                        "success": vr.success,
                        "transcription": vr.transcription,
                        "confidence": vr.overall_confidence,
                        "word_count": len(vr.word_alignments),
                        "quality": assessment['quality'],
                        "issues": assessment['issues']
                    }
                else:
                    print(f"\n  {validator_name}: Not in results")
                    
            except Exception as e:
                print(f"\n  ERROR with {validator_name}: {e}")
                import traceback
                traceback.print_exc()
                all_results[filename][validator_name] = {
                    "success": False,
                    "error": str(e)
                }
    
    # Summary
    print("\n\n" + "="*60)
    print("SUMMARY")
    print("="*60)
    
    for filename, validators in all_results.items():
        print(f"\n{filename}:")
        for validator_name, result in validators.items():
            status = "✓" if result.get("success") else "✗"
            quality = result.get("quality", "error")
            print(f"  {status} {validator_name}: {quality}")
            if result.get("transcription"):
                print(f"      -> {result['transcription'][:60]}...")
    
    # Save results
    output_path = "./validation_results/test_run.json"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    print(f"\nResults saved to: {output_path}")


if __name__ == "__main__":
    main()
