#!/usr/bin/env python3
"""
Transcription Utilities
=======================

Utility functions for transcribing audio using Gemini models.
Includes proper configuration for different model families.

Usage:
    from transcription_utils import transcribe_audio, TranscriptionConfig
    
    # For Gemini 2.5 models (simple)
    result = transcribe_audio(
        audio_path="audio.flac",
        model="gemini-2.5-flash",
        language="Telugu",
        temperature=0.0
    )
    
    # For Gemini 3 models with temperature=0 (use thinking_budget)
    result = transcribe_audio(
        audio_path="audio.flac",
        model="gemini-3-flash-preview",
        language="Telugu",
        temperature=0.0,
        thinking_budget=300  # Prevents thinking loops
    )
"""
import json
import time
import os
from dataclasses import dataclass
from typing import Optional, Dict, Any, Literal

from google import genai
from google.genai import types

# Model families
GEMINI_3_MODELS = ["gemini-3-pro-preview", "gemini-3-flash-preview"]
GEMINI_25_MODELS = ["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite"]

# Thinking levels for Gemini 3
ThinkingLevel = Literal["none", "low", "medium", "high"]


@dataclass
class TranscriptionConfig:
    """Configuration for transcription."""
    model: str = "gemini-2.5-flash"
    language: str = "Telugu"
    temperature: float = 0.0
    
    # For Gemini 3 models
    thinking_level: Optional[ThinkingLevel] = None  # Use this OR thinking_budget
    thinking_budget: Optional[int] = None  # Max thinking tokens (prevents loops)
    
    def __post_init__(self):
        """Validate config."""
        if self.thinking_level and self.thinking_budget:
            raise ValueError("Cannot use both thinking_level and thinking_budget")
        
        # Auto-set thinking_budget for Gemini 3 with temp=0
        if self.model in GEMINI_3_MODELS and self.temperature == 0.0:
            if not self.thinking_level and not self.thinking_budget:
                self.thinking_budget = 300  # Prevent thinking loops


def get_system_prompt(language: str) -> str:
    """Get the system prompt for transcription."""
    return f"""You are a strict, verbatim transcription engine for Indian languages.

Primary audio language: {language}

TASK:
1. Listen to the audio carefully
2. Transcribe exactly as spoken
3. Produce four outputs in JSON format

STRICT RULES:
- Verbatim only: Include all repetitions, fillers, stammers
- No normalization: Don't correct grammar or pronunciation  
- No inference: Don't add meaning not in audio

OUTPUT FORMAT (JSON):
{{
    "native_transcription": "Native script without punctuation",
    "native_with_punctuation": "Native script with minimal punctuation",
    "code_switch": "Mixed script (Indian in native, English in Latin)",
    "romanized": "Everything in Roman/Latin script"
}}"""


def get_user_prompt() -> str:
    """Get the user prompt."""
    return "Transcribe this audio. Return JSON with native_transcription, native_with_punctuation, code_switch, romanized fields."


def transcribe_audio(
    audio_path: str,
    model: str = "gemini-2.5-flash",
    language: str = "Telugu",
    temperature: float = 0.0,
    thinking_level: Optional[ThinkingLevel] = None,
    thinking_budget: Optional[int] = None,
    api_key: Optional[str] = None
) -> Dict[str, Any]:
    """
    Transcribe audio file using Gemini models.
    
    Args:
        audio_path: Path to audio file (.flac, .wav, .mp3)
        model: Gemini model name
        language: Primary language of audio
        temperature: Sampling temperature (0.0 for deterministic)
        thinking_level: For Gemini 3: "none", "low", "medium", "high"
        thinking_budget: For Gemini 3: Max thinking tokens (use instead of level with temp=0)
        api_key: Optional API key (defaults to env var)
    
    Returns:
        Dict with transcription results or error
    
    Example:
        # Gemini 2.5 (simple)
        result = transcribe_audio("audio.flac", model="gemini-2.5-flash")
        
        # Gemini 3 with temp=0 (use thinking_budget to prevent loops)
        result = transcribe_audio(
            "audio.flac", 
            model="gemini-3-flash-preview",
            temperature=0.0,
            thinking_budget=300
        )
    """
    config = TranscriptionConfig(
        model=model,
        language=language,
        temperature=temperature,
        thinking_level=thinking_level,
        thinking_budget=thinking_budget
    )
    
    # Get API key
    if api_key is None:
        api_key = os.environ.get('GEMINI_API_KEY')
        if not api_key:
            # Try loading from .env
            try:
                from src.backend.config import GEMINI_API_KEY
                api_key = GEMINI_API_KEY
            except ImportError:
                raise ValueError("GEMINI_API_KEY not found")
    
    client = genai.Client(api_key=api_key)
    
    # Load audio
    with open(audio_path, 'rb') as f:
        audio_bytes = f.read()
    
    # Determine MIME type
    ext = audio_path.lower().split('.')[-1]
    mime_types = {
        'flac': 'audio/flac',
        'wav': 'audio/wav',
        'mp3': 'audio/mpeg',
        'ogg': 'audio/ogg',
    }
    mime_type = mime_types.get(ext, 'audio/flac')
    
    # Build thinking config for Gemini 3
    thinking_config = None
    if config.model in GEMINI_3_MODELS:
        if config.thinking_budget:
            thinking_config = types.ThinkingConfig(thinking_budget=config.thinking_budget)
        elif config.thinking_level:
            thinking_config = types.ThinkingConfig(thinking_level=config.thinking_level.upper())
    
    # Build generation config
    gen_config = types.GenerateContentConfig(
        temperature=config.temperature,
        response_mime_type="application/json",
        system_instruction=[
            types.Part.from_text(text=get_system_prompt(config.language))
        ]
    )
    
    if thinking_config:
        gen_config.thinking_config = thinking_config
    
    # Build content
    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_bytes(
                    mime_type=mime_type,
                    data=audio_bytes
                ),
                types.Part.from_text(text=get_user_prompt())
            ]
        )
    ]
    
    # Call API
    start_time = time.time()
    try:
        response = client.models.generate_content(
            model=config.model,
            contents=contents,
            config=gen_config
        )
        processing_time = time.time() - start_time
        
        if response.text:
            result = json.loads(response.text)
            result['_metadata'] = {
                'model': config.model,
                'temperature': config.temperature,
                'thinking_budget': config.thinking_budget,
                'thinking_level': config.thinking_level,
                'processing_time_sec': round(processing_time, 2),
                'audio_path': audio_path
            }
            return result
        else:
            return {
                "error": "Empty response from model",
                "_metadata": {
                    'model': config.model,
                    'processing_time_sec': round(processing_time, 2)
                }
            }
            
    except json.JSONDecodeError as e:
        return {
            "error": f"JSON parse error: {e}",
            "raw_response": response.text if response else None,
            "_metadata": {
                'model': config.model,
                'processing_time_sec': round(time.time() - start_time, 2)
            }
        }
    except Exception as e:
        return {
            "error": str(e),
            "_metadata": {
                'model': config.model,
                'processing_time_sec': round(time.time() - start_time, 2)
            }
        }


