#!/usr/bin/env python3
"""
Temperature & Thinking Level Analysis
=====================================

Tests models with different temperature and thinking level combinations:
- Temperature 0 + Low thinking
- Temperature 0 + High thinking

Skips validation (already computed in previous run).
"""
import os
import sys
import json
import time
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional

sys.path.insert(0, str(Path(__file__).parent))

from src.backend.config import GEMINI_MODELS
from src.backend.audio_processor import AudioProcessor
from src.backend.gemini_transcriber import GeminiTranscriber, TranscriptionConfig

# === CONFIGURATION ===
VIDEO_ID = "pF_BQpHaIdU"
LANGUAGE = "Telugu"
SEGMENTS_DIR = f"/tmp/maya3_transcribe/{VIDEO_ID}/extracted/{VIDEO_ID}/segments"
OUTPUT_DIR = "./analysis_results"
SAMPLES_PER_MODEL = 20

# Models to test
MODELS_TO_TEST = [
    "gemini-3-pro-preview",
    "gemini-3-flash-preview", 
    "gemini-2.5-pro",
    "gemini-2.5-flash",
    "gemini-2.5-flash-lite",
    "gemini-2.0-flash",
]

# Test configurations: (temperature, thinking_level, label)
TEST_CONFIGS = [
    (0.0, "low", "temp0_low"),
    (0.0, "high", "temp0_high"),
]


@dataclass
class SegmentResult:
    """Result for a single segment."""
    segment_id: str
    audio_path: str
    duration_sec: float
    model: str
    temperature: float
    thinking_level: Optional[str]
    config_label: str
    
    # Transcription outputs
    native: Optional[str] = None
    punctuated: Optional[str] = None
    code_switch: Optional[str] = None
    romanized: Optional[str] = None
    
    # Timing
    transcription_time_sec: float = 0.0
    transcription_error: Optional[str] = None


def get_audio_segments(segments_dir: str, max_segments: int) -> List:
    """Get the SAME segments as the previous run for comparison."""
    processor = AudioProcessor(
        max_duration_sec=10.0,
        min_duration_sec=2.0
    )
    
    chunks = processor.process_segments_directory(
        segments_dir,
        max_segments=None,
        skip_short=True
    )
    
    # Filter to segments 2-10s
    good_chunks = [c for c in chunks if 2.0 <= c.duration_sec <= 10.0]
    
    # Select same samples as before (spread across file)
    if len(good_chunks) > max_segments:
        step = len(good_chunks) // max_segments
        selected = [good_chunks[i * step] for i in range(max_segments)]
    else:
        selected = good_chunks[:max_segments]
    
    print(f"Selected {len(selected)} segments")
    return selected


def run_transcriptions(
    segments: List,
    model: str,
    temperature: float,
    thinking_level: Optional[str],
    config_label: str,
    language: str
) -> List[SegmentResult]:
    """Run transcriptions with specific settings."""
    print(f"\n{'='*60}")
    print(f"Model: {model}")
    print(f"Temperature: {temperature}, Thinking: {thinking_level or 'default'}")
    print(f"Config: {config_label}")
    print(f"{'='*60}")
    
    transcriber = GeminiTranscriber()
    config = TranscriptionConfig(
        model=model,
        thinking_level=thinking_level,
        temperature=temperature,
        language=language
    )
    
    results = []
    
    for i, chunk in enumerate(segments):
        print(f"[{i+1}/{len(segments)}] {chunk.original_segment}...", end=" ", flush=True)
        
        result = SegmentResult(
            segment_id=chunk.original_segment,
            audio_path=chunk.file_path,
            duration_sec=chunk.duration_sec,
            model=model,
            temperature=temperature,
            thinking_level=thinking_level,
            config_label=config_label
        )
        
        start_time = time.time()
        
        try:
            transcription_result = transcriber.transcribe_chunk(chunk, config)
            
            result.native = transcription_result.transcription.native_transcription
            result.punctuated = transcription_result.transcription.native_with_punctuation
            result.code_switch = transcription_result.transcription.code_switch
            result.romanized = transcription_result.transcription.romanized
            result.transcription_time_sec = transcription_result.processing_time_sec or (time.time() - start_time)
            
            print(f"OK ({result.transcription_time_sec:.1f}s)")
            
        except Exception as e:
            result.transcription_error = str(e)
            result.transcription_time_sec = time.time() - start_time
            print(f"ERROR: {e}")
        
        results.append(result)
        time.sleep(0.5)  # Rate limiting
    
    return results


