#!/usr/bin/env python3
"""
Model Analysis Script
=====================

Runs comprehensive analysis across multiple Gemini models:
1. Transcribes 20 segments per model
2. Runs validation (IndicMFA + IndicConformer) on each transcription
3. Stores all results for analysis
4. Generates statistics summary

Usage:
    python run_model_analysis.py
"""
import os
import sys
import json
import time
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))

from src.backend.config import GEMINI_MODELS
from src.backend.audio_processor import AudioProcessor
from src.backend.gemini_transcriber import GeminiTranscriber, TranscriptionConfig

# === CONFIGURATION ===
VIDEO_ID = "pF_BQpHaIdU"
LANGUAGE = "Telugu"
SEGMENTS_DIR = f"/tmp/maya3_transcribe/{VIDEO_ID}/extracted/{VIDEO_ID}/segments"
OUTPUT_DIR = "./analysis_results"
SAMPLES_PER_MODEL = 20

# Models to test
MODELS_TO_TEST = [
    "gemini-3-pro-preview",
    "gemini-3-flash-preview",
    "gemini-2.5-pro",
    "gemini-2.5-flash",
    "gemini-2.5-flash-lite",
    "gemini-2.0-flash",
]

# Thinking levels per model (only Gemini 3 supports explicit thinking levels)
MODEL_THINKING = {
    "gemini-3-pro-preview": "high",
    "gemini-3-flash-preview": "high",
    "gemini-2.5-pro": None,
    "gemini-2.5-flash": None,
    "gemini-2.5-flash-lite": None,
    "gemini-2.0-flash": None,
}


@dataclass
class SegmentResult:
    """Result for a single segment."""
    segment_id: str
    audio_path: str
    duration_sec: float
    model: str
    thinking_level: Optional[str]
    
    # Transcription outputs
    native: Optional[str] = None
    punctuated: Optional[str] = None
    code_switch: Optional[str] = None
    romanized: Optional[str] = None
    
    # Timing
    transcription_time_sec: float = 0.0
    
    # Validation results
    indicmfa_transcription: Optional[str] = None
    indicmfa_confidence: Optional[float] = None
    indicmfa_word_count: int = 0
    indicmfa_time_sec: float = 0.0
    
    indic_conformer_transcription: Optional[str] = None
    indic_conformer_confidence: Optional[float] = None
    indic_conformer_time_sec: float = 0.0
    
    # Error tracking
    transcription_error: Optional[str] = None
    validation_error: Optional[str] = None


def get_audio_segments(
    segments_dir: str,
    max_segments: int,
    max_duration_sec: float = 10.0,
    min_duration_sec: float = 2.0
) -> List[Dict]:
    """Get list of audio segments to process."""
    processor = AudioProcessor(
        max_duration_sec=max_duration_sec,
        min_duration_sec=min_duration_sec
    )
    
    # Get all segments
    chunks = processor.process_segments_directory(
        segments_dir,
        max_segments=None,  # Get all first
        skip_short=True
    )
    
    # Filter to segments under 10s and over 2s for better testing
    good_chunks = [c for c in chunks if 2.0 <= c.duration_sec <= 10.0]
    
    # Select diverse samples (spread across the file)
    if len(good_chunks) > max_segments:
        step = len(good_chunks) // max_segments
        selected = [good_chunks[i * step] for i in range(max_segments)]
    else:
        selected = good_chunks[:max_segments]
    
    print(f"Selected {len(selected)} segments for analysis")
    return selected


def run_transcriptions(
    segments: List,
    model: str,
    thinking_level: Optional[str],
    language: str
) -> List[SegmentResult]:
    """Run transcriptions for all segments with a specific model."""
    print(f"\n{'='*60}")
    print(f"Model: {model} (thinking: {thinking_level or 'default'})")
    print(f"{'='*60}")
    
    transcriber = GeminiTranscriber()
    config = TranscriptionConfig(
        model=model,
        thinking_level=thinking_level,
        temperature=1.0,
        language=language
    )
    
    results = []
    
    for i, chunk in enumerate(segments):
        print(f"[{i+1}/{len(segments)}] {chunk.original_segment}...", end=" ", flush=True)
        
        result = SegmentResult(
            segment_id=chunk.original_segment,
            audio_path=chunk.file_path,
            duration_sec=chunk.duration_sec,
            model=model,
            thinking_level=thinking_level
        )
        
        start_time = time.time()
        
        try:
            transcription = transcriber.transcribe(chunk, config)
            
            result.native = transcription.native_text
            result.punctuated = transcription.punctuated_text
            result.code_switch = transcription.code_switch_text
            result.romanized = transcription.romanized_text
            result.transcription_time_sec = time.time() - start_time
            
            print(f"OK ({result.transcription_time_sec:.1f}s)")
            
        except Exception as e:
            result.transcription_error = str(e)
            result.transcription_time_sec = time.time() - start_time
            print(f"ERROR: {e}")
        
        results.append(result)
        
        # Rate limiting
        time.sleep(0.5)
    
    return results


