"""
IndicWav2Vec Validator
======================

Uses AI4Bharat's IndicWav2Vec fairseq models for ASR.
https://github.com/AI4Bharat/indicwav2vec

Supports:
- Language-specific fine-tuned fairseq models
- CTC decoding with character-level dictionary
- Pre-trained on 40 Indian languages
"""
import os
import time
from typing import Optional, List, Dict, Any
from pathlib import Path

from .base import (
    BaseValidator, 
    ValidationResult, 
    WordAlignment,
    ValidatorStatus,
    normalize_language_code
)


class IndicWav2VecValidator(BaseValidator):
    """
    Validator using AI4Bharat's IndicWav2Vec fairseq models.
    
    Uses the official AI4Bharat models from:
    https://github.com/AI4Bharat/indicwav2vec
    """
    
    name = "indicwav2vec"
    description = "AI4Bharat IndicWav2Vec - Fairseq ASR for Indian languages"
    
    # AI4Bharat model download URLs (fairseq format)
    MODEL_URLS = {
        "te": {
            "acoustic": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/te/te.pt",
            "dict": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/te/dict.ltr.txt",
        },
        "hi": {
            "acoustic": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/hi/hi.pt",
            "dict": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/hi/dict.ltr.txt",
        },
        "bn": {
            "acoustic": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/bn/bn.pt",
            "dict": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/bn/dict.ltr.txt",
        },
        "ta": {
            "acoustic": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/ta/ta.pt",
            "dict": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/ta/dict.ltr.txt",
        },
        "mr": {
            "acoustic": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/mr/mr.pt",
            "dict": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/mr/dict.ltr.txt",
        },
        "gu": {
            "acoustic": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/gu/gu.pt",
            "dict": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/gu/dict.ltr.txt",
        },
        "or": {
            "acoustic": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/or/or.pt",
            "dict": "https://objectstore.e2enetworks.net/indic-superb/aaai_ckpts/models/or/dict.ltr.txt",
        },
        "kn": {
            "acoustic": "https://indic-asr-public.objectstore.e2enetworks.net/indic-superb/models/acoustic/kannada.pt",
            "dict": "https://indic-asr-public.objectstore.e2enetworks.net/indic-superb/models/acoustic/kannada.dict.txt",
        },
        "ml": {
            "acoustic": "https://indic-asr-public.objectstore.e2enetworks.net/indic-superb/models/acoustic/malayalam.pt",
            "dict": "https://indic-asr-public.objectstore.e2enetworks.net/indic-superb/models/acoustic/malayalam.dict.txt",
        },
    }
    
    def __init__(
        self,
        enabled: bool = True,
        models_dir: str = "./models/indicwav2vec",
        device: str = "auto",
        language: str = "te",
        **kwargs
    ):
        """
        Initialize IndicWav2Vec validator.
        
        Args:
            enabled: Whether validator is active
            models_dir: Directory for AI4Bharat models
            device: "cuda", "cpu", or "auto"
            language: Default language code
        """
        super().__init__(enabled=enabled, **kwargs)
        self.models_dir = Path(models_dir)
        self.device_preference = device
        self.language = normalize_language_code(language)
        self.model = None
        self.dictionary = None
        self.device = None
        self._current_language = None
        
    def _get_model_paths(self, language: str) -> Dict[str, Path]:
        """Get paths to model files for given language."""
        lang_code = normalize_language_code(language)
        lang_dir = self.models_dir / lang_code
        
        return {
            "acoustic": lang_dir / f"{lang_code}.pt",
            "dict": lang_dir / "dict.ltr.txt",
        }
    
    def _download_models(self, language: str) -> bool:
        """Download AI4Bharat models for language."""
        lang_code = normalize_language_code(language)
        
        if lang_code not in self.MODEL_URLS:
            print(f"[{self.name}] No AI4Bharat models for {lang_code}")
            return False
        
        try:
            import urllib.request
            
            urls = self.MODEL_URLS[lang_code]
            lang_dir = self.models_dir / lang_code
            lang_dir.mkdir(parents=True, exist_ok=True)
            
            paths = self._get_model_paths(lang_code)
            
            for file_type, url in urls.items():
                out_path = paths[file_type]
                
                if not out_path.exists():
                    print(f"[{self.name}] Downloading {file_type} for {lang_code}...")
                    urllib.request.urlretrieve(url, out_path)
                    print(f"[{self.name}] Downloaded: {out_path}")
            
            return True
            
        except Exception as e:
            print(f"[{self.name}] Download failed: {e}")
            return False
    
    def _load_dictionary(self, dict_path: Path) -> Dict[int, str]:
        """Load character dictionary from file."""
        dictionary = {}
        with open(dict_path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    char = parts[0]
                    idx = int(parts[1])
                    dictionary[idx] = char
        return dictionary
        
    def setup(self, language: str = None) -> bool:
        """Load the IndicWav2Vec fairseq model."""
        try:
            import torch
            from fairseq.models.wav2vec import Wav2VecCtc
            
            # Determine device
            if self.device_preference == "auto":
                self.device = "cuda" if torch.cuda.is_available() else "cpu"
            else:
                self.device = self.device_preference
                
            print(f"[{self.name}] Loading model on {self.device}...")
            
            lang_code = normalize_language_code(language or self.language)
            self._current_language = lang_code
            
            # Get model paths
            paths = self._get_model_paths(lang_code)
            
            # Download if needed
            if not paths["acoustic"].exists():
                if not self._download_models(lang_code):
                    return False
            
            # Load dictionary
            if paths["dict"].exists():
                self.dictionary = self._load_dictionary(paths["dict"])
                print(f"[{self.name}] Dictionary loaded: {len(self.dictionary)} characters")
            else:
                print(f"[{self.name}] Warning: Dictionary not found")
                self.dictionary = {}
            
            # Load fairseq checkpoint
            print(f"[{self.name}] Loading checkpoint: {paths['acoustic']}")
            ckpt = torch.load(paths["acoustic"], map_location='cpu', weights_only=False)
            
            # Build model from checkpoint
            cfg = ckpt['cfg']
            model_state = ckpt['model']
            
            # Create model using fairseq
            from fairseq.tasks.audio_finetuning import AudioFinetuningTask
            from omegaconf import OmegaConf
            
            # Setup task config
            task_cfg = OmegaConf.create(cfg['task'])
            task_cfg.data = str(self.models_dir / lang_code)
            
            # Create task and model
            task = AudioFinetuningTask.setup_task(task_cfg)
            
            model_cfg = OmegaConf.create(cfg['model'])
            self.model = task.build_model(model_cfg)
            self.model.load_state_dict(model_state, strict=False)
            self.model.to(self.device)
            self.model.eval()
            
            print(f"[{self.name}] Model loaded successfully")
            return True
            
        except ImportError as e:
            print(f"[{self.name}] Missing dependency: {e}")
            print(f"[{self.name}] Trying fallback method...")
            return self._setup_direct(language)
        except Exception as e:
            print(f"[{self.name}] Setup error: {e}")
            import traceback
            traceback.print_exc()
            return self._setup_direct(language)
    
    def _setup_direct(self, language: str = None) -> bool:
        """Direct model loading without fairseq task setup."""
        try:
            import torch
            
            lang_code = normalize_language_code(language or self.language)
            self._current_language = lang_code
            
            paths = self._get_model_paths(lang_code)
            
            if not paths["acoustic"].exists():
                if not self._download_models(lang_code):
                    return False
            
            # Load dictionary
            if paths["dict"].exists():
                self.dictionary = self._load_dictionary(paths["dict"])
                print(f"[{self.name}] Dictionary: {len(self.dictionary)} chars")
            
            # Load checkpoint
            print(f"[{self.name}] Loading checkpoint directly...")
            ckpt = torch.load(paths["acoustic"], map_location='cpu', weights_only=False)
            
            # Store model state and config for manual inference
            self._model_state = ckpt['model']
            self._model_cfg = ckpt['cfg']
            
            # Build wav2vec2 encoder manually
            from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model
            from omegaconf import OmegaConf
            
            # Get w2v_args from config
            w2v_cfg = OmegaConf.create(self._model_cfg['model'].get('w2v_args', {}))
            if 'model' in w2v_cfg:
                encoder_cfg = w2v_cfg['model']
            else:
                encoder_cfg = w2v_cfg
            
            # We'll use the model state directly for inference
            self._encoder_dim = 1024  # From checkpoint analysis
            self._vocab_size = len(self.dictionary) + 4  # +4 for special tokens
            
            print(f"[{self.name}] Direct loading complete")
            self.model = "direct_mode"
            return True
            
        except Exception as e:
            print(f"[{self.name}] Direct setup failed: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def _load_audio(self, audio_path: str, target_sr: int = 16000):
        """Load and resample audio."""
        import soundfile as sf
        import torch
        
        data, sample_rate = sf.read(audio_path)
        waveform = torch.from_numpy(data).float()
        
        if len(waveform.shape) == 1:
            waveform = waveform.unsqueeze(0)
        
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        if sample_rate != target_sr:
            import torchaudio
            resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
            waveform = resampler(waveform)
            
        return waveform.squeeze(), target_sr
    
    def _decode_ctc(self, emissions, blank_idx: int = 0) -> str:
        """Decode CTC emissions to text."""
        import torch
        
        # Get best path
        indices = torch.argmax(emissions, dim=-1).squeeze().tolist()
        
        # Remove consecutive duplicates and blanks
        prev_idx = None
        chars = []
        for idx in indices:
            if idx != prev_idx and idx != blank_idx:
                if idx in self.dictionary:
                    chars.append(self.dictionary[idx])
                elif idx > 0:  # Skip blank but keep other special tokens
                    chars.append(f"<{idx}>")
            prev_idx = idx
        
        # Join characters and convert | to space
        text = ''.join(chars)
        text = text.replace('|', ' ')
        
        return text.strip()
    
    def _get_word_alignments(
        self,
        transcription: str,
        audio_duration: float
    ) -> List[WordAlignment]:
        """Generate word alignments from transcription."""
        words = transcription.split()
        if not words:
            return []
            
        total_chars = sum(len(w) for w in words)
        if total_chars == 0:
            return []
            
        alignments = []
        current_time = 0.0
        
        for word in words:
            word_duration = (len(word) / total_chars) * audio_duration
            alignments.append(WordAlignment(
                word=word,
                start_time=current_time,
                end_time=current_time + word_duration,
                confidence=None
            ))
            current_time += word_duration
            
        return alignments
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"
    ) -> ValidationResult:
        """
        Process audio through IndicWav2Vec.
        
        Args:
            audio_path: Path to audio file
            reference_text: Optional reference transcription
            language: Language code
            
        Returns:
            ValidationResult with ASR output and alignments
        """
        if not self.enabled:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Validator disabled"
            )
        
        start_time = time.time()
        lang_code = normalize_language_code(language)
        
        # Check if we need to reload model for different language
        if self._current_language != lang_code or self.model is None:
            if not self.setup(lang_code):
                return ValidationResult(
                    validator_name=self.name,
                    audio_path=audio_path,
                    success=False,
                    error_message="Model not loaded"
                )
        
        try:
            import torch
            
            # Load audio
            waveform, sr = self._load_audio(audio_path)
            audio_duration = len(waveform) / sr
            
            # Run inference
            if self.model == "direct_mode":
                # Use fairseq utilities for inference
                from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecCtc
                
                # Build model on demand
                ckpt_path = self._get_model_paths(lang_code)["acoustic"]
                
                # Use fairseq's model loading
                models, cfg, task = load_model_ensemble_and_task([str(ckpt_path)])
                model = models[0]
                model.to(self.device)
                model.eval()
                
                with torch.no_grad():
                    source = waveform.unsqueeze(0).to(self.device)
                    padding_mask = torch.zeros(source.shape[:2], dtype=torch.bool, device=self.device)
                    
                    emissions = model.get_logits(source, padding_mask=padding_mask)
                    transcription = self._decode_ctc(emissions.cpu())
            else:
                with torch.no_grad():
                    source = waveform.unsqueeze(0).to(self.device)
                    padding_mask = torch.zeros(source.shape[:2], dtype=torch.bool, device=self.device)
                    
                    # Get emissions from model
                    emissions = self.model.get_logits(source, padding_mask=padding_mask)
                    transcription = self._decode_ctc(emissions.cpu())
            
            # Get alignments
            alignments = self._get_word_alignments(transcription, audio_duration)
            
            processing_time = time.time() - start_time
            
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=True,
                transcription=transcription,
                word_alignments=alignments,
                overall_confidence=None,
                processing_time_sec=processing_time,
                audio_duration_sec=audio_duration,
                raw_output={
                    "language": lang_code,
                    "model": f"indicwav2vec-{lang_code}",
                    "model_type": "fairseq"
                }
            )
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message=str(e),
                processing_time_sec=time.time() - start_time
            )
    
    def cleanup(self):
        """Release GPU memory."""
        if self.model is not None and self.model != "direct_mode":
            del self.model
        self.model = None
        self.dictionary = None
            
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


# Helper function to load fairseq models
def load_model_ensemble_and_task(filenames):
    """Load fairseq model ensemble."""
    import torch
    from fairseq import checkpoint_utils
    from fairseq.tasks.audio_finetuning import AudioFinetuningTask
    from omegaconf import OmegaConf
    
    # Load first checkpoint to get config
    state = torch.load(filenames[0], map_location='cpu', weights_only=False)
    cfg = state['cfg']
    
    # Setup task
    task_cfg = OmegaConf.create(cfg['task'])
    task_cfg.data = str(Path(filenames[0]).parent)
    task = AudioFinetuningTask.setup_task(task_cfg)
    
    # Build model
    model_cfg = OmegaConf.create(cfg['model'])
    model = task.build_model(model_cfg)
    model.load_state_dict(state['model'], strict=False)
    
    return [model], cfg, task
