"""
Main transcription pipeline orchestrating the full workflow:
1. Download audio segments from R2
2. Fetch language info from Supabase
3. Process/split audio segments
4. Transcribe using Gemini
5. Output results

Full control over all parameters for testing and production use.
"""
import os
import json
import time
from typing import Optional, List, Dict
from dataclasses import dataclass, field
from datetime import datetime

# Import from src.backend package
from src.backend.config import DEFAULT_SETTINGS, GEMINI_MODELS, get_model_name
from src.backend.r2_storage import download_video_segments
from src.backend.supabase_client import get_video_language
from src.backend.audio_processor import AudioProcessor, AudioChunk, get_segment_stats
from src.backend.gemini_transcriber import GeminiTranscriber, TranscriptionConfig
from src.backend.transcription_schema import TranscriptionResult

# Validation
from src.validators import validate_transcription, quick_validate, cleanup as cleanup_validator


@dataclass
class PipelineConfig:
    """
    Complete configuration for the transcription pipeline.
    All parameters are exposed for full control.
    """
    # === Video Settings ===
    video_id: str
    language: Optional[str] = None  # If None, fetched from Supabase
    default_language: str = "Telugu"  # Fallback if Supabase lookup fails
    
    # === Audio Processing Settings ===
    max_segment_duration_sec: float = 10.0  # Hard limit - segments longer than this are split
    min_segment_duration_sec: float = 1.0   # Skip segments shorter than this
    chunk_overlap_sec: float = 0.0          # Overlap when splitting (for word boundaries)
    
    # === Segment Selection Settings ===
    max_segments: Optional[int] = None      # Limit total segments processed (for testing)
    segment_start_index: int = 0            # Start from this segment index
    skip_short_segments: bool = True        # Skip segments < min_duration
    
    # === Transcription Settings ===
    model: str = "gemini-3-flash-preview"   # Gemini model to use
    thinking_level: Optional[str] = "low"   # minimal, low, medium, high (Gemini 3)
    temperature: float = 0.0                # 0 for deterministic output
    
    # === Validation Settings ===
    validate_transcriptions: bool = True    # Validate native transcriptions
    validation_action: str = "flag"         # "flag" (mark status) or "retry" (re-transcribe)
    
    # === Output Settings ===
    output_dir: str = "./transcriptions"
    save_intermediate: bool = True          # Save results after each batch
    batch_size: int = 10                    # How many to process before saving
    
    # === Storage Settings ===
    work_dir: str = field(default_factory=lambda: DEFAULT_SETTINGS["work_dir"])
    cleanup_after: bool = False             # Remove downloaded files after processing
    use_regex_fallback: bool = True         # Try regex if exact R2 match fails


@dataclass 
class PipelineResult:
    """Results from a pipeline run."""
    video_id: str
    language: str
    model: str
    thinking_level: Optional[str]
    
    total_segments_found: int
    segments_processed: int
    chunks_created: int
    transcriptions_completed: int
    
    total_audio_duration_sec: float
    total_processing_time_sec: float
    
    output_file: str
    errors: List[str]
    
    started_at: str
    completed_at: str


