"""
IndicMFA Validator
==================

Uses Montreal Forced Aligner with AI4Bharat's Indic acoustic models.
https://github.com/AI4Bharat/IndicMFA

MFA provides:
- Precise word and phone-level alignment
- Alignment confidence scores
- TextGrid output format

Requirements:
- Montreal Forced Aligner (mfa) installed via conda
- AI4Bharat acoustic models downloaded from releases
"""
import os
import time
import tempfile
import shutil
import subprocess
from typing import Optional, List, Dict, Any
from pathlib import Path

from .base import (
    BaseValidator,
    ValidationResult,
    WordAlignment,
    ValidatorStatus,
    normalize_language_code
)


class IndicMFAValidator(BaseValidator):
    """
    Validator using Montreal Forced Aligner with AI4Bharat Indic models.
    
    This validator performs forced alignment given:
    - Audio file
    - Reference transcription text
    
    Outputs word-level timestamps and alignment scores.
    """
    
    name = "indicmfa"
    description = "Montreal Forced Aligner with AI4Bharat Indic acoustic models"
    
    # Supported languages and their model paths
    # Models should be downloaded from: https://github.com/AI4Bharat/IndicMFA/releases
    ACOUSTIC_MODELS = {
        "te": "telugu_mfa",
        "hi": "hindi_mfa", 
        "kn": "kannada_mfa",
        "ta": "tamil_mfa",
        "ml": "malayalam_mfa",
        "bn": "bengali_mfa",
        "gu": "gujarati_mfa",
        "mr": "marathi_mfa",
        "pa": "punjabi_mfa",
        "or": "odia_mfa",
        "en": "english_mfa",
    }
    
    # Dictionary paths (G2P or lexicon)
    DICTIONARIES = {
        "te": "telugu_dict.txt",
        "hi": "hindi_dict.txt",
        "kn": "kannada_dict.txt",
        "ta": "tamil_dict.txt",
        "ml": "malayalam_dict.txt",
        "bn": "bengali_dict.txt",
        "gu": "gujarati_dict.txt", 
        "mr": "marathi_dict.txt",
        "pa": "punjabi_dict.txt",
        "or": "odia_dict.txt",
        "en": "english_us_arpa",
    }
    
    def __init__(
        self,
        enabled: bool = True,
        models_dir: Optional[str] = None,
        use_g2p: bool = True,
        beam_size: int = 10,
        **kwargs
    ):
        """
        Initialize IndicMFA validator.
        
        Args:
            enabled: Whether validator is active
            models_dir: Directory containing MFA acoustic models
            use_g2p: Use grapheme-to-phoneme instead of dictionary
            beam_size: MFA beam size (affects accuracy vs speed)
        """
        super().__init__(enabled=enabled, **kwargs)
        self.models_dir = models_dir or os.path.expanduser("~/mfa_models")
        self.use_g2p = use_g2p
        self.beam_size = beam_size
        self.mfa_available = False
        
    def setup(self) -> bool:
        """Check MFA installation and model availability."""
        try:
            # Check if MFA is installed
            result = subprocess.run(
                ["mfa", "version"],
                capture_output=True,
                text=True,
                timeout=30
            )
            
            if result.returncode != 0:
                print(f"[{self.name}] MFA not found in PATH")
                print(f"[{self.name}] Install with: conda install -c conda-forge montreal-forced-aligner")
                return False
                
            print(f"[{self.name}] MFA version: {result.stdout.strip()}")
            
            # Check models directory
            if not os.path.exists(self.models_dir):
                os.makedirs(self.models_dir, exist_ok=True)
                print(f"[{self.name}] Models directory created: {self.models_dir}")
                print(f"[{self.name}] Download models from: https://github.com/AI4Bharat/IndicMFA/releases")
            
            self.mfa_available = True
            print(f"[{self.name}] Setup complete")
            return True
            
        except FileNotFoundError:
            print(f"[{self.name}] MFA not installed")
            print(f"[{self.name}] Install with: conda install -c conda-forge montreal-forced-aligner")
            return False
        except Exception as e:
            print(f"[{self.name}] Setup error: {e}")
            return False
    
    def _get_model_path(self, language: str) -> Optional[str]:
        """Get path to acoustic model for language."""
        lang_code = normalize_language_code(language)
        model_name = self.ACOUSTIC_MODELS.get(lang_code)
        
        if not model_name:
            return None
            
        model_path = os.path.join(self.models_dir, model_name)
        
        # Check if model exists
        if os.path.exists(model_path):
            return model_path
        if os.path.exists(model_path + ".zip"):
            return model_path + ".zip"
            
        # Try MFA's built-in models
        return model_name
    
    def _get_dictionary_path(self, language: str) -> Optional[str]:
        """Get path to pronunciation dictionary."""
        lang_code = normalize_language_code(language)
        dict_name = self.DICTIONARIES.get(lang_code)
        
        if not dict_name:
            return None
            
        dict_path = os.path.join(self.models_dir, dict_name)
        
        if os.path.exists(dict_path):
            return dict_path
            
        # Use MFA's built-in dictionaries if available
        return dict_name
    
    def _parse_textgrid(self, textgrid_path: str) -> List[WordAlignment]:
        """Parse TextGrid file to extract word alignments."""
        alignments = []
        
        try:
            # Try using textgrid library if available
            import textgrid
            
            tg = textgrid.TextGrid.fromFile(textgrid_path)
            
            # Find words tier
            words_tier = None
            for tier in tg.tiers:
                if tier.name.lower() in ['words', 'word']:
                    words_tier = tier
                    break
                    
            if words_tier is None and len(tg.tiers) > 0:
                # Use first interval tier
                for tier in tg.tiers:
                    if hasattr(tier, 'intervals'):
                        words_tier = tier
                        break
            
            if words_tier:
                for interval in words_tier.intervals:
                    if interval.mark and interval.mark.strip():
                        alignments.append(WordAlignment(
                            word=interval.mark.strip(),
                            start_time=interval.minTime,
                            end_time=interval.maxTime,
                            confidence=None
                        ))
                        
        except ImportError:
            # Fallback: parse TextGrid manually
            alignments = self._parse_textgrid_manual(textgrid_path)
        except Exception as e:
            print(f"[{self.name}] TextGrid parsing error: {e}")
            
        return alignments
    
    def _parse_textgrid_manual(self, textgrid_path: str) -> List[WordAlignment]:
        """Manual TextGrid parsing without external library."""
        alignments = []
        
        try:
            with open(textgrid_path, 'r', encoding='utf-8') as f:
                content = f.read()
                
            # Simple regex-based parsing
            import re
            
            # Find intervals with pattern: xmin, xmax, text
            pattern = r'xmin\s*=\s*([\d.]+)\s*xmax\s*=\s*([\d.]+)\s*text\s*=\s*"([^"]*)"'
            matches = re.findall(pattern, content)
            
            for xmin, xmax, text in matches:
                if text.strip():
                    alignments.append(WordAlignment(
                        word=text.strip(),
                        start_time=float(xmin),
                        end_time=float(xmax),
                        confidence=None
                    ))
                    
        except Exception as e:
            print(f"[{self.name}] Manual TextGrid parsing error: {e}")
            
        return alignments
    
    def validate(
        self,
        audio_path: str,
        reference_text: Optional[str] = None,
        language: str = "te"
    ) -> ValidationResult:
        """
        Run MFA forced alignment on audio.
        
        Args:
            audio_path: Path to audio file
            reference_text: Reference transcription (REQUIRED for MFA)
            language: Language code
            
        Returns:
            ValidationResult with word alignments
        """
        if not self.enabled:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Validator disabled"
            )
            
        if not self.ensure_setup():
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="MFA not available"
            )
            
        if not reference_text:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="Reference text required for forced alignment"
            )
        
        start_time = time.time()
        lang_code = normalize_language_code(language)
        
        # Create temp directory for MFA
        temp_dir = tempfile.mkdtemp(prefix="mfa_")
        corpus_dir = os.path.join(temp_dir, "corpus")
        output_dir = os.path.join(temp_dir, "output")
        os.makedirs(corpus_dir)
        os.makedirs(output_dir)
        
        try:
            # Copy audio to corpus directory
            audio_name = Path(audio_path).stem
            audio_ext = Path(audio_path).suffix
            corpus_audio = os.path.join(corpus_dir, f"{audio_name}{audio_ext}")
            shutil.copy2(audio_path, corpus_audio)
            
            # Create transcript file (.txt or .lab)
            transcript_path = os.path.join(corpus_dir, f"{audio_name}.txt")
            with open(transcript_path, 'w', encoding='utf-8') as f:
                f.write(reference_text)
            
            # Get model and dictionary
            model_path = self._get_model_path(lang_code)
            dict_path = self._get_dictionary_path(lang_code)
            
            # Build MFA command
            cmd = [
                "mfa", "align",
                corpus_dir,
                dict_path or "english_us_arpa",  # Fallback
                model_path or "english_mfa",     # Fallback
                output_dir,
                "--clean",
                "--overwrite",
                f"--beam={self.beam_size}",
            ]
            
            if self.use_g2p:
                cmd.append("--g2p_model_path")
                cmd.append(f"{lang_code}_g2p" if lang_code != "en" else "english_g2p")
            
            print(f"[{self.name}] Running MFA alignment...")
            
            # Run MFA
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=300  # 5 minute timeout
            )
            
            if result.returncode != 0:
                # Try with fallback models
                print(f"[{self.name}] MFA failed with custom models, trying defaults...")
                cmd_fallback = [
                    "mfa", "align",
                    corpus_dir,
                    "english_us_arpa",
                    "english_mfa",
                    output_dir,
                    "--clean",
                    "--overwrite"
                ]
                result = subprocess.run(
                    cmd_fallback,
                    capture_output=True,
                    text=True,
                    timeout=300
                )
            
            # Check for TextGrid output
            textgrid_path = os.path.join(output_dir, f"{audio_name}.TextGrid")
            
            if not os.path.exists(textgrid_path):
                # Check subdirectory
                for root, dirs, files in os.walk(output_dir):
                    for f in files:
                        if f.endswith('.TextGrid'):
                            textgrid_path = os.path.join(root, f)
                            break
            
            if os.path.exists(textgrid_path):
                # Parse TextGrid
                alignments = self._parse_textgrid(textgrid_path)
                
                # Calculate alignment score (coverage)
                ref_words = len(reference_text.split())
                aligned_words = len(alignments)
                alignment_score = aligned_words / ref_words if ref_words > 0 else 0
                
                # Get audio duration from alignments
                audio_duration = alignments[-1].end_time if alignments else 0
                
                processing_time = time.time() - start_time
                
                return ValidationResult(
                    validator_name=self.name,
                    audio_path=audio_path,
                    success=True,
                    transcription=reference_text,
                    word_alignments=alignments,
                    alignment_score=alignment_score,
                    processing_time_sec=processing_time,
                    audio_duration_sec=audio_duration,
                    raw_output={
                        "language": lang_code,
                        "textgrid_path": textgrid_path,
                        "mfa_output": result.stdout[:500] if result.stdout else None
                    }
                )
            else:
                return ValidationResult(
                    validator_name=self.name,
                    audio_path=audio_path,
                    success=False,
                    error_message=f"MFA failed: {result.stderr[:500] if result.stderr else 'No output'}",
                    processing_time_sec=time.time() - start_time
                )
                
        except subprocess.TimeoutExpired:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message="MFA alignment timed out",
                processing_time_sec=time.time() - start_time
            )
        except Exception as e:
            return ValidationResult(
                validator_name=self.name,
                audio_path=audio_path,
                success=False,
                error_message=str(e),
                processing_time_sec=time.time() - start_time
            )
        finally:
            # Cleanup temp directory
            try:
                shutil.rmtree(temp_dir)
            except:
                pass