def batch_transcribe(
    audio_paths: list,
    model: str = "gemini-2.5-flash",
    language: str = "Telugu",
    temperature: float = 0.0,
    thinking_budget: Optional[int] = None,
    delay_between_calls: float = 0.5
) -> list:
    """
    Transcribe multiple audio files.
    
    Args:
        audio_paths: List of audio file paths
        model: Gemini model name
        language: Primary language
        temperature: Sampling temperature
        thinking_budget: For Gemini 3 with temp=0
        delay_between_calls: Delay between API calls (seconds)
    
    Returns:
        List of transcription results
    """
    results = []
    
    for i, path in enumerate(audio_paths):
        print(f"[{i+1}/{len(audio_paths)}] {os.path.basename(path)}...", end=" ", flush=True)
        
        result = transcribe_audio(
            audio_path=path,
            model=model,
            language=language,
            temperature=temperature,
            thinking_budget=thinking_budget
        )
        
        if "error" not in result:
            print(f"OK ({result['_metadata']['processing_time_sec']}s)")
        else:
            print(f"ERROR: {result['error'][:50]}")
        
        results.append(result)
        time.sleep(delay_between_calls)
    
    return results


# === CLI Usage ===
if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 2:
        print("""
Transcription Utilities
=======================

Usage:
    python transcription_utils.py <audio_path> [options]

Options:
    --model MODEL         Model name (default: gemini-2.5-flash)
    --language LANG       Language (default: Telugu)
    --temperature TEMP    Temperature (default: 0.0)
    --thinking-budget N   Thinking budget for Gemini 3 (default: 300 for temp=0)

Examples:
    # Basic usage
    python transcription_utils.py audio.flac
    
    # With specific model
    python transcription_utils.py audio.flac --model gemini-3-flash-preview
    
    # Gemini 3 with temperature=0
    python transcription_utils.py audio.flac --model gemini-3-pro-preview --temperature 0 --thinking-budget 300

API Usage in Code:
    from transcription_utils import transcribe_audio
    
    # Simple (Gemini 2.5)
    result = transcribe_audio("audio.flac")
    
    # Gemini 3 with temp=0 (use thinking_budget to prevent loops)
    result = transcribe_audio(
        "audio.flac",
        model="gemini-3-flash-preview",
        temperature=0.0,
        thinking_budget=300
    )
""")
        sys.exit(1)
    
    # Parse args
    audio_path = sys.argv[1]
    model = "gemini-2.5-flash"
    language = "Telugu"
    temperature = 0.0
    thinking_budget = None
    
    i = 2
    while i < len(sys.argv):
        if sys.argv[i] == "--model" and i + 1 < len(sys.argv):
            model = sys.argv[i + 1]
            i += 2
        elif sys.argv[i] == "--language" and i + 1 < len(sys.argv):
            language = sys.argv[i + 1]
            i += 2
        elif sys.argv[i] == "--temperature" and i + 1 < len(sys.argv):
            temperature = float(sys.argv[i + 1])
            i += 2
        elif sys.argv[i] == "--thinking-budget" and i + 1 < len(sys.argv):
            thinking_budget = int(sys.argv[i + 1])
            i += 2
        else:
            i += 1
    
    # Run transcription
    print(f"Model: {model}")
    print(f"Language: {language}")
    print(f"Temperature: {temperature}")
    if thinking_budget:
        print(f"Thinking Budget: {thinking_budget}")
    print()
    
    result = transcribe_audio(
        audio_path=audio_path,
        model=model,
        language=language,
        temperature=temperature,
        thinking_budget=thinking_budget
    )
    
    if "error" in result:
        print(f"ERROR: {result['error']}")
    else:
        print(f"Native: {result.get('native_transcription', 'N/A')}")
        print(f"Punctuated: {result.get('native_with_punctuation', 'N/A')}")
        print(f"Code-switch: {result.get('code_switch', 'N/A')}")
        print(f"Romanized: {result.get('romanized', 'N/A')}")
        print(f"\nTime: {result['_metadata']['processing_time_sec']}s")