class TranscriptionPipeline:
    """
    Main pipeline orchestrating the full transcription workflow.
    """
    
    def __init__(self, config: PipelineConfig):
        """Initialize pipeline with configuration."""
        self.config = config
        self.results: List[TranscriptionResult] = []
        self.errors: List[str] = []
        self.start_time: Optional[float] = None
        
    def _log(self, message: str, level: str = "INFO"):
        """Log a message with timestamp."""
        timestamp = datetime.now().strftime("%H:%M:%S")
        print(f"[{timestamp}] [{level}] {message}")
        
    def run(self) -> PipelineResult:
        """
        Execute the full pipeline.
        
        Returns:
            PipelineResult with summary and statistics
        """
        self.start_time = time.time()
        started_at = datetime.now().isoformat()
        
        self._log(f"Starting pipeline for video: {self.config.video_id}")
        self._log(f"Model: {self.config.model}, Thinking: {self.config.thinking_level}")
        
        # === Step 1: Download from R2 ===
        self._log("Step 1: Downloading audio segments from R2...")
        try:
            segments_dir, metadata = download_video_segments(
                self.config.video_id,
                work_dir=self.config.work_dir,
                use_regex_fallback=self.config.use_regex_fallback
            )
        except Exception as e:
            self.errors.append(f"R2 download failed: {e}")
            self._log(f"ERROR: {e}", "ERROR")
            raise
        
        # Get segment stats
        stats = get_segment_stats(segments_dir)
        self._log(f"Found {stats.get('total_files', 0)} audio files, "
                  f"total duration: {stats.get('total_duration_sec', 0):.1f}s")
        if stats.get('over_10s_count', 0) > 0:
            self._log(f"  -> {stats['over_10s_count']} segments over 10s will be split")
        
        # === Step 2: Get Language ===
        self._log("Step 2: Fetching language info...")
        if self.config.language:
            language = self.config.language
            self._log(f"Using specified language: {language}")
        else:
            language = get_video_language(
                self.config.video_id,
                default=self.config.default_language
            )
            self._log(f"Language from Supabase: {language}")
        
        # === Step 3: Process Audio Segments ===
        self._log("Step 3: Processing audio segments...")
        processor = AudioProcessor(
            max_duration_sec=self.config.max_segment_duration_sec,
            min_duration_sec=self.config.min_segment_duration_sec,
            overlap_sec=self.config.chunk_overlap_sec
        )
        
        chunks = processor.process_segments_directory(
            segments_dir,
            max_segments=self.config.max_segments,
            skip_short=self.config.skip_short_segments
        )
        
        # Apply start index
        if self.config.segment_start_index > 0:
            chunks = chunks[self.config.segment_start_index:]
            self._log(f"Starting from index {self.config.segment_start_index}, "
                      f"{len(chunks)} chunks remaining")
        
        # === Step 4: Transcribe ===
        self._log(f"Step 4: Transcribing {len(chunks)} chunks...")
        
        transcriber = GeminiTranscriber()
        transcription_config = TranscriptionConfig(
            model=self.config.model,
            thinking_level=self.config.thinking_level,
            temperature=self.config.temperature,
            language=language
        )
        
        # Process in batches
        for batch_start in range(0, len(chunks), self.config.batch_size):
            batch_end = min(batch_start + self.config.batch_size, len(chunks))
            batch = chunks[batch_start:batch_end]
            
            self._log(f"Processing batch {batch_start + 1}-{batch_end} of {len(chunks)}")
            
            batch_results = transcriber.transcribe_batch(
                batch,
                transcription_config,
                progress_callback=lambda c, t: self._log(
                    f"  Chunk {batch_start + c}/{len(chunks)}: "
                    f"{batch[c-1].original_segment}"
                )
            )
            
            self.results.extend(batch_results)
            
            # Save intermediate results
            if self.config.save_intermediate:
                self._save_results(partial=True)
        
        # === Step 5: Save Final Results ===
        self._log("Step 5: Saving results...")
        output_file = self._save_results(partial=False)
        
        # === Cleanup ===
        if self.config.cleanup_after:
            self._log("Cleaning up downloaded files...")
            import shutil
            video_work_dir = os.path.join(self.config.work_dir, self.config.video_id)
            if os.path.exists(video_work_dir):
                shutil.rmtree(video_work_dir)
        
        # === Build Result ===
        total_time = time.time() - self.start_time
        completed_at = datetime.now().isoformat()
        
        total_audio = sum(r.duration_sec for r in self.results)
        
        result = PipelineResult(
            video_id=self.config.video_id,
            language=language,
            model=self.config.model,
            thinking_level=self.config.thinking_level,
            total_segments_found=stats.get('total_files', 0),
            segments_processed=len(set(r.segment_id for r in self.results)),
            chunks_created=len(chunks),
            transcriptions_completed=len(self.results),
            total_audio_duration_sec=total_audio,
            total_processing_time_sec=total_time,
            output_file=output_file,
            errors=self.errors,
            started_at=started_at,
            completed_at=completed_at
        )
        
        self._log(f"Pipeline completed in {total_time:.1f}s")
        self._log(f"Transcribed {total_audio:.1f}s of audio")
        self._log(f"Output saved to: {output_file}")
        
        return result
    
    def _save_results(self, partial: bool = False) -> str:
        """Save results to JSON file."""
        os.makedirs(self.config.output_dir, exist_ok=True)
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        suffix = "_partial" if partial else ""
        filename = f"{self.config.video_id}_{self.config.model.replace('-', '_')}_{timestamp}{suffix}.json"
        output_path = os.path.join(self.config.output_dir, filename)
        
        # Convert results to dict
        output_data = {
            "video_id": self.config.video_id,
            "model": self.config.model,
            "thinking_level": self.config.thinking_level,
            "language": self.results[0].language if self.results else self.config.default_language,
            "config": {
                "max_segment_duration_sec": self.config.max_segment_duration_sec,
                "min_segment_duration_sec": self.config.min_segment_duration_sec,
                "max_segments": self.config.max_segments,
            },
            "results_count": len(self.results),
            "total_duration_sec": sum(r.duration_sec for r in self.results),
            "results": [r.model_dump() for r in self.results]
        }
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=2)
        
        return output_path


