"""
IndicConformer Validator
========================

Uses AI4Bharat's IndicConformer models for ASR.

Two model types supported:
1. NeMo-based language-specific models (ai4bharat/indicconformer_stt_*_hybrid_ctc_rnnt_large)
2. HuggingFace AutoModel multilingual (ai4bharat/indic-conformer-600m-multilingual)

Features:
- 600M parameter Conformer architecture
- Hybrid CTC-RNNT decoding
- Support for 22 Indian languages
"""
import os
import time
from typing import Optional, List, Dict, Any
from pathlib import Path

from .base import (
    BaseValidator,
    ValidationResult,
    WordAlignment,
    ValidatorStatus,
    normalize_language_code
)


class IndicConformerValidator(BaseValidator):
    """
    Validator using AI4Bharat's IndicConformer models.
    
    Supports:
    - Multilingual model (600M params, 22 languages)
    - Language-specific NeMo models
    """
    
    name = "indic_conformer"
    description = "AI4Bharat IndicConformer - 600M Multilingual Indian ASR"
    
    # Language-specific NeMo models (require gated access)
    NEMO_MODELS = {
        "te": "ai4bharat/indicconformer_stt_te_hybrid_ctc_rnnt_large",
        "hi": "ai4bharat/indicconformer_stt_hi_hybrid_ctc_rnnt_large",
        "bn": "ai4bharat/indicconformer_stt_bn_hybrid_ctc_rnnt_large",
        "ta": "ai4bharat/indicconformer_stt_ta_hybrid_ctc_rnnt_large",
        "kn": "ai4bharat/indicconformer_stt_kn_hybrid_ctc_rnnt_large",
        "ml": "ai4bharat/indicconformer_stt_ml_hybrid_ctc_rnnt_large",
        "mr": "ai4bharat/indicconformer_stt_mr_hybrid_ctc_rnnt_large",
        "gu": "ai4bharat/indicconformer_stt_gu_hybrid_ctc_rnnt_large",
        "pa": "ai4bharat/indicconformer_stt_pa_hybrid_ctc_rnnt_large",
        "or": "ai4bharat/indicconformer_stt_or_hybrid_ctc_rnnt_large",
    }
    
    # Multilingual HuggingFace model (open access)
    MULTILINGUAL_MODEL = "ai4bharat/indic-conformer-600m-multilingual"
    
    def __init__(
        self,
        enabled: bool = True,
        device: str = "auto",
        language: str = "te",
        use_multilingual: bool = True,  # Prefer multilingual model (no gating)
        decoding: str = "rnnt",  # "ctc" or "rnnt"
        **kwargs
    ):
        """
        Initialize IndicConformer validator.
        
        Args:
            enabled: Whether validator is active
            device: "cuda", "cpu", or "auto"
            language: Default language code
            use_multilingual: Use multilingual model (recommended, no gating)
            decoding: Decoding strategy - "ctc" or "rnnt"
        """
        super().__init__(enabled=enabled, **kwargs)
        self.device_preference = device
        self.language = normalize_language_code(language)
        self.use_multilingual = use_multilingual
        self.decoding = decoding
        self.model = None
        self.device = None
        self._model_name = None
        self._model_type = None  # "multilingual" or "nemo"
        self._current_language = None
        
    def setup(self, language: str = None) -> bool:
        """Load IndicConformer model."""
        try:
            import torch
            
            # Determine device
            if self.device_preference == "auto":
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
            else:
                self.device = self.device_preference
                
            print(f"[{self.name}] Loading model on {self.device}...")
            
            lang_code = normalize_language_code(language or self.language)
            self._current_language = lang_code
            
            # Try multilingual model first (no gating)
            if self.use_multilingual:
                try:
                    return self._setup_multilingual(lang_code)
                except Exception as e:
                    print(f"[{self.name}] Multilingual setup failed: {e}")
                    print(f"[{self.name}] Trying language-specific NeMo model...")
            
            # Fallback to NeMo language-specific model
            return self._setup_nemo(lang_code)
            
        except Exception as e:
            print(f"[{self.name}] Setup error: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def _setup_multilingual(self, language: str) -> bool:
        """Setup using HuggingFace multilingual model."""
        from transformers import AutoModel
        import warnings
        
        self._model_name = self.MULTILINGUAL_MODEL
        print(f"[{self.name}] Loading multilingual model: {self._model_name}")
        
        # Suppress ONNX CUDA warnings (model uses ONNX which may not have CUDA provider)
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message=".*CUDAExecutionProvider.*")
            warnings.filterwarnings("ignore", message=".*FRAME_DURATION.*")
            
            # Load with trust_remote_code for custom model
            self.model = AutoModel.from_pretrained(
                self._model_name,
                trust_remote_code=True
            )
        
        # Note: Multilingual model uses ONNX runtime, device handling is internal
        self.model.eval()
        self._model_type = "multilingual"
        
        print(f"[{self.name}] Multilingual model loaded successfully (ONNX backend)")
        return True
    
    def _setup_nemo(self, language: str) -> bool:
        """Setup using NeMo language-specific model."""
        try:
            import nemo.collections.asr as nemo_asr
        except ImportError:
            print(f"[{self.name}] NeMo not installed. Install with: pip install nemo_toolkit[asr]")
            return False
        
        # Get model name for language
        self._model_name = self.NEMO_MODELS.get(language)
        
        if not self._model_name:
            print(f"[{self.name}] No NeMo model for {language}")
            return False
        
        print(f"[{self.name}] Loading NeMo model: {self._model_name}")
        
        try:
            # Load from HuggingFace (requires gated access)
            self.model = nemo_asr.models.ASRModel.from_pretrained(
                model_name=self._model_name
            )
            
            if self.device == "cuda":
                self.model = self.model.cuda()
            
            self.model.eval()
            self._model_type = "nemo"
            
            print(f"[{self.name}] NeMo model loaded successfully")
            return True
            
        except Exception as e:
            print(f"[{self.name}] NeMo model load failed: {e}")
            return False
    
    def _load_audio(self, audio_path: str, target_sr: int = 16000):
        """Load and resample audio."""
        import soundfile as sf
        import torch
        import torchaudio
        
        data, sample_rate = sf.read(audio_path)
        waveform = torch.from_numpy(data).float()
        
        # Ensure mono
        if len(waveform.shape) == 1:
            waveform = waveform.unsqueeze(0)
        elif waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        # Resample if needed
        if sample_rate != target_sr:
            resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
            waveform = resampler(waveform)
        
        return waveform, target_sr
    
    def _transcribe_multilingual(self, audio_path: str, language: str) -> Dict[str, Any]:
        """Transcribe using multilingual HuggingFace model (ONNX backend)."""
        import torch
        
        # Load audio
        waveform, sr = self._load_audio(audio_path)
        duration = waveform.shape[1] / sr
        
        # Note: Multilingual model uses ONNX runtime internally
        # Audio stays on CPU as ONNX handles device placement
        
        # Transcribe using model's __call__ method
        # Format: model(audio_tensor, language_code, decoding_method)
        with torch.no_grad():
            transcription = self.model(waveform, language, self.decoding)
        
        # Handle output format
        if isinstance(transcription, (list, tuple)):
            transcription = transcription[0] if transcription else ""
        
        # Generate word alignments
        words = str(transcription).split()
        alignments = []
        
        for i, word in enumerate(words):
            alignments.append(WordAlignment(
                word=word,
                start_time=i * duration / max(len(words), 1),
                end_time=(i + 1) * duration / max(len(words), 1),
                confidence=None
            ))
        
        return {
            "transcription": str(transcription),
            "alignments": alignments,
            "duration": duration
        }
    
    def _transcribe_nemo(self, audio_path: str) -> Dict[str, Any]:
        """Transcribe using NeMo model."""
        # NeMo can transcribe directly from file path
        transcription = self.model.transcribe([audio_path])[0]
        
        # Get duration
        import soundfile as sf
        data, sr = sf.read(audio_path)
        duration = len(data) / sr
        
        # Generate word alignments
        words = transcription.split()
        alignments = []
        
        for i, word in enumerate(words):
            alignments.append(WordAlignment(
                word=word,
                start_time=i * duration / max(len(words), 1),
                end_time=(i + 1) * duration / max(len(words), 1),
                confidence=None
            ))
        
        return {
            "transcription": transcription,
            "alignments": alignments,
            "duration": duration
        }
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"
    ) -> ValidationResult:
        """
        Transcribe audio using IndicConformer.
        
        Args:
            audio_path: Path to audio file
            reference_text: Optional reference (not used for ASR)
            language: Language code
            
        Returns:
            ValidationResult with ASR output
        """
        if not self.enabled:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Validator disabled"
            )
        
        start_time = time.time()
        lang_code = normalize_language_code(language)
        
        # Check if we need to reload model for different language
        if self._current_language != lang_code or self.model is None:
            if not self.setup(lang_code):
                return ValidationResult(
                    validator_name=self.name,
                    audio_path=audio_path,
                    success=False,
                    error_message="Model not loaded"
                )
        
        try:
            # Transcribe based on model type
            if self._model_type == "multilingual":
                result = self._transcribe_multilingual(audio_path, lang_code)
            else:
                result = self._transcribe_nemo(audio_path)
            
            processing_time = time.time() - start_time
            
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=True,
                transcription=result["transcription"],
                word_alignments=result["alignments"],
                overall_confidence=None,
                processing_time_sec=processing_time,
                audio_duration_sec=result.get("duration"),
                raw_output={
                    "language": lang_code,
                    "model": self._model_name,
                    "model_type": self._model_type,
                    "decoding": self.decoding
                }
            )
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message=str(e),
                processing_time_sec=time.time() - start_time
            )
    
    def cleanup(self):
        """Release resources."""
        if self.model is not None:
            del self.model
            self.model = None
            
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
