"""
IndicConformer Validator
========================

Uses AI4Bharat's IndicConformer ASR model.
https://github.com/AI4Bharat/IndicConformerASR
https://huggingface.co/ai4bharat/indic-conformer-600m-multilingual

IndicConformer provides:
- State-of-the-art ASR for Indian languages
- 600M parameter multilingual model
- Support for 22 Indian languages
- Based on NVIDIA NeMo framework
"""
import os
import time
from typing import Optional, List, Dict, Any
from pathlib import Path

from .base import (
    BaseValidator,
    ValidationResult,
    WordAlignment,
    ValidatorStatus,
    normalize_language_code
)


class IndicConformerValidator(BaseValidator):
    """
    Validator using AI4Bharat's IndicConformer model.
    
    A Conformer-based ASR model trained on Indian languages.
    Uses NVIDIA NeMo framework under the hood.
    """
    
    name = "indic_conformer"
    description = "AI4Bharat IndicConformer 600M - Multilingual Indian ASR"
    
    # HuggingFace model ID
    MODEL_ID = "ai4bharat/indic-conformer-600m-multilingual"
    
    # Supported languages
    SUPPORTED_LANGUAGES = [
        "as",  # Assamese
        "bn",  # Bengali
        "brx", # Bodo
        "doi", # Dogri
        "en",  # English
        "gu",  # Gujarati
        "hi",  # Hindi
        "kn",  # Kannada
        "kok", # Konkani
        "mai", # Maithili
        "ml",  # Malayalam
        "mni", # Manipuri
        "mr",  # Marathi
        "ne",  # Nepali
        "or",  # Odia
        "pa",  # Punjabi
        "sa",  # Sanskrit
        "sat", # Santali
        "sd",  # Sindhi
        "ta",  # Tamil
        "te",  # Telugu
        "ur",  # Urdu
    ]
    
    def __init__(
        self,
        enabled: bool = True,
        device: str = "auto",
        batch_size: int = 1,
        **kwargs
    ):
        """
        Initialize IndicConformer validator.
        
        Args:
            enabled: Whether validator is active
            device: "cuda", "cpu", or "auto"
            batch_size: Batch size for inference
        """
        super().__init__(enabled=enabled, **kwargs)
        self.device_preference = device
        self.batch_size = batch_size
        self.model = None
        self.processor = None
        self.device = None
        self._use_nemo = False
        self._use_transformers = False
        
    def setup(self) -> bool:
        """Load IndicConformer model."""
        try:
            import torch
            
            # Determine device
            if self.device_preference == "auto":
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
            else:
                self.device = self.device_preference
            
            print(f"[{self.name}] Loading model on {self.device}...")
            
            # Try NeMo first (native format)
            if self._try_nemo_setup():
                return True
                
            # Fallback to transformers
            if self._try_transformers_setup():
                return True
                
            print(f"[{self.name}] Could not load model with NeMo or Transformers")
            return False
            
        except Exception as e:
            print(f"[{self.name}] Setup error: {e}")
            return False
    
    def _try_nemo_setup(self) -> bool:
        """Try loading with NVIDIA NeMo."""
        try:
            import nemo.collections.asr as nemo_asr
            
            print(f"[{self.name}] Trying NeMo...")
            
            # Load from HuggingFace
            self.model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(
                self.MODEL_ID
            )
            
            if self.device == "cuda":
                self.model = self.model.cuda()
            
            self.model.eval()
            self._use_nemo = True
            print(f"[{self.name}] NeMo model loaded successfully")
            return True
            
        except ImportError:
            print(f"[{self.name}] NeMo not installed")
            print(f"[{self.name}] Install with: pip install nemo_toolkit[asr]")
            return False
        except Exception as e:
            print(f"[{self.name}] NeMo setup failed: {e}")
            return False
    
    def _try_transformers_setup(self) -> bool:
        """Try loading with HuggingFace Transformers."""
        try:
            from transformers import (
                AutoModelForCTC,
                AutoProcessor,
                Wav2Vec2Processor
            )
            import torch
            
            print(f"[{self.name}] Trying Transformers...")
            
            # Try loading as Wav2Vec2-style model
            try:
                self.processor = AutoProcessor.from_pretrained(self.MODEL_ID)
                self.model = AutoModelForCTC.from_pretrained(self.MODEL_ID)
            except Exception:
                # Model might not be in transformers format
                # Use a fallback Indic model
                fallback_model = "ai4bharat/indicwav2vec-hindi"
                print(f"[{self.name}] Trying fallback model: {fallback_model}")
                self.processor = Wav2Vec2Processor.from_pretrained(fallback_model)
                self.model = AutoModelForCTC.from_pretrained(fallback_model)
            
            self.model.to(self.device)
            self.model.eval()
            self._use_transformers = True
            print(f"[{self.name}] Transformers model loaded successfully")
            return True
            
        except ImportError:
            print(f"[{self.name}] Transformers not installed")
            return False
        except Exception as e:
            print(f"[{self.name}] Transformers setup failed: {e}")
            return False
    
    def _load_audio(self, audio_path: str, target_sr: int = 16000):
        """Load and preprocess audio."""
        import torchaudio
        import torch
        
        waveform, sample_rate = torchaudio.load(audio_path)
        
        # Resample if needed
        if sample_rate != target_sr:
            resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
            waveform = resampler(waveform)
            sample_rate = target_sr
            
        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
            
        return waveform.squeeze(), sample_rate
    
    def _transcribe_nemo(self, audio_path: str) -> Dict[str, Any]:
        """Transcribe using NeMo model."""
        # NeMo models can transcribe directly from file path
        transcription = self.model.transcribe([audio_path])[0]
        
        # NeMo doesn't provide word timestamps by default
        # Would need additional alignment step
        return {
            "transcription": transcription,
            "alignments": []
        }
    
    def _transcribe_transformers(
        self,
        audio_path: str,
        language: str
    ) -> Dict[str, Any]:
        """Transcribe using Transformers model."""
        import torch
        
        waveform, sr = self._load_audio(audio_path)
        audio_duration = len(waveform) / sr
        
        # Process audio
        inputs = self.processor(
            waveform.numpy(),
            sampling_rate=sr,
            return_tensors="pt",
            padding=True
        )
        
        input_values = inputs.input_values.to(self.device)
        
        with torch.no_grad():
            outputs = self.model(input_values)
            logits = outputs.logits
        
        # Decode
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = self.processor.batch_decode(predicted_ids)[0]
        
        # Calculate confidence from logits
        probs = torch.softmax(logits, dim=-1)
        max_probs = probs.max(dim=-1).values
        confidence = max_probs.mean().item()
        
        # Simple word alignments (proportional)
        words = transcription.split()
        alignments = []
        if words:
            word_duration = audio_duration / len(words)
            for i, word in enumerate(words):
                alignments.append(WordAlignment(
                    word=word,
                    start_time=i * word_duration,
                    end_time=(i + 1) * word_duration,
                    confidence=None
                ))
        
        return {
            "transcription": transcription,
            "alignments": alignments,
            "confidence": confidence,
            "duration": audio_duration
        }
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"
    ) -> ValidationResult:
        """
        Transcribe audio using IndicConformer.
        
        Args:
            audio_path: Path to audio file
            reference_text: Optional reference transcription
            language: Language code
            
        Returns:
            ValidationResult with ASR output
        """
        if not self.enabled:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Validator disabled"
            )
            
        if not self.ensure_setup():
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Model not loaded"
            )
        
        start_time = time.time()
        lang_code = normalize_language_code(language)
        
        # Check language support
        if lang_code not in self.SUPPORTED_LANGUAGES:
            print(f"[{self.name}] Warning: {lang_code} may not be fully supported")
        
        try:
            if self._use_nemo:
                result = self._transcribe_nemo(audio_path)
            else:
                result = self._transcribe_transformers(audio_path, lang_code)
            
            processing_time = time.time() - start_time
            
            # Get audio duration
            audio_duration = result.get("duration")
            if audio_duration is None:
                import torchaudio
                info = torchaudio.info(audio_path)
                audio_duration = info.num_frames / info.sample_rate
            
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=True,
                transcription=result["transcription"],
                word_alignments=result.get("alignments", []),
                overall_confidence=result.get("confidence"),
                processing_time_sec=processing_time,
                audio_duration_sec=audio_duration,
                raw_output={
                    "language": lang_code,
                    "model_type": "nemo" if self._use_nemo else "transformers",
                    "model_id": self.MODEL_ID
                }
            )
            
        except Exception as e:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message=str(e),
                processing_time_sec=time.time() - start_time
            )
    
    def cleanup(self):
        """Release resources."""
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
            
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
