"""
Main transcription pipeline orchestrating the full workflow:
1. Download audio segments from R2
2. Fetch language info from Supabase
3. Process/split audio segments
4. Transcribe using Gemini
5. Output results

Full control over all parameters for testing and production use.
"""
import os
import json
import time
from typing import Optional, List, Dict
from dataclasses import dataclass, field
from datetime import datetime

# Import from src.backend package
from src.backend.config import DEFAULT_SETTINGS, GEMINI_MODELS, get_model_name
from src.backend.r2_storage import download_video_segments
from src.backend.supabase_client import get_video_language
from src.backend.audio_processor import AudioProcessor, AudioChunk, get_segment_stats
from src.backend.audio_polisher import AudioPolisher
from src.backend.gemini_transcriber import GeminiTranscriber, TranscriptionConfig
from src.backend.transcription_schema import TranscriptionResult

# Validation
from src.validators import validate_transcription, quick_validate, cleanup as cleanup_validator


@dataclass
class PipelineConfig:
    """
    Complete configuration for the transcription pipeline.
    All parameters are exposed for full control.
    """
    # === Video Settings ===
    video_id: str
    language: Optional[str] = None  # If None, fetched from Supabase
    default_language: str = "Telugu"  # Fallback if Supabase lookup fails
    
    # === Audio Processing Settings ===
    max_segment_duration_sec: float = 10.0  # Hard limit - segments longer than this are split
    min_segment_duration_sec: float = 2.0   # Skip segments shorter than this (v4: raised, <2s is junk)
    chunk_overlap_sec: float = 0.0          # Overlap when splitting (for word boundaries)
    
    # === Segment Selection Settings ===
    max_segments: Optional[int] = None      # Limit total segments processed (for testing)
    segment_start_index: int = 0            # Start from this segment index
    skip_short_segments: bool = True        # Skip segments < min_duration
    
    # === Transcription Settings ===
    model: str = "gemini-3-flash-preview"   # Gemini model to use
    thinking_level: Optional[str] = "low"   # minimal, low, medium, high (Gemini 3)
    temperature: float = 0.0                # 0 for deterministic (tested safe with strict prompt)
    
    # === Validation Settings ===
    validate_transcriptions: bool = True    # Validate native transcriptions
    validation_action: str = "retry"        # "flag" (mark status) or "retry" (re-transcribe with pro)
    retry_model: str = "gemini-3-pro-preview"  # Model for retry attempts
    retry_thinking_level: str = "low"        # Thinking level for retry
    min_snr_db: float = 5.0                 # SNR gate: skip segments below this (saves API credits)
    
    # === Output Settings ===
    output_dir: str = "./transcriptions"
    save_intermediate: bool = True          # Save results after each batch
    batch_size: int = 10                    # How many to process before saving
    
    # === Audio Polishing Settings ===
    polish_audio: bool = True               # Polish segment boundaries before transcription
    
    # === Storage Settings ===
    work_dir: str = field(default_factory=lambda: DEFAULT_SETTINGS["work_dir"])
    cleanup_after: bool = False             # Remove downloaded files after processing
    use_regex_fallback: bool = True         # Try regex if exact R2 match fails


@dataclass 
class PipelineResult:
    """Results from a pipeline run."""
    video_id: str
    language: str
    model: str
    thinking_level: Optional[str]
    
    total_segments_found: int
    segments_processed: int
    chunks_created: int
    transcriptions_completed: int
    
    total_audio_duration_sec: float
    total_processing_time_sec: float
    
    output_file: str
    errors: List[str]
    
    started_at: str
    completed_at: str


