"""
Gemini transcription module using Google AI Studio endpoints.
Handles audio transcription with structured output and multiple model support.
"""
import time
from pathlib import Path
from typing import List, Optional, Dict, Any
from dataclasses import dataclass

from google import genai
from google.genai import types

from .config import GEMINI_API_KEY, GEMINI_MODELS, get_model_name
from .transcription_schema import (
    TranscriptionOutput,
    TranscriptionResult,
    get_transcription_prompt,
    get_user_prompt,
    TRANSCRIPTION_JSON_SCHEMA
)
from .audio_processor import AudioChunk


@dataclass
class TranscriptionConfig:
    """Configuration for a transcription job."""
    model: str = "gemini-3-flash-preview"
    thinking_level: Optional[str] = "high"  # minimal, low, medium, high
    temperature: float = 1.0  # Keep at 1.0 for Gemini 3
    language: str = "Telugu"
    timeout_sec: int = 60


class GeminiTranscriber:
    """
    Handles audio transcription using Gemini models via Google AI Studio.
    """
    
    def __init__(self, api_key: Optional[str] = None):
        """
        Initialize the transcriber.
        
        Args:
            api_key: Gemini API key (defaults to env var)
        """
        self.api_key = api_key or GEMINI_API_KEY
        if not self.api_key:
            raise ValueError("GEMINI_API_KEY not set")
        
        self.client = genai.Client(api_key=self.api_key)
        
    def _get_mime_type(self, file_path: str) -> str:
        """Determine MIME type from file extension."""
        ext = Path(file_path).suffix.lower()
        mime_types = {
            '.flac': 'audio/flac',
            '.wav': 'audio/wav',
            '.mp3': 'audio/mpeg',
            '.ogg': 'audio/ogg',
            '.m4a': 'audio/mp4',
            '.aac': 'audio/aac'
        }
        return mime_types.get(ext, 'audio/flac')
    
    def _load_audio_bytes(self, file_path: str) -> bytes:
        """Load audio file as bytes."""
        with open(file_path, 'rb') as f:
            return f.read()
    
    def _build_config(self, config: TranscriptionConfig) -> types.GenerateContentConfig:
        """Build the generation config for the API call."""
        model_name = get_model_name(config.model)
        model_info = GEMINI_MODELS.get(model_name, {})
        
        # Build thinking config for Gemini 3 models
        thinking_config = None
        if model_info.get("supports_thinking") and config.thinking_level:
            if model_name.startswith("gemini-3"):
                thinking_config = types.ThinkingConfig(
                    thinking_level=config.thinking_level.upper()
                )
        
        gen_config = types.GenerateContentConfig(
            temperature=config.temperature,
            response_mime_type="application/json",
            response_json_schema=TRANSCRIPTION_JSON_SCHEMA,
            system_instruction=[
                types.Part.from_text(text=get_transcription_prompt(config.language))
            ]
        )
        
        if thinking_config:
            gen_config.thinking_config = thinking_config
        
        return gen_config
    
    def transcribe_audio(
        self,
        audio_path: str,
        config: TranscriptionConfig
    ) -> Dict[str, Any]:
        """
        Transcribe a single audio file.
        
        Args:
            audio_path: Path to the audio file
            config: Transcription configuration
            
        Returns:
            Dictionary with transcription results
        """
        model_name = get_model_name(config.model)
        
        # Load audio
        audio_bytes = self._load_audio_bytes(audio_path)
        mime_type = self._get_mime_type(audio_path)
        
        # Build content
        contents = [
            types.Content(
                role="user",
                parts=[
                    types.Part.from_bytes(
                        mime_type=mime_type,
                        data=audio_bytes
                    ),
                    types.Part.from_text(text=get_user_prompt())
                ]
            )
        ]
        
        # Build config
        gen_config = self._build_config(config)
        
        # Make API call
        start_time = time.time()
        try:
            response = self.client.models.generate_content(
                model=model_name,
                contents=contents,
                config=gen_config
            )
            processing_time = time.time() - start_time
            
            # Parse response
            if response.text:
                import json
                result = json.loads(response.text)
                result['_processing_time_sec'] = processing_time
                result['_model'] = model_name
                result['_thinking_level'] = config.thinking_level
                return result
            else:
                return {
                    "error": "Empty response from model",
                    "_processing_time_sec": processing_time,
                    "_model": model_name
                }
                
        except Exception as e:
            processing_time = time.time() - start_time
            return {
                "error": str(e),
                "_processing_time_sec": processing_time,
                "_model": model_name
            }
    
    def transcribe_chunk(
        self,
        chunk: AudioChunk,
        config: TranscriptionConfig
    ) -> TranscriptionResult:
        """
        Transcribe an AudioChunk and return structured result.
        
        Args:
            chunk: AudioChunk to transcribe
            config: Transcription configuration
            
        Returns:
            TranscriptionResult with full metadata
        """
        raw_result = self.transcribe_audio(chunk.file_path, config)
        
        # Handle errors
        if "error" in raw_result:
            transcription = TranscriptionOutput(
                native_transcription=f"[ERROR: {raw_result['error']}]",
                native_with_punctuation=f"[ERROR: {raw_result['error']}]",
                code_switch=f"[ERROR: {raw_result['error']}]",
                romanized=f"[ERROR: {raw_result['error']}]",
                notes=raw_result['error']
            )
        else:
            transcription = TranscriptionOutput(
                native_transcription=raw_result.get('native_transcription', ''),
                native_with_punctuation=raw_result.get('native_with_punctuation', ''),
                code_switch=raw_result.get('code_switch', ''),
                romanized=raw_result.get('romanized', ''),
                confidence=raw_result.get('confidence'),
                notes=raw_result.get('notes')
            )
        
        return TranscriptionResult(
            segment_id=chunk.original_segment,
            chunk_index=chunk.chunk_index,
            total_chunks=chunk.total_chunks,
            duration_sec=chunk.duration_sec,
            language=config.language,
            transcription=transcription,
            model_used=raw_result.get('_model', config.model),
            thinking_level=raw_result.get('_thinking_level'),
            processing_time_sec=raw_result.get('_processing_time_sec')
        )
    
    def transcribe_batch(
        self,
        chunks: List[AudioChunk],
        config: TranscriptionConfig,
        max_chunks: Optional[int] = None,
        progress_callback: Optional[callable] = None
    ) -> List[TranscriptionResult]:
        """
        Transcribe multiple audio chunks.
        
        Args:
            chunks: List of AudioChunks to transcribe
            config: Transcription configuration
            max_chunks: Maximum chunks to process (for testing)
            progress_callback: Optional callback(current, total) for progress updates
            
        Returns:
            List of TranscriptionResults
        """
        if max_chunks:
            chunks = chunks[:max_chunks]
        
        results = []
        total = len(chunks)
        
        print(f"[Transcriber] Processing {total} chunks with {config.model}...")
        
        for i, chunk in enumerate(chunks):
            if progress_callback:
                progress_callback(i + 1, total)
            else:
                print(f"[Transcriber] Processing {i + 1}/{total}: {chunk.original_segment} "
                      f"(chunk {chunk.chunk_index + 1}/{chunk.total_chunks})")
            
            result = self.transcribe_chunk(chunk, config)
            results.append(result)
            
            # Brief pause to avoid rate limiting
            if i < total - 1:
                time.sleep(0.5)
        
        print(f"[Transcriber] Completed {len(results)} transcriptions")
        return results


