"""
IndicMFA Real Validator
=======================

Proper forced alignment using AI4Bharat's IndicMFA models.
Uses MFA (Montreal Forced Aligner) with G2G (grapheme-to-grapheme) dictionaries.

Provides:
- Character-level alignment timestamps
- Alignment quality metrics (log-likelihood, phone duration deviation)
- Per-word confidence scores

Requirements:
- MFA installed via conda/micromamba
- IndicMFA models downloaded from GitHub releases

Usage:
    validator = IndicMFARealValidator(language="te")
    result = validator.validate("audio.flac", "transcription text")
    print(f"Score: {result['alignment_score']}")
"""
import os
import csv
import shutil
import tempfile
import subprocess
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
import time


@dataclass
class CharAlignment:
    """Alignment for a single character."""
    char: str
    start_time: float
    end_time: float
    
    @property
    def duration(self) -> float:
        return self.end_time - self.start_time


@dataclass
class WordAlignment:
    """Alignment for a word (reconstructed from characters)."""
    word: str
    start_time: float
    end_time: float
    char_alignments: List[CharAlignment] = field(default_factory=list)
    
    @property
    def duration(self) -> float:
        return self.end_time - self.start_time


@dataclass
class IndicMFAResult:
    """Result of IndicMFA alignment."""
    audio_path: str
    transcription: str
    
    # MFA quality metrics (from alignment_analysis.csv)
    overall_log_likelihood: float = 0.0
    speech_log_likelihood: float = 0.0
    phone_duration_deviation: float = 0.0
    snr: float = 0.0
    
    # Normalized alignment score (0-1)
    alignment_score: float = 0.0
    
    # Alignments
    char_alignments: List[CharAlignment] = field(default_factory=list)
    word_alignments: List[WordAlignment] = field(default_factory=list)
    
    # Timing
    audio_duration_sec: float = 0.0
    processing_time_sec: float = 0.0
    
    # Status
    success: bool = False
    error: Optional[str] = None
    
    def to_dict(self) -> Dict:
        return {
            "audio_path": self.audio_path,
            "transcription": self.transcription,
            "alignment_score": round(self.alignment_score, 4),
            "overall_log_likelihood": round(self.overall_log_likelihood, 4),
            "speech_log_likelihood": round(self.speech_log_likelihood, 4),
            "phone_duration_deviation": round(self.phone_duration_deviation, 4),
            "snr": round(self.snr, 4),
            "word_count": len(self.word_alignments),
            "char_count": len(self.char_alignments),
            "word_alignments": [
                {
                    "word": wa.word,
                    "start_time": round(wa.start_time, 3),
                    "end_time": round(wa.end_time, 3),
                    "duration": round(wa.duration, 3)
                }
                for wa in self.word_alignments
            ],
            "audio_duration_sec": round(self.audio_duration_sec, 3),
            "processing_time_sec": round(self.processing_time_sec, 3),
            "success": self.success,
            "error": self.error
        }