@dataclass
class _BatchHealth:
    """Tracks validator health per batch. Catches silent degradation."""
    validated: int = 0
    ctc_success: int = 0       # CTC returned > 0
    mms_success: int = 0       # MMS returned > 0
    dual_success: int = 0      # Both returned > 0
    one_sided_accept: int = 0  # Accepted with only one validator
    accept: int = 0
    review: int = 0
    retry: int = 0
    reject: int = 0
    retry_improved: int = 0
    total_combined: float = 0.0

    # Alert thresholds
    CTC_FAIL_THRESHOLD = 0.20    # Warn if CTC fails on >20% of batch
    MMS_FAIL_THRESHOLD = 0.20    # Warn if MMS fails on >20% of batch
    REJECT_THRESHOLD = 0.30      # Warn if >30% of batch rejected
    ONE_SIDED_THRESHOLD = 0.40   # Warn if >40% accepted on single validator

    def record(self, validation):
        """Record one validation result."""
        self.validated += 1
        ctc_ok = validation.native_ctc_score > 0
        mms_ok = validation.roman_mms_score > 0
        if ctc_ok:
            self.ctc_success += 1
        if mms_ok:
            self.mms_success += 1
        if ctc_ok and mms_ok:
            self.dual_success += 1
        if validation.status == "accept" and not (ctc_ok and mms_ok):
            self.one_sided_accept += 1
        status_attr = validation.status  # accept/review/retry/reject
        if hasattr(self, status_attr):
            setattr(self, status_attr, getattr(self, status_attr) + 1)
        self.total_combined += validation.combined_score

    def report(self, log_fn):
        """Log batch health summary. Warn on threshold breaches."""
        n = max(self.validated, 1)
        avg_s = self.total_combined / n
        ctc_rate = self.ctc_success / n
        mms_rate = self.mms_success / n
        dual_rate = self.dual_success / n
        rej_rate = self.reject / n
        one_sided_rate = self.one_sided_accept / max(self.accept, 1)

        log_fn(f"  Health: {self.validated} validated | "
               f"accept={self.accept} review={self.review} "
               f"retry={self.retry} reject={self.reject} | "
               f"avg_S={avg_s:.3f}")
        log_fn(f"  Validators: CTC={ctc_rate:.0%} MMS={mms_rate:.0%} "
               f"dual={dual_rate:.0%} | "
               f"one-sided accepts={self.one_sided_accept}")
        if self.retry_improved > 0:
            log_fn(f"  Retries improved: {self.retry_improved}")

        # Alert on threshold breaches
        alerts = []
        if ctc_rate < (1 - self.CTC_FAIL_THRESHOLD):
            alerts.append(f"CTC failing on {1-ctc_rate:.0%} of segments")
        if mms_rate < (1 - self.MMS_FAIL_THRESHOLD):
            alerts.append(f"MMS failing on {1-mms_rate:.0%} of segments")
        if rej_rate > self.REJECT_THRESHOLD:
            alerts.append(f"Reject rate {rej_rate:.0%} exceeds {self.REJECT_THRESHOLD:.0%}")
        if self.accept > 0 and one_sided_rate > self.ONE_SIDED_THRESHOLD:
            alerts.append(
                f"One-sided accept rate {one_sided_rate:.0%} "
                f"exceeds {self.ONE_SIDED_THRESHOLD:.0%}"
            )
        for alert in alerts:
            log_fn(f"  ALERT: {alert}", "WARN")
        return alerts