def run_validation(results: List[SegmentResult], language: str) -> List[SegmentResult]:
    """Run validation on all transcribed segments."""
    print(f"\n{'='*60}")
    print("Running Validation (IndicMFA + IndicConformer)")
    print(f"{'='*60}")
    
    try:
        from src.validators import ValidatorRunner
        
        runner = ValidatorRunner(
            enable_indicmfa=True,
            enable_indic_conformer=True,
            language=language,
            device="cuda"
        )
    except Exception as e:
        print(f"Failed to initialize validators: {e}")
        return results
    
    for i, result in enumerate(results):
        if not result.native:
            continue
            
        print(f"[{i+1}/{len(results)}] Validating {result.segment_id}...", end=" ", flush=True)
        
        start_time = time.time()
        
        try:
            validation = runner.validate(
                audio_path=result.audio_path,
                reference_text=result.native,  # Use native transcription as reference
                language=language
            )
            
            # IndicMFA results
            if "indicmfa" in validation.results:
                mfa = validation.results["indicmfa"]
                result.indicmfa_transcription = mfa.transcription
                result.indicmfa_confidence = mfa.overall_confidence
                result.indicmfa_word_count = len(mfa.word_alignments) if mfa.word_alignments else 0
                result.indicmfa_time_sec = mfa.processing_time_sec or 0.0
            
            # IndicConformer results
            if "indic_conformer" in validation.results:
                conformer = validation.results["indic_conformer"]
                result.indic_conformer_transcription = conformer.transcription
                result.indic_conformer_confidence = conformer.overall_confidence
                result.indic_conformer_time_sec = conformer.processing_time_sec or 0.0
            
            print(f"OK")
            
        except Exception as e:
            result.validation_error = str(e)
            print(f"ERROR: {e}")
    
    runner.cleanup()
    return results


def save_results(all_results: Dict[str, List[SegmentResult]], output_dir: str) -> str:
    """Save all results to JSON."""
    os.makedirs(output_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"model_analysis_{timestamp}.json"
    output_path = os.path.join(output_dir, filename)
    
    # Convert to serializable format
    output_data = {
        "timestamp": datetime.now().isoformat(),
        "video_id": VIDEO_ID,
        "language": LANGUAGE,
        "samples_per_model": SAMPLES_PER_MODEL,
        "models_tested": list(all_results.keys()),
        "results_by_model": {
            model: [asdict(r) for r in results]
            for model, results in all_results.items()
        }
    }
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=2)
    
    print(f"\nResults saved to: {output_path}")
    return output_path


def print_summary(all_results: Dict[str, List[SegmentResult]]):
    """Print summary statistics."""
    print(f"\n{'='*80}")
    print("ANALYSIS SUMMARY")
    print(f"{'='*80}")
    
    print(f"\n{'Model':<30} {'Success':<10} {'Avg Time':<12} {'Conformer Match'}")
    print("-" * 80)
    
    for model, results in all_results.items():
        success = sum(1 for r in results if r.native)
        total = len(results)
        
        avg_time = sum(r.transcription_time_sec for r in results) / max(total, 1)
        
        # Count how many have conformer output
        conformer_count = sum(1 for r in results if r.indic_conformer_transcription)
        
        print(f"{model:<30} {success}/{total:<8} {avg_time:.1f}s{'':<8} {conformer_count}/{total}")
    
    print(f"\n{'='*80}")
    print("Sample Transcription Comparison (First Segment)")
    print(f"{'='*80}")
    
    # Get first segment that all models processed
    first_segment = None
    for model, results in all_results.items():
        if results and results[0].native:
            first_segment = results[0].segment_id
            break
    
    if first_segment:
        print(f"\nSegment: {first_segment}")
        print("-" * 80)
        
        for model, results in all_results.items():
            for r in results:
                if r.segment_id == first_segment and r.native:
                    print(f"\n{model}:")
                    print(f"  Native:    {r.native[:80]}...")
                    if r.indic_conformer_transcription:
                        print(f"  Conformer: {r.indic_conformer_transcription[:80]}...")
                    break


def main():
    """Main analysis function."""
    print("=" * 80)
    print("MODEL ANALYSIS - Gemini Transcription Comparison")
    print("=" * 80)
    print(f"\nVideo ID: {VIDEO_ID}")
    print(f"Language: {LANGUAGE}")
    print(f"Samples per model: {SAMPLES_PER_MODEL}")
    print(f"Models to test: {len(MODELS_TO_TEST)}")
    print(f"Total transcriptions: {SAMPLES_PER_MODEL * len(MODELS_TO_TEST)}")
    
    # Get audio segments
    print("\n" + "=" * 60)
    print("Step 1: Loading Audio Segments")
    print("=" * 60)
    
    segments = get_audio_segments(
        SEGMENTS_DIR,
        max_segments=SAMPLES_PER_MODEL
    )
    
    if not segments:
        print("ERROR: No segments found!")
        return
    
    # Run transcriptions for each model
    all_results = {}
    
    for model in MODELS_TO_TEST:
        thinking = MODEL_THINKING.get(model)
        
        results = run_transcriptions(
            segments,
            model=model,
            thinking_level=thinking,
            language=LANGUAGE
        )
        
        all_results[model] = results
        
        # Save intermediate results
        save_results(all_results, OUTPUT_DIR)
    
    # Run validation on all results
    print("\n" + "=" * 60)
    print("Step 2: Running Validation")
    print("=" * 60)
    
    for model, results in all_results.items():
        print(f"\nValidating {model} results...")
        all_results[model] = run_validation(results, LANGUAGE)
    
    # Save final results
    output_path = save_results(all_results, OUTPUT_DIR)
    
    # Print summary
    print_summary(all_results)
    
    print(f"\n{'='*80}")
    print(f"Analysis complete! Results saved to: {output_path}")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()
