#!/usr/bin/env python3
"""
Thinking Budget Analysis
========================

Tests models with temperature=0 and thinking_budget=300 (instead of thinking_level=high)
to prevent thinking loops while maintaining quality.
"""
import os
import sys
import json
import time
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional

sys.path.insert(0, str(Path(__file__).parent))

from src.backend.audio_processor import AudioProcessor
from src.backend.transcription_schema import TranscriptionOutput

# Direct API access for custom thinking config
from google import genai
from google.genai import types

# === CONFIGURATION ===
VIDEO_ID = "pF_BQpHaIdU"
LANGUAGE = "Telugu"
SEGMENTS_DIR = f"/tmp/maya3_transcribe/{VIDEO_ID}/extracted/{VIDEO_ID}/segments"
OUTPUT_DIR = "./analysis_results"
SAMPLES_PER_MODEL = 20

# Thinking budget in tokens (prevents infinite loops)
THINKING_BUDGET = 300

# Models to test
MODELS_TO_TEST = [
    "gemini-3-pro-preview",
    "gemini-3-flash-preview", 
    "gemini-2.5-pro",
    "gemini-2.5-flash",
    "gemini-2.5-flash-lite",
    "gemini-2.0-flash",
]

# Which models support thinking budget
THINKING_MODELS = ["gemini-3-pro-preview", "gemini-3-flash-preview"]


@dataclass
class SegmentResult:
    """Result for a single segment."""
    segment_id: str
    audio_path: str
    duration_sec: float
    model: str
    temperature: float
    thinking_budget: Optional[int]
    config_label: str
    
    native: Optional[str] = None
    punctuated: Optional[str] = None
    code_switch: Optional[str] = None
    romanized: Optional[str] = None
    
    transcription_time_sec: float = 0.0
    transcription_error: Optional[str] = None


def get_transcription_prompt(language: str) -> str:
    """Get the system prompt for transcription."""
    return f"""You are a strict, verbatim transcription engine for Indian languages.

Primary audio language: {language}

TASK:
1. Listen to the audio carefully
2. Transcribe exactly as spoken
3. Produce four outputs in JSON format

STRICT RULES:
- Verbatim only: Include all repetitions, fillers, stammers
- No normalization: Don't correct grammar or pronunciation
- No inference: Don't add meaning not in audio

OUTPUT (JSON):
- native_transcription: Native script, no punctuation
- native_with_punctuation: Native script with minimal punctuation
- code_switch: Mixed script (Indian in native, English in Latin)
- romanized: Everything in Roman/Latin script"""


def get_audio_segments(segments_dir: str, max_segments: int) -> List:
    """Get the SAME segments as previous runs."""
    processor = AudioProcessor(
        max_duration_sec=10.0,
        min_duration_sec=2.0
    )
    
    chunks = processor.process_segments_directory(
        segments_dir,
        max_segments=None,
        skip_short=True
    )
    
    good_chunks = [c for c in chunks if 2.0 <= c.duration_sec <= 10.0]
    
    if len(good_chunks) > max_segments:
        step = len(good_chunks) // max_segments
        selected = [good_chunks[i * step] for i in range(max_segments)]
    else:
        selected = good_chunks[:max_segments]
    
    print(f"Selected {len(selected)} segments")
    return selected


def transcribe_with_budget(
    client: genai.Client,
    audio_path: str,
    model: str,
    language: str,
    thinking_budget: Optional[int]
) -> Dict[str, Any]:
    """Transcribe audio with thinking budget."""
    
    # Load audio
    with open(audio_path, 'rb') as f:
        audio_bytes = f.read()
    
    mime_type = 'audio/flac' if audio_path.endswith('.flac') else 'audio/wav'
    
    # Build thinking config
    if model in THINKING_MODELS and thinking_budget:
        thinking_config = types.ThinkingConfig(thinking_budget=thinking_budget)
    else:
        thinking_config = None
    
    # Build generation config
    gen_config = types.GenerateContentConfig(
        temperature=0.0,
        response_mime_type="application/json",
        system_instruction=[
            types.Part.from_text(text=get_transcription_prompt(language))
        ]
    )
    
    if thinking_config:
        gen_config.thinking_config = thinking_config
    
    # Build content
    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_bytes(
                    mime_type=mime_type,
                    data=audio_bytes
                ),
                types.Part.from_text(text="Transcribe this audio. Return JSON with native_transcription, native_with_punctuation, code_switch, romanized fields.")
            ]
        )
    ]
    
    # Call API
    start_time = time.time()
    try:
        response = client.models.generate_content(
            model=model,
            contents=contents,
            config=gen_config
        )
        processing_time = time.time() - start_time
        
        if response.text:
            result = json.loads(response.text)
            result['_processing_time_sec'] = processing_time
            return result
        else:
            return {"error": "Empty response", "_processing_time_sec": processing_time}
            
    except Exception as e:
        return {"error": str(e), "_processing_time_sec": time.time() - start_time}


