"""
Vistaar Validator (IndicWhisper)
================================

Uses AI4Bharat's IndicWhisper/Vistaar models - Whisper fine-tuned for Indian languages.
https://github.com/AI4Bharat/vistaar

Supports:
- Language-specific fine-tuned models (Telugu, Hindi, Kannada, etc.)
- Word-level timestamps
- Better accuracy than base Whisper for Indic languages
"""
import os
import time
from typing import Optional, List, Dict, Any
from pathlib import Path

from .base import (
    BaseValidator,
    ValidationResult,
    WordAlignment,
    ValidatorStatus,
    normalize_language_code
)


class VistaarValidator(BaseValidator):
    """
    Validator using AI4Bharat's IndicWhisper/Vistaar models.
    
    Fine-tuned Whisper models for Indian languages with improved accuracy.
    """
    
    name = "vistaar"
    description = "AI4Bharat IndicWhisper/Vistaar - Fine-tuned Whisper for Indian languages"
    
    # Model download URLs for IndicWhisper
    MODEL_URLS = {
        "te": "https://indicwhisper.objectstore.e2enetworks.net/telugu_models.zip",
        "hi": "https://indicwhisper.objectstore.e2enetworks.net/hindi_models.zip",
        "kn": "https://indicwhisper.objectstore.e2enetworks.net/kannada_models.zip",
        "ta": "https://indicwhisper.objectstore.e2enetworks.net/tamil_models.zip",
        "ml": "https://indicwhisper.objectstore.e2enetworks.net/malayalam_models.zip",
        "bn": "https://indicwhisper.objectstore.e2enetworks.net/bengali_models.zip",
        "gu": "https://indicwhisper.objectstore.e2enetworks.net/gujarati_models.zip",
        "mr": "https://indicwhisper.objectstore.e2enetworks.net/marathi_models.zip",
        "pa": "https://indicwhisper.objectstore.e2enetworks.net/punjabi_models.zip",
        "or": "https://indicwhisper.objectstore.e2enetworks.net/odia_models.zip",
        "ur": "https://indicwhisper.objectstore.e2enetworks.net/urdu_models.zip",
        "sa": "https://indicwhisper.objectstore.e2enetworks.net/sanskrit_models.zip",
    }
    
    # Language to model directory mapping
    LANG_TO_MODEL_DIR = {
        "te": "telugu_models/whisper-medium-te_alldata_multigpu",
        "hi": "hindi_models/whisper-medium-hi_alldata_multigpu",
        "kn": "kannada_models/whisper-medium-kn_alldata_multigpu",
        "ta": "tamil_models/whisper-medium-ta_alldata_multigpu",
        "ml": "malayalam_models/whisper-medium-ml_alldata_multigpu",
        "bn": "bengali_models/whisper-medium-bn_alldata_multigpu",
        "gu": "gujarati_models/whisper-medium-gu_alldata_multigpu",
        "mr": "marathi_models/whisper-medium-mr_alldata_multigpu",
        "pa": "punjabi_models/whisper-medium-pa_alldata_multigpu",
        "or": "odia_models/whisper-medium-or_alldata_multigpu",
        "ur": "urdu_models/whisper-medium-ur_alldata_multigpu",
        "sa": "sanskrit_models/whisper-medium-sa_alldata_multigpu",
    }
    
    # Whisper language codes for fallback
    WHISPER_LANG_MAP = {
        "te": "te",
        "hi": "hi",
        "kn": "kn",
        "ta": "ta",
        "ml": "ml",
        "bn": "bn",
        "gu": "gu",
        "mr": "mr",
        "pa": "pa",
        "or": "or",
        "en": "en",
        "ur": "ur",
        "sa": "sa",
    }
    
    def __init__(
        self,
        enabled: bool = True,
        models_dir: str = "./models/indicwhisper",
        device: str = "auto",
        use_indicwhisper: bool = True,  # Use IndicWhisper if available
        fallback_model: str = "medium",  # Fallback to base Whisper
        **kwargs
    ):
        """
        Initialize Vistaar/IndicWhisper validator.
        
        Args:
            enabled: Whether validator is active
            models_dir: Directory containing IndicWhisper models
            device: "cuda", "cpu", or "auto"
            use_indicwhisper: Try to use IndicWhisper language-specific models
            fallback_model: Base Whisper model size for fallback
        """
        super().__init__(enabled=enabled, **kwargs)
        self.models_dir = Path(models_dir)
        self.device_preference = device
        self.use_indicwhisper = use_indicwhisper
        self.fallback_model = fallback_model
        self.model = None
        self.processor = None
        self.device = None
        self._model_type = None  # "indicwhisper" or "base_whisper"
        self._current_language = None
        
    def _get_model_path(self, language: str) -> Optional[Path]:
        """Get path to IndicWhisper model for given language."""
        lang_code = normalize_language_code(language)
        
        if lang_code not in self.LANG_TO_MODEL_DIR:
            return None
            
        model_subdir = self.LANG_TO_MODEL_DIR[lang_code]
        model_path = self.models_dir / model_subdir
        
        if model_path.exists() and (model_path / "pytorch_model.bin").exists():
            return model_path
        
        # Check for config.json as alternative marker
        if model_path.exists() and (model_path / "config.json").exists():
            return model_path
            
        return None
    
    def _download_model(self, language: str) -> bool:
        """Download IndicWhisper model for language if not present."""
        lang_code = normalize_language_code(language)
        
        if lang_code not in self.MODEL_URLS:
            print(f"[{self.name}] No IndicWhisper model available for {lang_code}")
            return False
            
        model_path = self._get_model_path(lang_code)
        if model_path:
            print(f"[{self.name}] Model already exists at {model_path}")
            return True
        
        try:
            import urllib.request
            import zipfile
            
            url = self.MODEL_URLS[lang_code]
            zip_path = self.models_dir / f"{lang_code}_models.zip"
            
            print(f"[{self.name}] Downloading model from {url}...")
            self.models_dir.mkdir(parents=True, exist_ok=True)
            urllib.request.urlretrieve(url, zip_path)
            
            print(f"[{self.name}] Extracting...")
            with zipfile.ZipFile(zip_path, 'r') as z:
                z.extractall(self.models_dir)
            
            # Clean up zip
            zip_path.unlink()
            print(f"[{self.name}] Model downloaded successfully")
            return True
            
        except Exception as e:
            print(f"[{self.name}] Download failed: {e}")
            return False
        
    def setup(self, language: str = "te") -> bool:
        """Load IndicWhisper or fallback Whisper model."""
        try:
            import torch
            
            # Determine device
            if self.device_preference == "auto":
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
            else:
                self.device = self.device_preference
            
            print(f"[{self.name}] Loading model on {self.device}...")
            
            lang_code = normalize_language_code(language)
            self._current_language = lang_code
            
            # Try IndicWhisper first
            if self.use_indicwhisper:
                model_path = self._get_model_path(lang_code)
                
                if model_path is None:
                    # Try to download
                    if self._download_model(lang_code):
                        model_path = self._get_model_path(lang_code)
                
                if model_path:
                    return self._setup_indicwhisper(model_path)
            
            # Fallback to base Whisper
            print(f"[{self.name}] Using base Whisper {self.fallback_model} as fallback")
            return self._setup_base_whisper()
                
        except ImportError as e:
            print(f"[{self.name}] Missing dependency: {e}")
            return False
        except Exception as e:
            print(f"[{self.name}] Setup error: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def _setup_indicwhisper(self, model_path: Path) -> bool:
        """Setup using IndicWhisper HuggingFace model."""
        try:
            from transformers import WhisperForConditionalGeneration, WhisperProcessor
            import torch
            
            print(f"[{self.name}] Loading IndicWhisper from {model_path}")
            
            self.processor = WhisperProcessor.from_pretrained(str(model_path))
            self.model = WhisperForConditionalGeneration.from_pretrained(str(model_path))
            self.model.to(self.device)
            self.model.eval()
            
            self._model_type = "indicwhisper"
            print(f"[{self.name}] IndicWhisper model loaded successfully")
            return True
            
        except Exception as e:
            print(f"[{self.name}] IndicWhisper setup failed: {e}")
            return self._setup_base_whisper()
    
    def _setup_base_whisper(self) -> bool:
        """Setup using base Whisper model via faster-whisper."""
        try:
            from faster_whisper import WhisperModel
            
            self.model = WhisperModel(
                self.fallback_model,
                device=self.device,
                compute_type="float16" if self.device == "cuda" else "int8"
            )
            
            self._model_type = "base_whisper"
            print(f"[{self.name}] Base Whisper {self.fallback_model} loaded")
            return True
            
        except Exception as e:
            print(f"[{self.name}] Base Whisper setup failed: {e}")
            return False
    
    def _transcribe_indicwhisper(
        self,
        audio_path: str,
        language: str
    ) -> Dict[str, Any]:
        """Transcribe using IndicWhisper HuggingFace model."""
        import torch
        import soundfile as sf
        import numpy as np
        
        # Load audio
        audio_data, sr = sf.read(audio_path)
        
        # Resample to 16kHz if needed
        if sr != 16000:
            import torchaudio
            audio_tensor = torch.from_numpy(audio_data).float()
            if len(audio_tensor.shape) == 1:
                audio_tensor = audio_tensor.unsqueeze(0)
            resampler = torchaudio.transforms.Resample(sr, 16000)
            audio_tensor = resampler(audio_tensor)
            audio_data = audio_tensor.squeeze().numpy()
            sr = 16000
        
        # Process
        inputs = self.processor(
            audio_data,
            sampling_rate=sr,
            return_tensors="pt"
        )
        input_features = inputs.input_features.to(self.device)
        
        # Generate with timestamps
        with torch.no_grad():
            predicted_ids = self.model.generate(
                input_features,
                return_timestamps=True,
                language=language,
            )
        
        # Decode
        transcription = self.processor.batch_decode(
            predicted_ids, 
            skip_special_tokens=True
        )[0]
        
        # Get word-level alignments if available
        word_alignments = []
        try:
            # Try to decode with timestamps
            decoded = self.processor.decode(predicted_ids[0], output_offsets=True)
            if hasattr(decoded, 'offsets') and decoded.offsets:
                for offset in decoded.offsets:
                    word_alignments.append(WordAlignment(
                        word=offset.get('text', '').strip(),
                        start_time=offset.get('start_offset', 0) / 16000,
                        end_time=offset.get('end_offset', 0) / 16000,
                        confidence=None
                    ))
        except Exception:
            # Fall back to simple word splitting
            words = transcription.split()
            duration = len(audio_data) / sr
            for i, word in enumerate(words):
                word_alignments.append(WordAlignment(
                    word=word,
                    start_time=i * duration / len(words),
                    end_time=(i + 1) * duration / len(words),
                    confidence=None
                ))
        
        return {
            "transcription": transcription.strip(),
            "alignments": word_alignments,
            "language": language,
            "duration": len(audio_data) / sr
        }
    
    def _transcribe_base_whisper(
        self,
        audio_path: str,
        language: str
    ) -> Dict[str, Any]:
        """Transcribe using faster-whisper base model."""
        lang_hint = self.WHISPER_LANG_MAP.get(language, language)
        
        segments, info = self.model.transcribe(
            audio_path,
            language=lang_hint,
            word_timestamps=True,
            beam_size=5
        )
        
        full_text = ""
        word_alignments = []
        
        for segment in segments:
            full_text += segment.text
            
            if hasattr(segment, 'words') and segment.words:
                for word in segment.words:
                    word_alignments.append(WordAlignment(
                        word=word.word.strip(),
                        start_time=word.start,
                        end_time=word.end,
                        confidence=word.probability if hasattr(word, 'probability') else None
                    ))
        
        return {
            "transcription": full_text.strip(),
            "alignments": word_alignments,
            "language": info.language,
            "language_probability": info.language_probability,
            "duration": info.duration
        }
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"
    ) -> ValidationResult:
        """
        Transcribe audio using IndicWhisper/Vistaar.
        
        Args:
            audio_path: Path to audio file
            reference_text: Optional reference (not used, for comparison only)
            language: Language code
            
        Returns:
            ValidationResult with ASR output and word alignments
        """
        if not self.enabled:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Validator disabled"
            )
        
        start_time = time.time()
        lang_code = normalize_language_code(language)
        
        # Check if we need to reload model for different language
        if self._current_language != lang_code or self.model is None:
            if not self.setup(lang_code):
                return ValidationResult(
                    validator_name=self.name,
                    audio_path=audio_path,
                    success=False,
                    error_message="Model not loaded"
                )
        
        try:
            # Transcribe based on model type
            if self._model_type == "indicwhisper":
                result = self._transcribe_indicwhisper(audio_path, lang_code)
            else:
                result = self._transcribe_base_whisper(audio_path, lang_code)
            
            # Calculate confidence from word alignments
            confidences = [
                wa.confidence for wa in result["alignments"]
                if wa.confidence is not None
            ]
            overall_confidence = sum(confidences) / len(confidences) if confidences else None
            
            processing_time = time.time() - start_time
            
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=True,
                transcription=result["transcription"],
                word_alignments=result["alignments"],
                overall_confidence=overall_confidence,
                processing_time_sec=processing_time,
                audio_duration_sec=result.get("duration"),
                raw_output={
                    "language": lang_code,
                    "model_type": self._model_type,
                    "detected_language": result.get("language"),
                    "language_probability": result.get("language_probability"),
                }
            )
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message=str(e),
                processing_time_sec=time.time() - start_time
            )
    
    def cleanup(self):
        """Release resources."""
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
            
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
