"""
IndicWav2Vec Validator
======================

Uses AI4Bharat's IndicWav2Vec models for ASR.
https://github.com/AI4Bharat/indicwav2vec

Supports two modes:
1. HuggingFace models (easier setup, community fine-tunes)
2. AI4Bharat fairseq models (requires fairseq==0.10.2)

For fairseq models, install:
    pip install fairseq==0.10.2 omegaconf==2.0.6 hydra-core==1.0.7
"""
import os
import time
from typing import Optional, List, Dict, Any
from pathlib import Path

from .base import (
    BaseValidator, 
    ValidationResult, 
    WordAlignment,
    ValidatorStatus,
    normalize_language_code
)


class IndicWav2VecValidator(BaseValidator):
    """
    Validator using IndicWav2Vec models.
    
    Default: Uses HuggingFace community models (reliable)
    Optional: Can use AI4Bharat fairseq models (requires specific fairseq version)
    """
    
    name = "indicwav2vec"
    description = "IndicWav2Vec - Fine-tuned Wav2Vec2 for Indian languages"
    
    # HuggingFace community models (reliable, work with transformers)
    HF_MODELS = {
        "te": "Harveenchadha/vakyansh-wav2vec2-telugu-tem-100",
        "hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
        "bn": "ai4bharat/indicwav2vec_v1_bengali",
        "ta": "Harveenchadha/vakyansh-wav2vec2-tamil-tam-250",
        "kn": "Harveenchadha/vakyansh-wav2vec2-kannada-knm-560",
        "ml": "gvs/wav2vec2-large-xlsr-malayalam",
        "gu": "Harveenchadha/vakyansh-wav2vec2-gujarati-gnm-100",
        "mr": "Harveenchadha/vakyansh-wav2vec2-marathi-mrm-100",
    }
    
    # AI4Bharat fairseq model URLs (for reference)
    AI4BHARAT_URLS = {
        "te": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/te/te.pt",
        "hi": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/hi/hi.pt",
        "bn": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/bn/bn.pt",
        "ta": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/ta/ta.pt",
        "mr": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/mr/mr.pt",
        "gu": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/gu/gu.pt",
        "or": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/or/or.pt",
        "kn": "https://indic-asr-public.objectstore.e2enetworks.net/indic-superb/models/acoustic/kannada.pt",
        "ml": "https://indic-asr-public.objectstore.e2enetworks.net/indic-superb/models/acoustic/malayalam.pt",
    }
    
    def __init__(
        self,
        enabled: bool = True,
        device: str = "auto",
        language: str = "te",
        **kwargs
    ):
        """
        Initialize IndicWav2Vec validator.
        
        Args:
            enabled: Whether validator is active
            device: "cuda", "cpu", or "auto"
            language: Default language code
        """
        super().__init__(enabled=enabled, **kwargs)
        self.device_preference = device
        self.language = normalize_language_code(language)
        self.processor = None
        self.model = None
        self.device = None
        self._model_name = None
        self._current_language = None
        
    def setup(self, language: str = None) -> bool:
        """Load the Wav2Vec2 model and processor."""
        try:
            import torch
            from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
            
            # Determine device
            if self.device_preference == "auto":
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
            else:
                self.device = self.device_preference
                
            print(f"[{self.name}] Loading model on {self.device}...")
            
            lang_code = normalize_language_code(language or self.language)
            self._current_language = lang_code
            
            # Get HuggingFace model name
            self._model_name = self.HF_MODELS.get(lang_code)
            
            if not self._model_name:
                # Fallback to Telugu model
                print(f"[{self.name}] No model for {lang_code}, using Telugu")
                self._model_name = self.HF_MODELS["te"]
            
            print(f"[{self.name}] Using model: {self._model_name}")
            
            self.processor = Wav2Vec2Processor.from_pretrained(self._model_name)
            self.model = Wav2Vec2ForCTC.from_pretrained(self._model_name)
            self.model.to(self.device)
            self.model.eval()
            
            print(f"[{self.name}] Model loaded successfully")
            return True
            
        except Exception as e:
            print(f"[{self.name}] Setup error: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def _load_audio(self, audio_path: str, target_sr: int = 16000):
        """Load and resample audio."""
        import soundfile as sf
        import torch
        
        data, sample_rate = sf.read(audio_path)
        waveform = torch.from_numpy(data).float()
        
        if len(waveform.shape) == 1:
            waveform = waveform.unsqueeze(0)
        
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        if sample_rate != target_sr:
            import torchaudio
            resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
            waveform = resampler(waveform)
            
        return waveform.squeeze(), target_sr
    
    def _get_word_alignments(
        self,
        transcription: str,
        audio_duration: float
    ) -> List[WordAlignment]:
        """Generate word alignments from transcription."""
        words = transcription.split()
        if not words:
            return []
            
        total_chars = sum(len(w) for w in words)
        if total_chars == 0:
            return []
            
        alignments = []
        current_time = 0.0
        
        for word in words:
            word_duration = (len(word) / total_chars) * audio_duration
            alignments.append(WordAlignment(
                word=word,
                start_time=current_time,
                end_time=current_time + word_duration,
                confidence=None
            ))
            current_time += word_duration
            
        return alignments
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"
    ) -> ValidationResult:
        """
        Process audio through IndicWav2Vec.
        
        Args:
            audio_path: Path to audio file
            reference_text: Optional reference transcription
            language: Language code
            
        Returns:
            ValidationResult with ASR output and alignments
        """
        if not self.enabled:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Validator disabled"
            )
        
        start_time = time.time()
        lang_code = normalize_language_code(language)
        
        # Check if we need to reload model for different language
        if self._current_language != lang_code or self.model is None:
            if not self.setup(lang_code):
                return ValidationResult(
                    validator_name=self.name,
                    audio_path=audio_path,
                    success=False,
                    error_message="Model not loaded"
                )
        
        try:
            import torch
            
            # Load audio
            waveform, sr = self._load_audio(audio_path)
            audio_duration = len(waveform) / sr
            
            # Process through model
            inputs = self.processor(
                waveform.numpy(),
                sampling_rate=sr,
                return_tensors="pt",
                padding=True
            )
            
            input_values = inputs.input_values.to(self.device)
            
            with torch.no_grad():
                outputs = self.model(input_values)
                logits = outputs.logits
            
            # Decode transcription
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription = self.processor.batch_decode(predicted_ids)[0]
            
            # Get alignments
            alignments = self._get_word_alignments(transcription, audio_duration)
            
            # Calculate confidence from logits
            probs = torch.softmax(logits, dim=-1)
            max_probs = probs.max(dim=-1).values
            overall_confidence = max_probs.mean().item()
            
            processing_time = time.time() - start_time
            
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=True,
                transcription=transcription,
                word_alignments=alignments,
                overall_confidence=overall_confidence,
                processing_time_sec=processing_time,
                audio_duration_sec=audio_duration,
                raw_output={
                    "language": lang_code,
                    "model": self._model_name,
                    "model_type": "huggingface"
                }
            )
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message=str(e),
                processing_time_sec=time.time() - start_time
            )
    
    def cleanup(self):
        """Release GPU memory."""
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
            
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
