"""
IndicMFA Validator
==================

Montreal Forced Aligner (MFA) with AI4Bharat Indic acoustic models.
https://github.com/AI4Bharat/IndicMFA

Requirements:
- MFA must be installed via conda: `conda install -c conda-forge montreal-forced-aligner`
- Or use the fallback CTC segmentation method

Features:
- Precise word and phone level alignments
- Language-specific acoustic models for Indian languages
- TextGrid output with timing information
"""
import os
import time
import tempfile
import subprocess
from typing import Optional, List, Dict, Any
from pathlib import Path

from .base import (
    BaseValidator,
    ValidationResult,
    WordAlignment,
    ValidatorStatus,
    normalize_language_code
)


class IndicMFAValidator(BaseValidator):
    """
    Validator using Montreal Forced Aligner with Indic models.
    
    Provides precise word-level alignments using forced alignment.
    Requires reference text for alignment.
    """
    
    name = "indicmfa"
    description = "Montreal Forced Aligner with AI4Bharat Indic acoustic models"
    
    # MFA model names for Indian languages (to be downloaded via mfa)
    MFA_MODELS = {
        "te": {"acoustic": "telugu_mfa", "dictionary": "telugu_mfa"},
        "hi": {"acoustic": "hindi_mfa", "dictionary": "hindi_mfa"},
        "ta": {"acoustic": "tamil_mfa", "dictionary": "tamil_mfa"},
        "bn": {"acoustic": "bengali_mfa", "dictionary": "bengali_mfa"},
        "en": {"acoustic": "english_mfa", "dictionary": "english_mfa"},
    }
    
    # IndicMFA model URLs (from AI4Bharat releases)
    INDICMFA_URLS = {
        "te": {
            "acoustic": "https://github.com/AI4Bharat/IndicMFA/releases/download/v1.0/telugu_acoustic_model.zip",
            "dictionary": "https://github.com/AI4Bharat/IndicMFA/releases/download/v1.0/telugu_dictionary.txt",
        },
        "hi": {
            "acoustic": "https://github.com/AI4Bharat/IndicMFA/releases/download/v1.0/hindi_acoustic_model.zip",
            "dictionary": "https://github.com/AI4Bharat/IndicMFA/releases/download/v1.0/hindi_dictionary.txt",
        },
    }
    
    def __init__(
        self,
        enabled: bool = True,
        models_dir: str = "./models/indicmfa",
        use_ctc_fallback: bool = True,  # Fall back to CTC segmentation if MFA not available
        **kwargs
    ):
        """
        Initialize IndicMFA validator.
        
        Args:
            enabled: Whether validator is active
            models_dir: Directory for MFA models
            use_ctc_fallback: Use CTC segmentation if MFA unavailable
        """
        super().__init__(enabled=enabled, **kwargs)
        self.models_dir = Path(models_dir)
        self.use_ctc_fallback = use_ctc_fallback
        self._mfa_available = None
        self._ctc_model = None
        self._ctc_processor = None
        
    def _check_mfa_available(self) -> bool:
        """Check if MFA is installed and available."""
        if self._mfa_available is not None:
            return self._mfa_available
            
        try:
            result = subprocess.run(
                ["mfa", "version"],
                capture_output=True,
                text=True,
                timeout=10
            )
            self._mfa_available = result.returncode == 0
            if self._mfa_available:
                print(f"[{self.name}] MFA version: {result.stdout.strip()}")
        except (subprocess.TimeoutExpired, FileNotFoundError):
            self._mfa_available = False
            
        return self._mfa_available
    
    def setup(self) -> bool:
        """Setup MFA or CTC fallback."""
        print(f"[{self.name}] Checking MFA availability...")
        
        if self._check_mfa_available():
            print(f"[{self.name}] MFA is available")
            return True
        
        if self.use_ctc_fallback:
            print(f"[{self.name}] MFA not available, setting up CTC fallback")
            return self._setup_ctc_fallback()
        
        print(f"[{self.name}] MFA not available and no fallback enabled")
        print(f"[{self.name}] Install MFA: conda install -c conda-forge montreal-forced-aligner")
        return False
    
    def _setup_ctc_fallback(self) -> bool:
        """Setup CTC-based alignment as fallback."""
        try:
            import torch
            from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
            
            print(f"[{self.name}] Loading Wav2Vec2 for CTC segmentation...")
            
            # Use Telugu-specific model for better alignment
            model_name = "anuragshas/wav2vec2-large-xlsr-53-telugu"
            
            self._ctc_processor = Wav2Vec2Processor.from_pretrained(model_name)
            self._ctc_model = Wav2Vec2ForCTC.from_pretrained(model_name)
            
            device = "cuda" if torch.cuda.is_available() else "cpu"
            self._ctc_model.to(device)
            self._ctc_model.eval()
            self._device = device
            
            print(f"[{self.name}] CTC fallback ready on {device}")
            return True
            
        except Exception as e:
            print(f"[{self.name}] CTC fallback setup failed: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def _align_with_mfa(
        self,
        audio_path: str,
        reference_text: str,
        language: str
    ) -> List[WordAlignment]:
        """Perform forced alignment using MFA."""
        lang_code = normalize_language_code(language)
        
        # Create temp directory for MFA
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_path = Path(temp_dir)
            
            # Copy audio file
            audio_name = Path(audio_path).stem
            import shutil
            audio_dest = temp_path / f"{audio_name}.wav"
            
            # Convert to wav if needed
            import soundfile as sf
            data, sr = sf.read(audio_path)
            sf.write(audio_dest, data, sr)
            
            # Create transcript file
            transcript_path = temp_path / f"{audio_name}.txt"
            transcript_path.write_text(reference_text)
            
            # Output directory
            output_dir = temp_path / "aligned"
            output_dir.mkdir()
            
            # Get model names
            models = self.MFA_MODELS.get(lang_code, self.MFA_MODELS.get("en"))
            
            # Run MFA
            cmd = [
                "mfa", "align",
                str(temp_path),
                models["dictionary"],
                models["acoustic"],
                str(output_dir),
                "--clean"
            ]
            
            try:
                result = subprocess.run(
                    cmd,
                    capture_output=True,
                    text=True,
                    timeout=300
                )
                
                if result.returncode != 0:
                    print(f"[{self.name}] MFA error: {result.stderr}")
                    return []
                
                # Parse TextGrid output
                textgrid_path = output_dir / f"{audio_name}.TextGrid"
                if textgrid_path.exists():
                    return self._parse_textgrid(textgrid_path)
                    
            except subprocess.TimeoutExpired:
                print(f"[{self.name}] MFA timeout")
            except Exception as e:
                print(f"[{self.name}] MFA error: {e}")
        
        return []
    
    def _parse_textgrid(self, textgrid_path: Path) -> List[WordAlignment]:
        """Parse TextGrid file to extract word alignments."""
        try:
            from praatio import textgrid
            
            tg = textgrid.openTextgrid(str(textgrid_path), includeEmptyIntervals=False)
            
            alignments = []
            
            # Look for words tier
            for tier_name in ["words", "word", "Words"]:
                if tier_name in tg.tierNames:
                    tier = tg.getTier(tier_name)
                    for interval in tier.entries:
                        if interval.label.strip():
                            alignments.append(WordAlignment(
                                word=interval.label.strip(),
                                start_time=interval.start,
                                end_time=interval.end,
                                confidence=1.0  # MFA alignments are deterministic
                            ))
                    break
            
            return alignments
            
        except ImportError:
            print(f"[{self.name}] praatio not installed, install with: pip install praatio")
            return []
        except Exception as e:
            print(f"[{self.name}] TextGrid parse error: {e}")
            return []
    
    def _align_with_ctc(
        self,
        audio_path: str,
        reference_text: str,
        language: str
    ) -> tuple:
        """Perform alignment using CTC segmentation."""
        import torch
        import soundfile as sf
        import numpy as np
        
        # Load audio
        data, sr = sf.read(audio_path)
        if sr != 16000:
            import torchaudio
            waveform = torch.from_numpy(data).float().unsqueeze(0)
            resampler = torchaudio.transforms.Resample(sr, 16000)
            waveform = resampler(waveform)
            data = waveform.squeeze().numpy()
            sr = 16000
        
        duration = len(data) / sr
        
        # Get CTC logits
        device = getattr(self, '_device', 'cpu')
        
        inputs = self._ctc_processor(
            data,
            sampling_rate=sr,
            return_tensors="pt",
            padding=True
        )
        
        input_values = inputs.input_values.to(device)
        
        with torch.no_grad():
            outputs = self._ctc_model(input_values)
            logits = outputs.logits.cpu()
        
        # Simple proportional alignment (CTC segmentation would be better)
        words = reference_text.split()
        alignments = []
        
        if words:
            total_chars = sum(len(w) for w in words)
            current_time = 0.0
            
            for word in words:
                word_duration = (len(word) / total_chars) * duration if total_chars > 0 else 0
                alignments.append(WordAlignment(
                    word=word,
                    start_time=current_time,
                    end_time=current_time + word_duration,
                    confidence=0.8  # Estimated
                ))
                current_time += word_duration
        
        # Calculate confidence from logits
        probs = torch.softmax(logits, dim=-1)
        confidence = probs.max(dim=-1).values.mean().item()
        
        return alignments, confidence, duration
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"
    ) -> ValidationResult:
        """
        Align audio with reference text.
        
        Args:
            audio_path: Path to audio file
            reference_text: Reference transcription (REQUIRED for alignment)
            language: Language code
            
        Returns:
            ValidationResult with word alignments
        """
        if not self.enabled:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Validator disabled"
            )
        
        if not reference_text:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Reference text required for forced alignment"
            )
        
        if not self.ensure_setup():
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="MFA/CTC not available"
            )
        
        start_time = time.time()
        lang_code = normalize_language_code(language)
        
        try:
            if self._check_mfa_available():
                # Use MFA
                alignments = self._align_with_mfa(audio_path, reference_text, lang_code)
                method = "mfa"
                confidence = 1.0
                
                # Get duration
                import soundfile as sf
                data, sr = sf.read(audio_path)
                duration = len(data) / sr
            else:
                # Use CTC fallback
                alignments, confidence, duration = self._align_with_ctc(
                    audio_path, reference_text, lang_code
                )
                method = "ctc_fallback"
            
            processing_time = time.time() - start_time
            
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=len(alignments) > 0,
                transcription=reference_text,  # Use reference as transcription
                word_alignments=alignments,
                overall_confidence=confidence,
                alignment_score=confidence,
                processing_time_sec=processing_time,
                audio_duration_sec=duration,
                raw_output={
                    "language": lang_code,
                    "method": method,
                    "word_count": len(alignments)
                }
            )
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message=str(e),
                processing_time_sec=time.time() - start_time
            )
    
    def cleanup(self):
        """Release resources."""
        if self._ctc_model is not None:
            del self._ctc_model
            self._ctc_model = None
        if self._ctc_processor is not None:
            del self._ctc_processor
            self._ctc_processor = None
            
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
