#!/usr/bin/env python3
"""
Test script to validate all validators work correctly.
Runs each validator on sample segments and verifies outputs.
"""
import os
import sys
import json
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))

from src.validators import ValidatorRunner, ValidationResult
from src.validators.base import normalize_language_code

# Test configuration
SEGMENTS_DIR = "/tmp/maya3_transcribe/pF_BQpHaIdU/extracted/pF_BQpHaIdU/segments"
LANGUAGE = "te"  # Telugu
NUM_SAMPLES = 3  # Test with first 3 segments

# Sample reference texts from our earlier transcription (for MFA)
REFERENCE_TEXTS = {
    "SPEAKER_00_0000_0.03-2.61.flac": "నాకు కొన్ని యాడ్స్ గుర్తుంటాయ్ లైక్ ఎగ్జాంపుల్ కచ్చా మ్యాంగో బైట్",
    "SPEAKER_00_0001_3.05-13.40.flac": "ఆ కచ్చితంగా ఒక అబ్బాయి అయితే మా వాడు ఒకడున్నాడు ఐ విల్ నాట్ నేమ్ హిమ్ బట్ మరేంటంటే మ్యాంగోస్ తీసుకుని జిరాక్స్ షాప్ కి వెళ్ళిపోయాడు అన్నమాట",
    "SPEAKER_00_0002_16.08-28.31.flac": "అదేంటది ప్రాసెస్ అని చెప్పి కరెక్ట్ వాడు రెండు మ్యాంగోస్ తీసుకెళ్లి అంటే క్యూరియాసిటీ అవును అవును తెలియకపోవడం కాదు అక్కడికెళ్ళి అరే ఇది ఇస్తే ఏ నిజంగా ఏమంటాడో చూద్దాం",
}


def get_sample_segments(segments_dir: str, num_samples: int = 3):
    """Get first N audio segments for testing."""
    segment_files = sorted(Path(segments_dir).glob("*.flac"))[:num_samples]
    return [str(f) for f in segment_files]


def validate_output(result: ValidationResult, validator_name: str) -> dict:
    """Validate the output from a validator and return assessment."""
    assessment = {
        "validator": validator_name,
        "success": result.success,
        "issues": [],
        "quality": "unknown"
    }
    
    if not result.success:
        assessment["issues"].append(f"Failed: {result.error_message}")
        assessment["quality"] = "failed"
        return assessment
    
    # Check transcription
    if result.transcription:
        trans = result.transcription.strip()
        
        if not trans:
            assessment["issues"].append("Empty transcription")
        elif len(set(trans.replace(" ", ""))) < 3:
            assessment["issues"].append("Possibly garbage output (low character diversity)")
    else:
        assessment["issues"].append("No transcription output")
    
    # Determine overall quality
    if not assessment["issues"]:
        assessment["quality"] = "good"
    elif len(assessment["issues"]) == 1 and "WPS" in str(assessment["issues"]):
        assessment["quality"] = "acceptable"
    else:
        assessment["quality"] = "poor"
    
    return assessment


def main():
    print("=" * 70)
    print("VALIDATORS TEST SCRIPT - All Models")
    print("=" * 70)
    
    # Check segments exist
    if not os.path.exists(SEGMENTS_DIR):
        print(f"ERROR: Segments directory not found: {SEGMENTS_DIR}")
        print("Run the transcription pipeline first to download segments.")
        return
    
    # Get sample segments
    sample_segments = get_sample_segments(SEGMENTS_DIR, NUM_SAMPLES)
    print(f"\nTesting with {len(sample_segments)} sample segments")
    print(f"Language: {LANGUAGE}")
    
    # Initialize runner with all validators
    print("\n" + "=" * 70)
    print("Initializing Validators...")
    print("=" * 70)
    
    runner = ValidatorRunner(
        # IndicWav2Vec removed due to fairseq version incompatibility
        enable_vistaar=True,
        enable_indicmfa=True,  # Will use CTC fallback if MFA not installed
        enable_indic_conformer=True,  # Will try to load
        language=LANGUAGE,
        device="cuda"
    )
    
    # Show validator info
    print("\nValidator Status:")
    for name, validator in runner.validators.items():
        print(f"  {name}: {validator.description}")
    
    all_results = {}
    
    # Process each segment
    for audio_path in sample_segments:
        filename = Path(audio_path).name
        ref_text = REFERENCE_TEXTS.get(filename, None)
        
        print(f"\n\n{'#' * 70}")
        print(f"# AUDIO: {filename}")
        if ref_text:
            print(f"# Reference: {ref_text[:60]}...")
        print(f"{'#' * 70}")
        
        # Run all validators
        result = runner.validate(
            audio_path=audio_path,
            reference_text=ref_text,
            language=LANGUAGE
        )
        
        all_results[filename] = {}
        
        # Process results
        for validator_name, vr in result.results.items():
            print(f"\n--- {validator_name} ---")
            
            if vr.success:
                print(f"  ✓ Success")
                print(f"  Transcription: {vr.transcription[:80] if vr.transcription else 'None'}...")
                print(f"  Audio duration: {vr.audio_duration_sec:.2f}s" if vr.audio_duration_sec else "  Duration: N/A")
                print(f"  Processing time: {vr.processing_time_sec:.2f}s" if vr.processing_time_sec else "  Time: N/A")
                print(f"  Confidence: {vr.overall_confidence:.3f}" if vr.overall_confidence else "  Confidence: N/A")
                print(f"  Word alignments: {len(vr.word_alignments)} words")
                
                if vr.raw_output:
                    model_type = vr.raw_output.get('model_type', 'unknown')
                    model = vr.raw_output.get('model', 'unknown')
                    print(f"  Model: {model} ({model_type})")
            else:
                print(f"  ✗ Failed: {vr.error_message}")
            
            # Validate output
            assessment = validate_output(vr, validator_name)
            
            all_results[filename][validator_name] = {
                "success": vr.success,
                "transcription": vr.transcription,
                "confidence": vr.overall_confidence,
                "word_count": len(vr.word_alignments),
                "quality": assessment['quality'],
                "issues": assessment['issues'],
                "processing_time": vr.processing_time_sec,
                "model_info": vr.raw_output
            }
    
    # Cleanup
    runner.cleanup()
    
    # Summary
    print("\n\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    
    validators_tested = set()
    for filename, validators in all_results.items():
        print(f"\n{filename}:")
        for validator_name, result in validators.items():
            validators_tested.add(validator_name)
            status = "✓" if result.get("success") else "✗"
            quality = result.get("quality", "error")
            print(f"  {status} {validator_name}: {quality}")
            if result.get("transcription"):
                trans = result["transcription"]
                print(f"      -> {trans[:60]}...")
    
    # Validator summary
    print(f"\n\nValidators tested: {len(validators_tested)}")
    for v in sorted(validators_tested):
        print(f"  - {v}")
    
    # Save results
    output_path = "./validation_results/all_validators_test.json"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    print(f"\nResults saved to: {output_path}")


if __name__ == "__main__":
    main()