def transcribe_segments(
    segments: List[AudioChunk],
    language: str = "Telugu",
    model: str = "gemini-3-flash-preview",
    thinking_level: str = "high",
    max_segments: Optional[int] = None
) -> List[TranscriptionResult]:
    """
    Convenience function to transcribe a list of audio segments.
    
    Args:
        segments: List of AudioChunk objects
        language: Primary language of the audio
        model: Gemini model to use
        thinking_level: Thinking level for Gemini 3 models
        max_segments: Max segments to process (for testing)
        
    Returns:
        List of TranscriptionResult objects
    """
    config = TranscriptionConfig(
        model=model,
        thinking_level=thinking_level,
        language=language
    )
    
    transcriber = GeminiTranscriber()
    return transcriber.transcribe_batch(segments, config, max_chunks=max_segments)


if __name__ == "__main__":
    # Test with a sample audio file
    import sys
    if len(sys.argv) > 1:
        audio_path = sys.argv[1]
        language = sys.argv[2] if len(sys.argv) > 2 else "Telugu"
        
        config = TranscriptionConfig(
            model="gemini-3-flash-preview",
            thinking_level="high",
            language=language
        )
        
        transcriber = GeminiTranscriber()
        result = transcriber.transcribe_audio(audio_path, config)
        
        print("\nTranscription Result:")
        import json
        print(json.dumps(result, indent=2, ensure_ascii=False))
