"""
Vistaar Validator
=================

Uses AI4Bharat's Vistaar model - a Whisper-based ASR fine-tuned for Indian languages.
https://github.com/AI4Bharat/vistaar

Vistaar provides:
- High-quality ASR for Indian languages
- Word-level timestamps via Whisper's alignment
- Better accuracy than base Whisper for Indic languages
"""
import os
import time
from typing import Optional, List, Dict, Any
from pathlib import Path

from .base import (
    BaseValidator,
    ValidationResult,
    WordAlignment,
    ValidatorStatus,
    normalize_language_code
)


class VistaarValidator(BaseValidator):
    """
    Validator using AI4Bharat's Vistaar (Whisper-based) model.
    
    Fine-tuned Whisper models for Indian languages with improved accuracy.
    """
    
    name = "vistaar"
    description = "AI4Bharat Vistaar - Fine-tuned Whisper for Indian languages"
    
    # Vistaar model variants per language
    # See: https://github.com/AI4Bharat/vistaar for full list
    LANGUAGE_MODELS = {
        "te": "ai4bharat/vistaar-te-base",   # Telugu
        "hi": "ai4bharat/vistaar-hi-base",   # Hindi
        "kn": "ai4bharat/vistaar-kn-base",   # Kannada
        "ta": "ai4bharat/vistaar-ta-base",   # Tamil
        "ml": "ai4bharat/vistaar-ml-base",   # Malayalam
        "bn": "ai4bharat/vistaar-bn-base",   # Bengali
        "gu": "ai4bharat/vistaar-gu-base",   # Gujarati
        "mr": "ai4bharat/vistaar-mr-base",   # Marathi
        "pa": "ai4bharat/vistaar-pa-base",   # Punjabi
        "or": "ai4bharat/vistaar-or-base",   # Odia
        "en": "openai/whisper-base",          # English fallback
    }
    
    # Whisper language codes (faster-whisper uses ISO 639-1 codes)
    WHISPER_LANG_MAP = {
        "te": "te",
        "hi": "hi",
        "kn": "kn",
        "ta": "ta",
        "ml": "ml",
        "bn": "bn",
        "gu": "gu",
        "mr": "mr",
        "pa": "pa",
        "or": "or",
        "en": "en",
        "telugu": "te",
        "hindi": "hi",
        "kannada": "kn",
        "tamil": "ta",
        "malayalam": "ml",
        "bengali": "bn",
        "gujarati": "gu",
        "marathi": "mr",
        "punjabi": "pa",
        "odia": "or",
        "english": "en",
    }
    
    def __init__(
        self,
        enabled: bool = True,
        model_size: str = "medium",  # Use medium for better multilingual support
        device: str = "auto",
        use_faster_whisper: bool = True,
        compute_type: str = "float16",
        **kwargs
    ):
        """
        Initialize Vistaar validator.
        
        Args:
            enabled: Whether validator is active
            model_size: Model size (base, small, medium, large)
            device: "cuda", "cpu", or "auto"
            use_faster_whisper: Use faster-whisper for better performance
            compute_type: Compute type for faster-whisper
        """
        super().__init__(enabled=enabled, **kwargs)
        self.model_size = model_size
        self.device_preference = device
        self.use_faster_whisper = use_faster_whisper
        self.compute_type = compute_type
        self.model = None
        self.device = None
        
    def setup(self) -> bool:
        """Load Whisper/Vistaar model."""
        try:
            import torch
            
            # Determine device
            if self.device_preference == "auto":
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
            else:
                self.device = self.device_preference
            
            print(f"[{self.name}] Loading model on {self.device}...")
            
            if self.use_faster_whisper:
                return self._setup_faster_whisper()
            else:
                return self._setup_whisper()
                
        except ImportError as e:
            print(f"[{self.name}] Missing dependency: {e}")
            print(f"[{self.name}] Install with: pip install faster-whisper or pip install openai-whisper")
            return False
        except Exception as e:
            print(f"[{self.name}] Setup error: {e}")
            return False
    
    def _setup_faster_whisper(self) -> bool:
        """Setup using faster-whisper library."""
        try:
            from faster_whisper import WhisperModel
            
            # Use base Whisper with language hints for now
            # Vistaar models would need to be converted to CTranslate2 format
            model_name = f"{self.model_size}"
            
            self.model = WhisperModel(
                model_name,
                device=self.device,
                compute_type=self.compute_type if self.device == "cuda" else "int8"
            )
            
            self._is_faster_whisper = True
            print(f"[{self.name}] faster-whisper model loaded")
            return True
            
        except Exception as e:
            print(f"[{self.name}] faster-whisper setup failed: {e}")
            print(f"[{self.name}] Falling back to standard whisper...")
            return self._setup_whisper()
    
    def _setup_whisper(self) -> bool:
        """Setup using openai-whisper library."""
        try:
            import whisper
            
            self.model = whisper.load_model(self.model_size, device=self.device)
            self._is_faster_whisper = False
            print(f"[{self.name}] openai-whisper model loaded")
            return True
            
        except Exception as e:
            print(f"[{self.name}] whisper setup failed: {e}")
            return False
    
    def _transcribe_faster_whisper(
        self,
        audio_path: str,
        language: str
    ) -> Dict[str, Any]:
        """Transcribe using faster-whisper."""
        from faster_whisper import WhisperModel
        
        lang_hint = self.WHISPER_LANG_MAP.get(language, language)
        
        segments, info = self.model.transcribe(
            audio_path,
            language=lang_hint,
            word_timestamps=True,
            beam_size=5
        )
        
        # Collect results
        full_text = ""
        word_alignments = []
        
        for segment in segments:
            full_text += segment.text
            
            if hasattr(segment, 'words') and segment.words:
                for word in segment.words:
                    word_alignments.append(WordAlignment(
                        word=word.word.strip(),
                        start_time=word.start,
                        end_time=word.end,
                        confidence=word.probability if hasattr(word, 'probability') else None
                    ))
        
        return {
            "transcription": full_text.strip(),
            "alignments": word_alignments,
            "language": info.language,
            "language_probability": info.language_probability,
            "duration": info.duration
        }
    
    def _transcribe_whisper(
        self,
        audio_path: str,
        language: str
    ) -> Dict[str, Any]:
        """Transcribe using openai-whisper."""
        import whisper
        
        lang_hint = self.WHISPER_LANG_MAP.get(language, language)
        
        result = self.model.transcribe(
            audio_path,
            language=lang_hint,
            word_timestamps=True
        )
        
        # Extract word alignments
        word_alignments = []
        
        for segment in result.get("segments", []):
            for word in segment.get("words", []):
                word_alignments.append(WordAlignment(
                    word=word["word"].strip(),
                    start_time=word["start"],
                    end_time=word["end"],
                    confidence=word.get("probability")
                ))
        
        return {
            "transcription": result["text"].strip(),
            "alignments": word_alignments,
            "language": result.get("language"),
            "duration": None  # Not directly available
        }
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"
    ) -> ValidationResult:
        """
        Transcribe audio using Vistaar/Whisper.
        
        Args:
            audio_path: Path to audio file
            reference_text: Optional reference (not used, for comparison only)
            language: Language code
            
        Returns:
            ValidationResult with ASR output and word alignments
        """
        if not self.enabled:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Validator disabled"
            )
            
        if not self.ensure_setup():
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Model not loaded"
            )
        
        start_time = time.time()
        lang_code = normalize_language_code(language)
        
        try:
            # Transcribe
            if self._is_faster_whisper:
                result = self._transcribe_faster_whisper(audio_path, lang_code)
            else:
                result = self._transcribe_whisper(audio_path, lang_code)
            
            # Calculate confidence from word alignments
            confidences = [
                wa.confidence for wa in result["alignments"]
                if wa.confidence is not None
            ]
            overall_confidence = sum(confidences) / len(confidences) if confidences else None
            
            # Get audio duration
            audio_duration = result.get("duration")
            if audio_duration is None and result["alignments"]:
                audio_duration = result["alignments"][-1].end_time
            
            processing_time = time.time() - start_time
            
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=True,
                transcription=result["transcription"],
                word_alignments=result["alignments"],
                overall_confidence=overall_confidence,
                processing_time_sec=processing_time,
                audio_duration_sec=audio_duration,
                raw_output={
                    "language": lang_code,
                    "detected_language": result.get("language"),
                    "language_probability": result.get("language_probability"),
                    "model_type": "faster-whisper" if self._is_faster_whisper else "openai-whisper"
                }
            )
            
        except Exception as e:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message=str(e),
                processing_time_sec=time.time() - start_time
            )
    
    def cleanup(self):
        """Release resources."""
        if self.model is not None:
            del self.model
            self.model = None
            
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