class TranscriptionPipeline:
    """
    Main pipeline orchestrating the full transcription workflow.
    """
    
    def __init__(self, config: PipelineConfig):
        """Initialize pipeline with configuration."""
        self.config = config
        self.results: List[TranscriptionResult] = []
        self.errors: List[str] = []
        self.start_time: Optional[float] = None
        self._health_alerts: List[str] = []  # Accumulated health alerts
        
    def _log(self, message: str, level: str = "INFO"):
        """Log a message with timestamp."""
        timestamp = datetime.now().strftime("%H:%M:%S")
        print(f"[{timestamp}] [{level}] {message}")
    
    def _get_language_code(self, language: str) -> str:
        """Convert language name to code for validation."""
        lang_map = {
            "Telugu": "te", "Hindi": "hi", "Tamil": "ta",
            "Kannada": "kn", "Malayalam": "ml", "Bengali": "bn",
            "Marathi": "mr", "Gujarati": "gu", "Punjabi": "pa",
            "Odia": "or", "Assamese": "as", "English": "en",
        }
        return lang_map.get(language, "te")
        
    def run(self) -> PipelineResult:
        """
        Execute the full pipeline.
        
        Returns:
            PipelineResult with summary and statistics
        """
        self.start_time = time.time()
        started_at = datetime.now().isoformat()
        
        self._log(f"Starting pipeline for video: {self.config.video_id}")
        self._log(f"Model: {self.config.model}, Thinking: {self.config.thinking_level}")
        
        # === Step 1: Download from R2 ===
        self._log("Step 1: Downloading audio segments from R2...")
        try:
            segments_dir, metadata = download_video_segments(
                self.config.video_id,
                work_dir=self.config.work_dir,
                use_regex_fallback=self.config.use_regex_fallback
            )
        except Exception as e:
            self.errors.append(f"R2 download failed: {e}")
            self._log(f"ERROR: {e}", "ERROR")
            raise
        
        # Get segment stats
        stats = get_segment_stats(segments_dir)
        self._log(f"Found {stats.get('total_files', 0)} audio files, "
                  f"total duration: {stats.get('total_duration_sec', 0):.1f}s")
        if stats.get('over_10s_count', 0) > 0:
            self._log(f"  -> {stats['over_10s_count']} segments over 10s will be split")
        
        # === Step 2: Get Language ===
        self._log("Step 2: Fetching language info...")
        if self.config.language:
            language = self.config.language
            self._log(f"Using specified language: {language}")
        else:
            language = get_video_language(
                self.config.video_id,
                default=self.config.default_language
            )
            self._log(f"Language from Supabase: {language}")
        
        # === Step 3: Process Audio Segments ===
        self._log("Step 3: Processing audio segments...")
        processor = AudioProcessor(
            max_duration_sec=self.config.max_segment_duration_sec,
            min_duration_sec=self.config.min_segment_duration_sec,
            overlap_sec=self.config.chunk_overlap_sec
        )
        
        chunks = processor.process_segments_directory(
            segments_dir,
            max_segments=self.config.max_segments,
            skip_short=self.config.skip_short_segments
        )
        
        # Apply start index
        if self.config.segment_start_index > 0:
            chunks = chunks[self.config.segment_start_index:]
            self._log(f"Starting from index {self.config.segment_start_index}, "
                      f"{len(chunks)} chunks remaining")
        
        # === Step 3.5: Polish Audio + SNR Gate ===
        if self.config.polish_audio:
            self._log("Step 3.5: Polishing audio boundaries...")
            polisher = AudioPolisher()
            polished_dir = os.path.join(segments_dir, "polished")
            polished_count = 0
            snr_skipped = 0
            surviving_chunks = []
            for chunk in chunks:
                result = polisher.polish(chunk.file_path, output_dir=polished_dir)
                if result.was_modified:
                    chunk.file_path = result.output_path
                    polished_count += 1
                # SNR gate: skip segments with very low SNR (likely noise/music)
                if result.snr_db < self.config.min_snr_db:
                    snr_skipped += 1
                    continue
                # Duration gate: polishing might shrink segments below minimum
                if result.polished_duration_ms / 1000.0 < self.config.min_segment_duration_sec:
                    snr_skipped += 1  # reuse counter, both are pre-tx filters
                    continue
                surviving_chunks.append(chunk)
            self._log(f"Polished {polished_count}/{len(chunks)} chunks "
                      f"({len(chunks) - polished_count} already clean)")
            if snr_skipped > 0:
                self._log(f"SNR gate: skipped {snr_skipped} chunks "
                          f"(SNR < {self.config.min_snr_db}dB)")
            chunks = surviving_chunks
        
        # === Step 4: Transcribe ===
        self._log(f"Step 4: Transcribing {len(chunks)} chunks...")
        
        transcriber = GeminiTranscriber()
        transcription_config = TranscriptionConfig(
            model=self.config.model,
            thinking_level=self.config.thinking_level,
            temperature=self.config.temperature,
            language=language
        )
        
        # Process in batches
        for batch_start in range(0, len(chunks), self.config.batch_size):
            batch_end = min(batch_start + self.config.batch_size, len(chunks))
            batch = chunks[batch_start:batch_end]
            
            self._log(f"Processing batch {batch_start + 1}-{batch_end} of {len(chunks)}")
            
            batch_results = transcriber.transcribe_batch(
                batch,
                transcription_config,
                progress_callback=lambda c, t: self._log(
                    f"  Chunk {batch_start + c}/{len(chunks)}: "
                    f"{batch[c-1].original_segment}"
                )
            )
            
            # === Validate transcriptions (dual scoring: native CTC + romanized MMS) ===
            if self.config.validate_transcriptions:
                lang_code = self._get_language_code(language)
                batch_health = _BatchHealth()
                for i, result in enumerate(batch_results):
                    if result.native:
                        romanized = result.transcription.romanized or ""
                        validation = validate_transcription(
                            batch[i].file_path,
                            result.native,
                            romanized_text=romanized,
                            language=lang_code,
                        )
                        result.validation_status = validation.status
                        result.validation_score = validation.combined_score

                        # Retry with pro model if below accept threshold
                        if (validation.status in ("retry", "reject")
                                and self.config.validation_action == "retry"):
                            self._log(
                                f"  Retrying {result.segment_id} with "
                                f"{self.config.retry_model} "
                                f"(S={validation.combined_score:.2f})", "WARN"
                            )
                            retry_config = TranscriptionConfig(
                                model=self.config.retry_model,
                                thinking_level=self.config.retry_thinking_level,
                                temperature=self.config.temperature,
                                language=language,
                            )
                            retry_raw = transcriber.transcribe_audio(
                                batch[i].file_path, retry_config
                            )
                            if not retry_raw.get("error"):
                                # Re-validate the retry result
                                from src.backend.transcription_schema import (
                                    TranscriptionOutput, SpeakerMeta
                                )
                                speaker_data = retry_raw.get('speaker')
                                speaker = (SpeakerMeta(**speaker_data)
                                           if isinstance(speaker_data, dict)
                                           else None)
                                retry_output = TranscriptionOutput(
                                    transcription=retry_raw.get('transcription', ''),
                                    code_switch=retry_raw.get('code_switch', ''),
                                    romanized=retry_raw.get('romanized', ''),
                                    tagged=retry_raw.get('tagged', ''),
                                    speaker=speaker,
                                )
                                retry_roman = retry_output.romanized or ""
                                retry_val = validate_transcription(
                                    batch[i].file_path,
                                    retry_output.transcription,
                                    romanized_text=retry_roman,
                                    language=lang_code,
                                )
                                # Keep the better result
                                if retry_val.combined_score > validation.combined_score:
                                    result.transcription = retry_output
                                    result.model_used = retry_config.model
                                    result.thinking_level = retry_config.thinking_level
                                    validation = retry_val
                                    self._log(
                                        f"  Retry improved: "
                                        f"S={retry_val.combined_score:.2f} "
                                        f"({retry_val.status})"
                                    )
                                    batch_health.retry_improved += 1
                                else:
                                    self._log(
                                        f"  Retry did not improve: "
                                        f"S={retry_val.combined_score:.2f}"
                                    )
                            # Update final validation state
                            result.validation_status = validation.status
                            result.validation_score = validation.combined_score

                        if validation.status in ("reject", "retry"):
                            self._log(
                                f"  Validation: {result.segment_id} -> "
                                f"{validation.status} "
                                f"(S={validation.combined_score:.2f}, "
                                f"CTC={validation.native_ctc_score:.2f}, "
                                f"MMS={validation.roman_mms_score:.2f})",
                                "WARN"
                            )

                        # Record for batch health monitoring
                        batch_health.record(validation)

                # Batch health report + alerts
                alerts = batch_health.report(self._log)
                self._health_alerts.extend(alerts)

            self.results.extend(batch_results)
            
            # Save intermediate results
            if self.config.save_intermediate:
                self._save_results(partial=True)
        
        # === Step 5: Save Final Results ===
        self._log("Step 5: Saving results...")
        output_file = self._save_results(partial=False)
        
        # === Cleanup ===
        if self.config.validate_transcriptions:
            cleanup_validator()  # Release CTC model resources
            
        if self.config.cleanup_after:
            self._log("Cleaning up downloaded files...")
            import shutil
            video_work_dir = os.path.join(self.config.work_dir, self.config.video_id)
            if os.path.exists(video_work_dir):
                shutil.rmtree(video_work_dir)
        
        # === Build Result ===
        total_time = time.time() - self.start_time
        completed_at = datetime.now().isoformat()
        
        total_audio = sum(r.duration_sec for r in self.results)
        
        result = PipelineResult(
            video_id=self.config.video_id,
            language=language,
            model=self.config.model,
            thinking_level=self.config.thinking_level,
            total_segments_found=stats.get('total_files', 0),
            segments_processed=len(set(r.segment_id for r in self.results)),
            chunks_created=len(chunks),
            transcriptions_completed=len(self.results),
            total_audio_duration_sec=total_audio,
            total_processing_time_sec=total_time,
            output_file=output_file,
            errors=self.errors,
            started_at=started_at,
            completed_at=completed_at
        )
        
        self._log(f"Pipeline completed in {total_time:.1f}s")
        self._log(f"Transcribed {total_audio:.1f}s of audio")
        self._log(f"Output saved to: {output_file}")
        
        return result
    
    def _save_results(self, partial: bool = False) -> str:
        """Save results to JSON file."""
        os.makedirs(self.config.output_dir, exist_ok=True)
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        suffix = "_partial" if partial else ""
        filename = f"{self.config.video_id}_{self.config.model.replace('-', '_')}_{timestamp}{suffix}.json"
        output_path = os.path.join(self.config.output_dir, filename)
        
        # Convert results to dict
        output_data = {
            "video_id": self.config.video_id,
            "model": self.config.model,
            "thinking_level": self.config.thinking_level,
            "language": self.results[0].language if self.results else self.config.default_language,
            "config": {
                "max_segment_duration_sec": self.config.max_segment_duration_sec,
                "min_segment_duration_sec": self.config.min_segment_duration_sec,
                "max_segments": self.config.max_segments,
            },
            "results_count": len(self.results),
            "total_duration_sec": sum(r.duration_sec for r in self.results),
            "results": [r.model_dump() for r in self.results]
        }
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=2)
        
        return output_path


