#!/usr/bin/env python3
"""
Scoring Test
============

Test alignment scoring on existing transcriptions from final_analysis.json.
Uses CTC-based scoring to evaluate transcription quality.
"""
import json
import os
import sys
from pathlib import Path
from datetime import datetime

sys.path.insert(0, str(Path(__file__).parent))

from src.validators.alignment_scorer import AlignmentScorer, AlignmentResult


def load_analysis():
    """Load existing analysis results."""
    with open('analysis_results/final_analysis.json', 'r', encoding='utf-8') as f:
        return json.load(f)


def run_scoring_test():
    """Score all transcriptions from analysis file."""
    print("=" * 80)
    print("ALIGNMENT SCORING TEST")
    print("=" * 80)
    
    # Load analysis
    analysis = load_analysis()
    segments = analysis['segments']
    
    print(f"\nLoaded {len(segments)} segments")
    print(f"Model: gemini-3-flash-preview (temp0_low)")
    print(f"Language: {analysis['metadata']['language']}")
    
    # Initialize scorer
    scorer = AlignmentScorer(language="te")
    
    results = []
    
    for i, seg in enumerate(segments):
        print(f"\n[{i+1}/{len(segments)}] {seg['segment_id']}", end="", flush=True)
        
        audio_path = seg['audio_path']
        
        # Check audio exists
        if not os.path.exists(audio_path):
            print(" - AUDIO NOT FOUND")
            continue
        
        # Get transcription from gemini-3-flash temp0_low
        model_data = seg['models'].get('gemini-3-flash-preview', {})
        config_data = model_data.get('temp0_low', {})
        
        native = config_data.get('native', '')
        punctuated = config_data.get('punctuated', '')
        
        if not native:
            print(" - NO TRANSCRIPTION")
            continue
        
        # Score native transcription
        print(f" ({seg['duration_sec']:.1f}s)...", end=" ", flush=True)
        
        score_result = scorer.score_transcription(audio_path, native)
        
        print(f"Score: {score_result.alignment_score:.4f}, Conf: {score_result.average_confidence:.4f}")
        
        # Store result
        results.append({
            "segment_id": seg['segment_id'],
            "audio_path": audio_path,
            "duration_sec": seg['duration_sec'],
            "transcription": native,
            "punctuated": punctuated,
            "scoring": score_result.to_dict()
        })
    
    scorer.cleanup()
    
    return results


def analyze_results(results):
    """Analyze scoring results."""
    print("\n" + "=" * 80)
    print("ANALYSIS")
    print("=" * 80)
    
    scores = [r['scoring']['alignment_score'] for r in results]
    confidences = [r['scoring']['average_confidence'] for r in results]
    min_confs = [r['scoring']['min_confidence'] for r in results]
    
    print(f"\nAlignment Scores:")
    print(f"  Mean: {sum(scores)/len(scores):.4f}")
    print(f"  Min: {min(scores):.4f}")
    print(f"  Max: {max(scores):.4f}")
    
    print(f"\nAverage Confidence:")
    print(f"  Mean: {sum(confidences)/len(confidences):.4f}")
    print(f"  Min: {min(confidences):.4f}")
    print(f"  Max: {max(confidences):.4f}")
    
    print(f"\nMin Confidence (per segment):")
    print(f"  Mean: {sum(min_confs)/len(min_confs):.4f}")
    print(f"  Lowest: {min(min_confs):.4f}")
    
    # Score distribution
    print(f"\nScore Distribution:")
    brackets = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    for i in range(len(brackets)-1):
        count = sum(1 for s in scores if brackets[i] <= s < brackets[i+1])
        print(f"  {brackets[i]:.1f}-{brackets[i+1]:.1f}: {count} ({count*100/len(scores):.0f}%)")
    
    # Show best and worst
    sorted_results = sorted(results, key=lambda x: x['scoring']['alignment_score'])
    
    print(f"\nTop 3 Highest Scores:")
    for r in sorted_results[-3:][::-1]:
        print(f"  {r['scoring']['alignment_score']:.4f}: {r['transcription'][:50]}...")
    
    print(f"\nTop 3 Lowest Scores:")
    for r in sorted_results[:3]:
        print(f"  {r['scoring']['alignment_score']:.4f}: {r['transcription'][:50]}...")


def save_results(results):
    """Save results with scores."""
    output = {
        "metadata": {
            "model": "gemini-3-flash-preview",
            "config": "temp0_low",
            "language": "Telugu",
            "segments_scored": len(results),
            "timestamp": datetime.now().isoformat()
        },
        "summary": {
            "mean_alignment_score": sum(r['scoring']['alignment_score'] for r in results) / len(results),
            "mean_confidence": sum(r['scoring']['average_confidence'] for r in results) / len(results),
        },
        "results": results
    }
    
    output_path = "analysis_results/scoring_results.json"
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output, f, ensure_ascii=False, indent=2)
    
    print(f"\nSaved: {output_path}")


def main():
    # Run scoring
    results = run_scoring_test()
    
    if not results:
        print("No results to analyze")
        return
    
    # Analyze
    analyze_results(results)
    
    # Save
    save_results(results)
    
    print("\n" + "=" * 80)
    print("RECOMMENDATIONS")
    print("=" * 80)
    scores = [r['scoring']['alignment_score'] for r in results]
    mean_score = sum(scores) / len(scores)
    
    if mean_score > 0.7:
        print(f"""
✅ Mean alignment score: {mean_score:.4f}

This indicates good transcription quality. The CTC-based scoring 
provides useful confidence metrics without needing IndicConformer.

Suggested thresholds:
  - High confidence: score >= 0.8
  - Medium confidence: 0.6 <= score < 0.8
  - Low confidence: score < 0.6

For production:
  - Accept transcriptions with score >= 0.7 automatically
  - Flag transcriptions with score < 0.6 for review
""")
    else:
        print(f"""
⚠️ Mean alignment score: {mean_score:.4f}

Scores are lower than expected. Consider:
  - Checking audio quality
  - Reviewing transcription accuracy
  - Using IndicConformer for comparison on low-score segments
""")


if __name__ == "__main__":
    main()