class IndicMFARealValidator:
    """
    Real IndicMFA validator using Montreal Forced Aligner.
    
    Requires:
    - micromamba with MFA environment
    - IndicMFA models for the target language
    """
    
    # Model paths (relative to project root)
    MODELS_DIR = "models/indicmfa"
    
    # Available language models
    LANGUAGES = {
        "te": "Telugu",
        "hi": "Hindi", 
        "ta": "Tamil",
        "kn": "Kannada",
        "ml": "Malayalam",
        "bn": "Bengali",
        "mr": "Marathi",
        "gu": "Gujarati",
        "pa": "Punjabi",
        "or": "Odia",
        "as": "Assamese",
        "ur": "Urdu",
    }
    
    def __init__(
        self,
        language: str = "te",
        project_root: str = None,
        mfa_env: str = "mfa"
    ):
        """
        Initialize IndicMFA validator.
        
        Args:
            language: Language code (te, hi, ta, etc.)
            project_root: Path to project root (contains models/, bin/)
            mfa_env: Name of micromamba environment with MFA
        """
        self.language = language.lower()[:2]
        self.lang_name = self.LANGUAGES.get(self.language, self.language.title())
        
        if project_root is None:
            project_root = Path(__file__).parent.parent.parent
        self.project_root = Path(project_root)
        
        self.mfa_env = mfa_env
        self.micromamba = self.project_root / "bin" / "micromamba"
        
        # Model paths
        self.models_dir = self.project_root / self.MODELS_DIR / self.language
        self.acoustic_model = self.models_dir / f"{self.lang_name}_Acoustic_Model.zip"
        self.dictionary = self.models_dir / f"{self.lang_name}_Dictionary_g2g.txt"
        
        self._setup_complete = False
    
    def setup(self) -> bool:
        """Check if MFA and models are available."""
        if self._setup_complete:
            return True
        
        # Check micromamba
        if not self.micromamba.exists():
            print(f"[IndicMFA] micromamba not found at {self.micromamba}")
            return False
        
        # Check MFA environment
        try:
            result = subprocess.run(
                [str(self.micromamba), "run", "-n", self.mfa_env, "mfa", "version"],
                capture_output=True,
                text=True,
                timeout=30,
                env={**os.environ, "MAMBA_ROOT_PREFIX": str(Path.home() / "micromamba")}
            )
            if result.returncode != 0:
                print(f"[IndicMFA] MFA not available in {self.mfa_env} environment")
                return False
            print(f"[IndicMFA] MFA version: {result.stdout.strip()}")
        except Exception as e:
            print(f"[IndicMFA] Error checking MFA: {e}")
            return False
        
        # Check models
        if not self.acoustic_model.exists():
            print(f"[IndicMFA] Acoustic model not found: {self.acoustic_model}")
            print(f"[IndicMFA] Download from: https://github.com/AI4Bharat/IndicMFA/releases")
            return False
        
        if not self.dictionary.exists():
            print(f"[IndicMFA] Dictionary not found: {self.dictionary}")
            return False
        
        print(f"[IndicMFA] Ready for {self.lang_name}")
        self._setup_complete = True
        return True
    
    def _preprocess_text(self, text: str) -> str:
        """
        Preprocess text for G2G alignment.
        Adds spaces between each character, double space between words.
        """
        words = text.split()
        spaced_words = []
        for word in words:
            # Add space between each character
            spaced_word = " ".join(list(word))
            spaced_words.append(spaced_word)
        # Join words with double space
        return "  ".join(spaced_words)
    
    def _reconstruct_words(
        self,
        original_text: str,
        char_alignments: List[CharAlignment]
    ) -> List[WordAlignment]:
        """Reconstruct word alignments from character alignments."""
        words = original_text.split()
        word_alignments = []
        
        char_idx = 0
        for word in words:
            word_chars = list(word)
            word_char_alignments = []
            
            # Collect character alignments for this word
            for _ in word_chars:
                if char_idx < len(char_alignments):
                    word_char_alignments.append(char_alignments[char_idx])
                    char_idx += 1
            
            if word_char_alignments:
                word_alignments.append(WordAlignment(
                    word=word,
                    start_time=word_char_alignments[0].start_time,
                    end_time=word_char_alignments[-1].end_time,
                    char_alignments=word_char_alignments
                ))
        
        return word_alignments
    
    def _parse_textgrid(self, textgrid_path: Path) -> List[CharAlignment]:
        """Parse TextGrid file to extract character alignments."""
        try:
            from praatio import textgrid
            
            tg = textgrid.openTextgrid(str(textgrid_path), includeEmptyIntervals=False)
            
            alignments = []
            for tier_name in ["words", "phones"]:
                if tier_name in tg.tierNames:
                    tier = tg.getTier(tier_name)
                    for interval in tier.entries:
                        if interval.label.strip() and interval.label != "<unk>":
                            alignments.append(CharAlignment(
                                char=interval.label,
                                start_time=interval.start,
                                end_time=interval.end
                            ))
                    break
            
            return alignments
            
        except Exception as e:
            print(f"[IndicMFA] TextGrid parse error: {e}")
            return []
    
    def _parse_analysis_csv(self, csv_path: Path) -> Dict:
        """Parse alignment_analysis.csv for quality metrics."""
        metrics = {
            "overall_log_likelihood": 0.0,
            "speech_log_likelihood": 0.0,
            "phone_duration_deviation": 0.0,
            "snr": 0.0
        }
        
        try:
            with open(csv_path, 'r') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    metrics["overall_log_likelihood"] = float(row.get("overall_log_likelihood", 0) or 0)
                    metrics["speech_log_likelihood"] = float(row.get("speech_log_likelihood", 0) or 0)
                    metrics["phone_duration_deviation"] = float(row.get("phone_duration_deviation", 0) or 0)
                    metrics["snr"] = float(row.get("snr", 0) or 0)
                    break  # Only first row
        except Exception as e:
            print(f"[IndicMFA] CSV parse error: {e}")
        
        return metrics
    
    def _compute_alignment_score(self, metrics: Dict) -> float:
        """
        Compute normalized alignment score (0-1) from MFA metrics.
        
        Based on:
        - speech_log_likelihood: higher (less negative) is better
        - phone_duration_deviation: lower is better
        """
        sll = metrics.get("speech_log_likelihood", -100)
        pdd = metrics.get("phone_duration_deviation", 10)
        
        # Normalize speech log-likelihood (-100 to 0 range typically)
        # Map to 0-1: -100 -> 0, 0 -> 1
        sll_score = max(0, min(1, (sll + 100) / 100))
        
        # Normalize phone duration deviation (0 to 5 range typically)
        # Map to 0-1: 0 -> 1, 5 -> 0
        pdd_score = max(0, min(1, 1 - pdd / 5))
        
        # Combined score
        score = 0.7 * sll_score + 0.3 * pdd_score
        return score
    
    def validate(
        self,
        audio_path: str,
        transcription: str
    ) -> IndicMFAResult:
        """
        Run IndicMFA alignment and return quality metrics.
        
        Args:
            audio_path: Path to audio file
            transcription: Text to align
            
        Returns:
            IndicMFAResult with alignment metrics
        """
        start_time = time.time()
        
        if not self.setup():
            return IndicMFAResult(
                audio_path=audio_path,
                transcription=transcription,
                error="MFA not set up"
            )
        
        # Create temp directory
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_path = Path(temp_dir)
            corpus_dir = temp_path / "corpus"
            output_dir = temp_path / "output"
            corpus_dir.mkdir()
            output_dir.mkdir()
            
            try:
                # Convert audio to wav
                wav_path = corpus_dir / "audio.wav"
                subprocess.run(
                    ["ffmpeg", "-y", "-i", audio_path, "-ar", "16000", "-ac", "1", str(wav_path)],
                    capture_output=True,
                    timeout=30
                )
                
                # Get audio duration
                import soundfile as sf
                audio_data, sr = sf.read(str(wav_path))
                audio_duration = len(audio_data) / sr
                
                # Create transcript file with character-spaced text
                transcript_path = corpus_dir / "audio.txt"
                spaced_text = self._preprocess_text(transcription)
                transcript_path.write_text(spaced_text)
                
                # Run MFA align
                env = {**os.environ, "MAMBA_ROOT_PREFIX": str(Path.home() / "micromamba")}
                
                result = subprocess.run(
                    [
                        str(self.micromamba), "run", "-n", self.mfa_env,
                        "mfa", "align",
                        "--clean",
                        "--single_speaker",
                        str(corpus_dir),
                        str(self.dictionary),
                        str(self.acoustic_model),
                        str(output_dir)
                    ],
                    capture_output=True,
                    text=True,
                    timeout=120,
                    env=env
                )
                
                if result.returncode != 0:
                    return IndicMFAResult(
                        audio_path=audio_path,
                        transcription=transcription,
                        error=f"MFA failed: {result.stderr[-500:] if result.stderr else 'Unknown error'}",
                        processing_time_sec=time.time() - start_time
                    )
                
                # Parse outputs
                textgrid_path = output_dir / "audio.TextGrid"
                analysis_path = output_dir / "alignment_analysis.csv"
                
                char_alignments = self._parse_textgrid(textgrid_path) if textgrid_path.exists() else []
                metrics = self._parse_analysis_csv(analysis_path) if analysis_path.exists() else {}
                
                # Reconstruct word alignments
                word_alignments = self._reconstruct_words(transcription, char_alignments)
                
                # Compute alignment score
                alignment_score = self._compute_alignment_score(metrics)
                
                processing_time = time.time() - start_time
                
                return IndicMFAResult(
                    audio_path=audio_path,
                    transcription=transcription,
                    overall_log_likelihood=metrics.get("overall_log_likelihood", 0),
                    speech_log_likelihood=metrics.get("speech_log_likelihood", 0),
                    phone_duration_deviation=metrics.get("phone_duration_deviation", 0),
                    snr=metrics.get("snr", 0),
                    alignment_score=alignment_score,
                    char_alignments=char_alignments,
                    word_alignments=word_alignments,
                    audio_duration_sec=audio_duration,
                    processing_time_sec=processing_time,
                    success=True
                )
                
            except Exception as e:
                import traceback
                traceback.print_exc()
                return IndicMFAResult(
                    audio_path=audio_path,
                    transcription=transcription,
                    error=str(e),
                    processing_time_sec=time.time() - start_time
                )