def main():
    """Main analysis function."""
    print("=" * 80)
    print("TEMPERATURE & THINKING LEVEL ANALYSIS")
    print("=" * 80)
    print(f"\nVideo ID: {VIDEO_ID}")
    print(f"Language: {LANGUAGE}")
    print(f"Samples per config: {SAMPLES_PER_MODEL}")
    print(f"Models: {len(MODELS_TO_TEST)}")
    print(f"Configs: {len(TEST_CONFIGS)}")
    print(f"Total transcriptions: {SAMPLES_PER_MODEL * len(MODELS_TO_TEST) * len(TEST_CONFIGS)}")
    
    # Get segments
    print("\n" + "=" * 60)
    print("Loading Audio Segments")
    print("=" * 60)
    
    segments = get_audio_segments(SEGMENTS_DIR, SAMPLES_PER_MODEL)
    
    if not segments:
        print("ERROR: No segments found!")
        return
    
    # Store all results
    all_results = {}
    
    # Run for each config
    for temperature, thinking, config_label in TEST_CONFIGS:
        print(f"\n{'#'*80}")
        print(f"CONFIG: {config_label} (temp={temperature}, thinking={thinking})")
        print(f"{'#'*80}")
        
        config_results = {}
        
        for model in MODELS_TO_TEST:
            # Determine thinking level for this model
            model_info = GEMINI_MODELS.get(model, {})
            
            # Only Gemini 3 models support explicit thinking levels
            if model.startswith("gemini-3"):
                actual_thinking = thinking
            else:
                actual_thinking = None  # Other models don't use thinking param
            
            results = run_transcriptions(
                segments,
                model=model,
                temperature=temperature,
                thinking_level=actual_thinking,
                config_label=config_label,
                language=LANGUAGE
            )
            
            config_results[model] = results
        
        all_results[config_label] = config_results
        
        # Save intermediate
        save_results(all_results, "temperature_analysis_partial.json")
    
    # Save final results
    output_path = save_results(all_results, None)
    
    # Print summary
    print_summary(all_results)
    
    print(f"\n{'='*80}")
    print(f"Analysis complete! Results saved to: {output_path}")
    print(f"{'='*80}")


def save_results(all_results: Dict, filename: Optional[str]) -> str:
    """Save results to JSON."""
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"temperature_analysis_{timestamp}.json"
    
    output_path = os.path.join(OUTPUT_DIR, filename)
    
    # Convert to serializable format
    output_data = {
        "timestamp": datetime.now().isoformat(),
        "video_id": VIDEO_ID,
        "language": LANGUAGE,
        "samples_per_config": SAMPLES_PER_MODEL,
        "configs_tested": list(all_results.keys()),
        "models_tested": MODELS_TO_TEST,
        "results_by_config": {
            config: {
                model: [asdict(r) for r in results]
                for model, results in model_results.items()
            }
            for config, model_results in all_results.items()
        }
    }
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=2)
    
    print(f"\nResults saved to: {output_path}")
    return output_path


def print_summary(all_results: Dict):
    """Print summary statistics."""
    print(f"\n{'='*80}")
    print("SUMMARY")
    print(f"{'='*80}")
    
    for config_label, model_results in all_results.items():
        print(f"\n--- {config_label} ---")
        print(f"{'Model':<30} {'Success':<10} {'Avg Time':<12}")
        print("-" * 55)
        
        for model, results in model_results.items():
            success = sum(1 for r in results if r.native)
            total = len(results)
            avg_time = sum(r.transcription_time_sec for r in results) / max(total, 1)
            
            print(f"{model:<30} {success}/{total:<8} {avg_time:.1f}s")
    
    # Compare first segment across all configs
    print(f"\n{'='*80}")
    print("FIRST SEGMENT COMPARISON")
    print(f"{'='*80}")
    
    first_segment = None
    for config_label, model_results in all_results.items():
        for model, results in model_results.items():
            if results and results[0].native:
                first_segment = results[0].segment_id
                break
        if first_segment:
            break
    
    if first_segment:
        print(f"\nSegment: {first_segment}")
        
        for config_label, model_results in all_results.items():
            print(f"\n--- {config_label} ---")
            for model, results in model_results.items():
                for r in results:
                    if r.segment_id == first_segment and r.native:
                        print(f"{model}:")
                        print(f"  {r.native[:80]}...")
                        break


if __name__ == "__main__":
    main()
