#!/usr/bin/env python3
"""
Create Final Consolidated Analysis Document
============================================

Consolidates all results into one clean JSON file:
- segment_id
- model responses (temp1_high, temp0_low)
- validation results (MFA, Conformer)
"""
import json
import os
from datetime import datetime

# Source files
ORIGINAL_FILE = "analysis_results/model_analysis_20260204_215804.json"
TEMP0_LOW_FILE = "analysis_results/temperature_analysis_partial.json"
OUTPUT_FILE = "analysis_results/final_analysis.json"

# Models to include (excluding 2.0)
MODELS = [
    "gemini-3-pro-preview",
    "gemini-3-flash-preview",
    "gemini-2.5-pro",
    "gemini-2.5-flash",
    "gemini-2.5-flash-lite",
]


def load_original_results():
    """Load original results (temp=1, thinking=high)."""
    with open(ORIGINAL_FILE, 'r', encoding='utf-8') as f:
        return json.load(f)


def load_temp0_results():
    """Load temp0_low results."""
    with open(TEMP0_LOW_FILE, 'r', encoding='utf-8') as f:
        return json.load(f)


def create_consolidated_analysis():
    """Create consolidated analysis document."""
    original = load_original_results()
    temp0 = load_temp0_results()
    
    # Get segments list from first model
    first_model = MODELS[0]
    segments = [r['segment_id'] for r in original['results_by_model'][first_model]]
    
    # Build consolidated structure
    analysis = {
        "metadata": {
            "video_id": original['video_id'],
            "language": original['language'],
            "total_segments": len(segments),
            "models_tested": MODELS,
            "configs_tested": ["temp1_high", "temp0_low"],
            "created_at": datetime.now().isoformat(),
            "notes": {
                "temp1_high": "Temperature=1.0 with thinking_level=high (Gemini 3) or default (2.5)",
                "temp0_low": "Temperature=0.0 with thinking_level=low (Gemini 3) or default (2.5)",
                "thinking_budget": "For Gemini 3 with temp=0, use thinking_budget=300 to prevent loops"
            }
        },
        "summary": {
            "speed_ranking": [],
            "recommendations": {}
        },
        "segments": []
    }
    
    # Process each segment
    for seg_idx, segment_id in enumerate(segments):
        segment_data = {
            "segment_id": segment_id,
            "audio_path": None,
            "duration_sec": None,
            "models": {},
            "validation": {
                "indicmfa": None,
                "indic_conformer": None
            }
        }
        
        # Get data from each model
        for model in MODELS:
            # Original (temp1_high)
            orig_results = original['results_by_model'].get(model, [])
            orig_seg = next((r for r in orig_results if r['segment_id'] == segment_id), None)
            
            # temp0_low
            temp0_results = temp0['results_by_config'].get('temp0_low', {}).get(model, [])
            temp0_seg = next((r for r in temp0_results if r['segment_id'] == segment_id), None)
            
            model_data = {
                "temp1_high": None,
                "temp0_low": None
            }
            
            if orig_seg:
                segment_data["audio_path"] = orig_seg.get('audio_path')
                segment_data["duration_sec"] = orig_seg.get('duration_sec')
                
                model_data["temp1_high"] = {
                    "native": orig_seg.get('native'),
                    "punctuated": orig_seg.get('punctuated'),
                    "code_switch": orig_seg.get('code_switch'),
                    "romanized": orig_seg.get('romanized'),
                    "time_sec": orig_seg.get('transcription_time_sec')
                }
                
                # Get validation from original (same audio, run once)
                if orig_seg.get('indicmfa_transcription'):
                    segment_data["validation"]["indicmfa"] = {
                        "transcription": orig_seg.get('indicmfa_transcription'),
                        "confidence": orig_seg.get('indicmfa_confidence'),
                        "word_count": orig_seg.get('indicmfa_word_count'),
                        "time_sec": orig_seg.get('indicmfa_time_sec')
                    }
                
                if orig_seg.get('indic_conformer_transcription'):
                    segment_data["validation"]["indic_conformer"] = {
                        "transcription": orig_seg.get('indic_conformer_transcription'),
                        "confidence": orig_seg.get('indic_conformer_confidence'),
                        "time_sec": orig_seg.get('indic_conformer_time_sec')
                    }
            
            if temp0_seg:
                model_data["temp0_low"] = {
                    "native": temp0_seg.get('native'),
                    "punctuated": temp0_seg.get('punctuated'),
                    "code_switch": temp0_seg.get('code_switch'),
                    "romanized": temp0_seg.get('romanized'),
                    "time_sec": temp0_seg.get('transcription_time_sec')
                }
            
            segment_data["models"][model] = model_data
        
        analysis["segments"].append(segment_data)
    
    # Calculate summary statistics
    speed_data = []
    for model in MODELS:
        orig_results = original['results_by_model'].get(model, [])
        temp0_results = temp0['results_by_config'].get('temp0_low', {}).get(model, [])
        
        avg_orig = sum(r.get('transcription_time_sec', 0) for r in orig_results) / max(len(orig_results), 1)
        avg_temp0 = sum(r.get('transcription_time_sec', 0) for r in temp0_results) / max(len(temp0_results), 1)
        
        speed_data.append({
            "model": model,
            "temp1_high_avg_sec": round(avg_orig, 2),
            "temp0_low_avg_sec": round(avg_temp0, 2)
        })
    
    # Sort by temp0_low speed
    speed_data.sort(key=lambda x: x['temp0_low_avg_sec'])
    analysis["summary"]["speed_ranking"] = speed_data
    
    analysis["summary"]["recommendations"] = {
        "high_volume_production": "gemini-2.5-flash-lite (fastest, 1.5-2s avg)",
        "balanced_quality_speed": "gemini-2.5-flash (good quality, 3-4s avg)",
        "maximum_quality": "gemini-3-pro-preview (best quality, use thinking_budget=300 with temp=0)",
        "settings_notes": {
            "gemini_3_models": "Use temp=0 with thinking_budget=300 to prevent loops",
            "gemini_2.5_models": "temp=0 works fine, no thinking_budget needed"
        }
    }
    
    return analysis


def main():
    print("Creating consolidated analysis...")
    
    analysis = create_consolidated_analysis()
    
    # Save
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        json.dump(analysis, f, ensure_ascii=False, indent=2)
    
    print(f"Saved: {OUTPUT_FILE}")
    
    # Print summary
    print(f"\n{'='*60}")
    print("SUMMARY")
    print(f"{'='*60}")
    print(f"Segments: {analysis['metadata']['total_segments']}")
    print(f"Models: {len(analysis['metadata']['models_tested'])}")
    print(f"\nSpeed Ranking (temp0_low):")
    for i, s in enumerate(analysis['summary']['speed_ranking'], 1):
        print(f"  {i}. {s['model']}: {s['temp0_low_avg_sec']}s")
    
    print(f"\nRecommendations:")
    for k, v in analysis['summary']['recommendations'].items():
        if k != 'settings_notes':
            print(f"  {k}: {v}")


if __name__ == "__main__":
    main()