# Convenience function
def indicmfa_validate(
    audio_path: str,
    transcription: str,
    language: str = "te"
) -> Dict:
    """Quick IndicMFA validation."""
    validator = IndicMFARealValidator(language=language)
    result = validator.validate(audio_path, transcription)
    return result.to_dict()


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 3:
        print("Usage: python indicmfa_real.py <audio_path> <transcription>")
        sys.exit(1)
    
    audio_path = sys.argv[1]
    transcription = " ".join(sys.argv[2:])
    
    result = indicmfa_validate(audio_path, transcription)
    
    print(f"\n{'='*60}")
    print("INDICMFA ALIGNMENT RESULT")
    print(f"{'='*60}")
    print(f"Success: {result['success']}")
    print(f"Alignment Score: {result['alignment_score']:.4f}")
    print(f"Speech Log-Likelihood: {result['speech_log_likelihood']:.2f}")
    print(f"Phone Duration Deviation: {result['phone_duration_deviation']:.2f}")
    print(f"SNR: {result['snr']:.2f}")
    print(f"Processing Time: {result['processing_time_sec']:.2f}s")
    
    if result.get('error'):
        print(f"Error: {result['error']}")
    else:
        print(f"\nWord Alignments ({result['word_count']} words):")
        for wa in result['word_alignments'][:10]:
            print(f"  {wa['word']:<20} [{wa['start_time']:.2f}-{wa['end_time']:.2f}s] ({wa['duration']:.3f}s)")