def run_pipeline(
    video_id: str,
    language: Optional[str] = None,
    model: str = "gemini-3-flash-preview",
    thinking_level: str = "high",
    max_segments: Optional[int] = None,
    max_duration_sec: float = 10.0,
    output_dir: str = "./transcriptions"
) -> PipelineResult:
    """
    Convenience function to run the pipeline with common settings.
    
    Args:
        video_id: YouTube video ID
        language: Primary language (if None, fetched from Supabase)
        model: Gemini model to use
        thinking_level: Thinking level for Gemini 3
        max_segments: Limit segments for testing
        max_duration_sec: Max segment duration (hard limit)
        output_dir: Directory for output files
        
    Returns:
        PipelineResult with summary
    """
    config = PipelineConfig(
        video_id=video_id,
        language=language,
        model=model,
        thinking_level=thinking_level,
        max_segments=max_segments,
        max_segment_duration_sec=max_duration_sec,
        output_dir=output_dir
    )
    
    pipeline = TranscriptionPipeline(config)
    return pipeline.run()


def test_models(
    video_id: str,
    language: str,
    models: List[str],
    thinking_levels: List[str] = ["high"],
    max_segments: int = 3
) -> Dict[str, PipelineResult]:
    """
    Test multiple models on the same video for comparison.
    
    Args:
        video_id: Video ID to test
        language: Primary language
        models: List of model names to test
        thinking_levels: Thinking levels to try
        max_segments: Segments per test
        
    Returns:
        Dict mapping model_thinking to PipelineResult
    """
    results = {}
    
    for model in models:
        for thinking in thinking_levels:
            key = f"{model}_{thinking}"
            print(f"\n{'='*60}")
            print(f"Testing: {key}")
            print(f"{'='*60}")
            
            try:
                result = run_pipeline(
                    video_id=video_id,
                    language=language,
                    model=model,
                    thinking_level=thinking,
                    max_segments=max_segments,
                    output_dir=f"./transcriptions/test_{video_id}"
                )
                results[key] = result
            except Exception as e:
                print(f"ERROR testing {key}: {e}")
                
    return results


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Transcription Pipeline")
    parser.add_argument("video_id", help="YouTube video ID")
    parser.add_argument("--language", "-l", help="Primary language (default: auto-detect)")
    parser.add_argument("--model", "-m", default="gemini-3-flash-preview",
                        help="Gemini model to use")
    parser.add_argument("--thinking", "-t", default="high",
                        choices=["minimal", "low", "medium", "high"],
                        help="Thinking level for Gemini 3")
    parser.add_argument("--max-segments", "-n", type=int,
                        help="Maximum segments to process")
    parser.add_argument("--max-duration", "-d", type=float, default=10.0,
                        help="Maximum segment duration in seconds")
    parser.add_argument("--output", "-o", default="./transcriptions",
                        help="Output directory")
    
    args = parser.parse_args()
    
    result = run_pipeline(
        video_id=args.video_id,
        language=args.language,
        model=args.model,
        thinking_level=args.thinking,
        max_segments=args.max_segments,
        max_duration_sec=args.max_duration,
        output_dir=args.output
    )
    
    print(f"\n{'='*60}")
    print("Pipeline Summary")
    print(f"{'='*60}")
    print(f"Video ID: {result.video_id}")
    print(f"Language: {result.language}")
    print(f"Model: {result.model} (thinking: {result.thinking_level})")
    print(f"Segments processed: {result.segments_processed}")
    print(f"Chunks transcribed: {result.transcriptions_completed}")
    print(f"Total audio: {result.total_audio_duration_sec:.1f}s")
    print(f"Total time: {result.total_processing_time_sec:.1f}s")
    print(f"Output: {result.output_file}")
    if result.errors:
        print(f"Errors: {result.errors}")