def run_transcriptions(
    segments: List,
    model: str,
    language: str,
    thinking_budget: Optional[int],
    config_label: str
) -> List[SegmentResult]:
    """Run transcriptions for all segments."""
    print(f"\n{'='*60}")
    print(f"Model: {model}")
    print(f"Temperature: 0.0, Thinking Budget: {thinking_budget or 'N/A'}")
    print(f"{'='*60}")
    
    # Get API key
    from src.backend.config import GEMINI_API_KEY
    client = genai.Client(api_key=GEMINI_API_KEY)
    
    results = []
    
    for i, chunk in enumerate(segments):
        print(f"[{i+1}/{len(segments)}] {chunk.original_segment}...", end=" ", flush=True)
        
        result = SegmentResult(
            segment_id=chunk.original_segment,
            audio_path=chunk.file_path,
            duration_sec=chunk.duration_sec,
            model=model,
            temperature=0.0,
            thinking_budget=thinking_budget if model in THINKING_MODELS else None,
            config_label=config_label
        )
        
        start_time = time.time()
        
        try:
            raw = transcribe_with_budget(
                client,
                chunk.file_path,
                model,
                language,
                thinking_budget if model in THINKING_MODELS else None
            )
            
            if "error" not in raw:
                result.native = raw.get('native_transcription', '')
                result.punctuated = raw.get('native_with_punctuation', '')
                result.code_switch = raw.get('code_switch', '')
                result.romanized = raw.get('romanized', '')
            else:
                result.transcription_error = raw['error']
            
            result.transcription_time_sec = raw.get('_processing_time_sec', time.time() - start_time)
            print(f"OK ({result.transcription_time_sec:.1f}s)")
            
        except Exception as e:
            result.transcription_error = str(e)
            result.transcription_time_sec = time.time() - start_time
            print(f"ERROR: {e}")
        
        results.append(result)
        time.sleep(0.5)
    
    return results


def main():
    """Main function."""
    print("=" * 80)
    print(f"THINKING BUDGET ANALYSIS (budget={THINKING_BUDGET} tokens)")
    print("=" * 80)
    
    # Get segments
    segments = get_audio_segments(SEGMENTS_DIR, SAMPLES_PER_MODEL)
    
    if not segments:
        print("ERROR: No segments found!")
        return
    
    # Run for each model
    config_label = f"temp0_budget{THINKING_BUDGET}"
    all_results = {config_label: {}}
    
    for model in MODELS_TO_TEST:
        results = run_transcriptions(
            segments,
            model=model,
            language=LANGUAGE,
            thinking_budget=THINKING_BUDGET,
            config_label=config_label
        )
        
        all_results[config_label][model] = results
        
        # Save intermediate
        save_results(all_results, "budget_analysis_partial.json")
    
    # Save final
    output_path = save_results(all_results, None)
    
    # Print summary
    print_summary(all_results)
    
    print(f"\n{'='*80}")
    print(f"Complete! Results: {output_path}")
    print(f"{'='*80}")


def save_results(all_results: Dict, filename: Optional[str]) -> str:
    """Save results to JSON."""
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"budget_analysis_{timestamp}.json"
    
    output_path = os.path.join(OUTPUT_DIR, filename)
    
    output_data = {
        "timestamp": datetime.now().isoformat(),
        "video_id": VIDEO_ID,
        "language": LANGUAGE,
        "thinking_budget": THINKING_BUDGET,
        "samples_per_model": SAMPLES_PER_MODEL,
        "configs_tested": list(all_results.keys()),
        "models_tested": MODELS_TO_TEST,
        "results_by_config": {
            config: {
                model: [asdict(r) for r in results]
                for model, results in model_results.items()
            }
            for config, model_results in all_results.items()
        }
    }
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=2)
    
    print(f"\nSaved: {output_path}")
    return output_path


def print_summary(all_results: Dict):
    """Print summary."""
    print(f"\n{'='*80}")
    print("SUMMARY")
    print(f"{'='*80}")
    
    for config, models in all_results.items():
        print(f"\n--- {config} ---")
        for model, results in models.items():
            success = sum(1 for r in results if r.native)
            avg_time = sum(r.transcription_time_sec for r in results) / max(len(results), 1)
            print(f"{model:<30} {success}/20  avg: {avg_time:.1f}s")


if __name__ == "__main__":
    main()