def run_pipeline(
    video_id: str,
    language: Optional[str] = None,
    model: str = "gemini-3-flash-preview",
    thinking_level: str = "low",
    temperature: float = 0.0,
    max_segments: Optional[int] = None,
    max_duration_sec: float = 10.0,
    output_dir: str = "./transcriptions",
    validate: bool = True
) -> PipelineResult:
    """
    Convenience function to run the pipeline with common settings.
    
    Args:
        video_id: YouTube video ID
        language: Primary language (if None, fetched from Supabase)
        model: Gemini model to use
        thinking_level: Thinking level for Gemini 3 (default: low)
        temperature: Generation temperature (default: 0.0 deterministic)
        max_segments: Limit segments for testing
        max_duration_sec: Max segment duration (hard limit)
        output_dir: Directory for output files
        validate: Whether to validate transcriptions
        
    Returns:
        PipelineResult with summary
    """
    config = PipelineConfig(
        video_id=video_id,
        language=language,
        model=model,
        thinking_level=thinking_level,
        temperature=temperature,
        max_segments=max_segments,
        max_segment_duration_sec=max_duration_sec,
        output_dir=output_dir,
        validate_transcriptions=validate
    )
    
    pipeline = TranscriptionPipeline(config)
    return pipeline.run()


def test_models(
    video_id: str,
    language: str,
    models: List[str],
    thinking_levels: List[str] = ["high"],
    max_segments: int = 3
) -> Dict[str, PipelineResult]:
    """
    Test multiple models on the same video for comparison.
    
    Args:
        video_id: Video ID to test
        language: Primary language
        models: List of model names to test
        thinking_levels: Thinking levels to try
        max_segments: Segments per test
        
    Returns:
        Dict mapping model_thinking to PipelineResult
    """
    results = {}
    
    for model in models:
        for thinking in thinking_levels:
            key = f"{model}_{thinking}"
            print(f"\n{'='*60}")
            print(f"Testing: {key}")
            print(f"{'='*60}")
            
            try:
                result = run_pipeline(
                    video_id=video_id,
                    language=language,
                    model=model,
                    thinking_level=thinking,
                    max_segments=max_segments,
                    output_dir=f"./transcriptions/test_{video_id}"
                )
                results[key] = result
            except Exception as e:
                print(f"ERROR testing {key}: {e}")
                
    return results


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Transcription Pipeline")
    parser.add_argument("video_id", help="YouTube video ID")
    parser.add_argument("--language", "-l", help="Primary language (default: auto-detect)")
    parser.add_argument("--model", "-m", default="gemini-3-flash-preview",
                        help="Gemini model to use")
    parser.add_argument("--thinking", "-t", default="low",
                        choices=["minimal", "low", "medium", "high"],
                        help="Thinking level for Gemini 3 (default: low)")
    parser.add_argument("--temperature", type=float, default=0.0,
                        help="Generation temperature (default: 0.0 deterministic)")
    parser.add_argument("--no-validate", action="store_true",
                        help="Skip transcription validation")
    parser.add_argument("--max-segments", "-n", type=int,
                        help="Maximum segments to process")
    parser.add_argument("--max-duration", "-d", type=float, default=10.0,
                        help="Maximum segment duration in seconds")
    parser.add_argument("--output", "-o", default="./transcriptions",
                        help="Output directory")
    
    args = parser.parse_args()
    
    result = run_pipeline(
        video_id=args.video_id,
        language=args.language,
        model=args.model,
        thinking_level=args.thinking,
        temperature=args.temperature,
        max_segments=args.max_segments,
        max_duration_sec=args.max_duration,
        output_dir=args.output,
        validate=not args.no_validate
    )
    
    print(f"\n{'='*60}")
    print("Pipeline Summary")
    print(f"{'='*60}")
    print(f"Video ID: {result.video_id}")
    print(f"Language: {result.language}")
    print(f"Model: {result.model} (thinking: {result.thinking_level})")
    print(f"Segments processed: {result.segments_processed}")
    print(f"Chunks transcribed: {result.transcriptions_completed}")
    print(f"Total audio: {result.total_audio_duration_sec:.1f}s")
    print(f"Total time: {result.total_processing_time_sec:.1f}s")
    print(f"Output: {result.output_file}")
    if result.errors:
        print(f"Errors: {result.errors}")
