"""
CTC Forced Aligner
==================

Proper forced alignment using torchaudio's CTC forced_align.
This is equivalent to MFA but using neural CTC instead of GMM-HMM.

Provides:
- Per-character alignment scores (log-prob)
- Per-word alignment scores  
- Overall alignment quality score
- Identifies low-confidence regions

Usage:
    aligner = CTCForcedAligner(language="te")
    result = aligner.align("audio.flac", "transcription text")
    
    print(f"Alignment Score: {result.alignment_score}")  # 0-1
    print(f"Low confidence words: {result.low_confidence_words}")
"""
import torch
import torchaudio
import numpy as np
import soundfile as sf
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import time


@dataclass
class WordAlignment:
    """Alignment info for a single word."""
    word: str
    start_frame: int
    end_frame: int
    start_time: float
    end_time: float
    log_prob: float  # Raw log probability
    confidence: float  # Normalized 0-1 score
    char_scores: List[float] = field(default_factory=list)
    
    @property
    def duration(self) -> float:
        return self.end_time - self.start_time
    
    @property
    def is_low_confidence(self) -> bool:
        return self.confidence < 0.5


@dataclass 
class AlignmentResult:
    """Result of forced alignment."""
    audio_path: str
    transcription: str
    
    # Scores
    alignment_score: float  # 0-1, overall quality
    mean_log_prob: float  # Raw mean log prob
    
    # Per-word details
    word_alignments: List[WordAlignment] = field(default_factory=list)
    
    # Quality flags
    low_confidence_words: List[str] = field(default_factory=list)
    low_confidence_ratio: float = 0.0
    
    # Timing
    audio_duration_sec: float = 0.0
    processing_time_sec: float = 0.0
    
    def to_dict(self) -> Dict:
        return {
            "audio_path": self.audio_path,
            "transcription": self.transcription,
            "alignment_score": round(self.alignment_score, 4),
            "mean_log_prob": round(self.mean_log_prob, 4),
            "word_count": len(self.word_alignments),
            "low_confidence_words": self.low_confidence_words,
            "low_confidence_ratio": round(self.low_confidence_ratio, 4),
            "word_alignments": [
                {
                    "word": wa.word,
                    "start_time": round(wa.start_time, 3),
                    "end_time": round(wa.end_time, 3),
                    "confidence": round(wa.confidence, 4),
                    "log_prob": round(wa.log_prob, 4)
                }
                for wa in self.word_alignments
            ],
            "audio_duration_sec": round(self.audio_duration_sec, 3),
            "processing_time_sec": round(self.processing_time_sec, 3)
        }


class CTCForcedAligner:
    """
    CTC-based forced alignment using torchaudio.
    
    This provides MFA-equivalent functionality using neural CTC models.
    """
    
    # Language-specific Wav2Vec2 models for CTC forced alignment
    # Each model trained on the target language for best phoneme coverage.
    # "default" (XLSR-53) is multilingual fallback for languages without dedicated models.
    MODELS = {
        "te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
        "hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
        "mr": "theainerd/Wav2Vec2-large-xlsr-hindi",       # Marathi: Devanagari, use Hindi model
        "ta": "Harveenchadha/wav2vec2-large-xlsr-53-tamil",
        "kn": "Harveenchadha/wav2vec2-large-xlsr-53-kannada",
        "ml": "gvs/wav2vec2-large-xlsr-malayalam",
        "bn": "arijitx/wav2vec2-large-xlsr-bengali",       # Fixed: was pointing to Hindi
        "as": "arijitx/wav2vec2-large-xlsr-bengali",       # Assamese: closest to Bengali
        "gu": "facebook/wav2vec2-large-xlsr-53",            # Gujarati: use multilingual XLSR
        "pa": "facebook/wav2vec2-large-xlsr-53",            # Punjabi: use multilingual XLSR
        "or": "facebook/wav2vec2-large-xlsr-53",            # Odia: use multilingual XLSR
        "en": "facebook/wav2vec2-large-960h-lv60-self",     # English: dedicated English model
        "default": "facebook/wav2vec2-large-xlsr-53",
    }
    
    def __init__(
        self,
        language: str = "te",
        device: str = "auto",
        cache_dir: str = "./models/ctc_aligner"
    ):
        self.language = language.lower()[:2]
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        if device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        
        self.model = None
        self.processor = None
        self.vocab = None
        self._setup_complete = False
    
    def setup(self) -> bool:
        """Load model and processor."""
        if self._setup_complete:
            return True
        
        try:
            from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
            
            model_name = self.MODELS.get(self.language, self.MODELS["default"])
            print(f"[CTCAligner] Loading {model_name}...")
            
            self.processor = Wav2Vec2Processor.from_pretrained(
                model_name, cache_dir=self.cache_dir
            )
            self.model = Wav2Vec2ForCTC.from_pretrained(
                model_name, cache_dir=self.cache_dir
            )
            
            self.model.to(self.device)
            self.model.eval()
            
            self.vocab = self.processor.tokenizer.get_vocab()
            self.blank_id = self.processor.tokenizer.pad_token_id
            
            print(f"[CTCAligner] Ready on {self.device}")
            self._setup_complete = True
            return True
            
        except Exception as e:
            print(f"[CTCAligner] Setup failed: {e}")
            return False
    
    def _load_audio(self, audio_path: str) -> Tuple[np.ndarray, int]:
        """Load and resample audio to 16kHz."""
        data, sr = sf.read(audio_path)
        
        if len(data.shape) > 1:
            data = data.mean(axis=1)
        
        if sr != 16000:
            waveform = torch.from_numpy(data).float().unsqueeze(0)
            data = torchaudio.functional.resample(waveform, sr, 16000).squeeze().numpy()
            sr = 16000
        
        return data, sr
    
    def _get_emissions(self, audio: np.ndarray) -> torch.Tensor:
        """Get CTC log-probability emissions."""
        inputs = self.processor(
            audio, sampling_rate=16000, return_tensors="pt", padding=True
        )
        
        with torch.no_grad():
            logits = self.model(inputs.input_values.to(self.device)).logits
        
        emissions = torch.log_softmax(logits.cpu(), dim=-1)
        return emissions
    
    def _text_to_tokens(self, text: str) -> tuple:
        """Convert text to token IDs. Returns (tensor, oov_ratio).
        OOV chars are dropped but tracked. High OOV = unreliable score."""
        tokens = []
        total_chars = 0
        oov_chars = 0
        for char in text:
            if char == " ":
                if "|" in self.vocab:
                    tokens.append(self.vocab["|"])
            elif char in self.vocab:
                tokens.append(self.vocab[char])
                total_chars += 1
            else:
                oov_chars += 1
                total_chars += 1
        oov_ratio = oov_chars / max(total_chars, 1)
        if oov_ratio > 0.1:
            print(f"[CTCAligner] WARNING: {oov_ratio:.0%} OOV chars "
                  f"({oov_chars}/{total_chars})")
        return torch.tensor([tokens], dtype=torch.int32), oov_ratio

    def _compute_word_alignments(
        self,
        text: str,
        alignment: torch.Tensor,
        scores: torch.Tensor,
        frame_duration: float
    ) -> List[WordAlignment]:
        """Extract word-level alignments from frame-level results."""
        words = text.split()
        
        # Build character-to-word mapping
        char_to_word = []
        for word_idx, word in enumerate(words):
            for _ in word:
                char_to_word.append(word_idx)
            char_to_word.append(-1)  # Space/separator
        
        # Remove trailing separator
        if char_to_word and char_to_word[-1] == -1:
            char_to_word = char_to_word[:-1]
        
        # Aggregate scores per word
        word_scores = {i: [] for i in range(len(words))}
        word_frames = {i: [] for i in range(len(words))}
        
        alignment_np = alignment[0].numpy()
        scores_np = scores[0].numpy()
        
        char_idx = 0
        for frame_idx, (token_idx, score) in enumerate(zip(alignment_np, scores_np)):
            if token_idx != self.blank_id and char_idx < len(char_to_word):
                word_idx = char_to_word[char_idx]
                if word_idx >= 0:
                    word_scores[word_idx].append(score)
                    word_frames[word_idx].append(frame_idx)
                char_idx += 1
        
        # Create WordAlignment objects
        alignments = []
        for word_idx, word in enumerate(words):
            word_score_list = word_scores.get(word_idx, [])
            frame_list = word_frames.get(word_idx, [])
            
            if word_score_list:
                mean_log_prob = np.mean(word_score_list)
                # Convert log prob to 0-1 confidence
                # log_prob of 0 = prob of 1 (perfect)
                # log_prob of -5 = prob of ~0.007 (bad)
                confidence = np.exp(mean_log_prob)
                confidence = min(1.0, max(0.0, confidence))
                
                start_frame = min(frame_list)
                end_frame = max(frame_list) + 1
            else:
                mean_log_prob = -10.0
                confidence = 0.0
                start_frame = 0
                end_frame = 0
            
            alignments.append(WordAlignment(
                word=word,
                start_frame=start_frame,
                end_frame=end_frame,
                start_time=start_frame * frame_duration,
                end_time=end_frame * frame_duration,
                log_prob=float(mean_log_prob),
                confidence=float(confidence),
                char_scores=word_score_list
            ))
        
        return alignments
    
    def align(
        self,
        audio_path: str,
        transcription: str
    ) -> AlignmentResult:
        """
        Perform forced alignment.
        
        Args:
            audio_path: Path to audio file
            transcription: Text to align
            
        Returns:
            AlignmentResult with word-level alignments and scores
        """
        if not self._setup_complete:
            if not self.setup():
                return AlignmentResult(
                    audio_path=audio_path,
                    transcription=transcription,
                    alignment_score=0.0,
                    mean_log_prob=-10.0
                )
        
        start_time = time.time()
        
        try:
            # Load audio
            audio, sr = self._load_audio(audio_path)
            audio_duration = len(audio) / sr
            
            # Get emissions
            emissions = self._get_emissions(audio)
            num_frames = emissions.shape[1]
            frame_duration = audio_duration / num_frames
            
            # Convert text to tokens (returns tuple with OOV ratio)
            targets, oov_ratio = self._text_to_tokens(transcription)
            
            if targets.shape[1] == 0:
                return AlignmentResult(
                    audio_path=audio_path,
                    transcription=transcription,
                    alignment_score=0.0,
                    mean_log_prob=-10.0,
                    audio_duration_sec=audio_duration,
                    processing_time_sec=time.time() - start_time
                )
            
            # Run forced alignment
            input_lengths = torch.tensor([num_frames])
            target_lengths = torch.tensor([targets.shape[1]])
            
            alignment, scores = torchaudio.functional.forced_align(
                emissions,
                targets,
                input_lengths,
                target_lengths,
                blank=self.blank_id
            )
            
            # Compute word alignments
            word_alignments = self._compute_word_alignments(
                transcription, alignment, scores, frame_duration
            )
            
            # Overall scores
            mean_log_prob = scores.mean().item()
            
            # Convert to 0-1 score
            # Typical range: -0.5 (excellent) to -5.0 (poor)
            # Map to 0-1: score = 1 + (log_prob / 5)
            alignment_score = max(0.0, min(1.0, 1.0 + mean_log_prob / 5.0))
            # Penalize when many chars were OOV (alignment unreliable)
            if oov_ratio > 0.1:
                alignment_score *= (1.0 - oov_ratio * 0.5)
            
            # Find low confidence words
            low_conf_words = [wa.word for wa in word_alignments if wa.is_low_confidence]
            low_conf_ratio = len(low_conf_words) / max(len(word_alignments), 1)
            
            processing_time = time.time() - start_time
            
            return AlignmentResult(
                audio_path=audio_path,
                transcription=transcription,
                alignment_score=alignment_score,
                mean_log_prob=mean_log_prob,
                word_alignments=word_alignments,
                low_confidence_words=low_conf_words,
                low_confidence_ratio=low_conf_ratio,
                audio_duration_sec=audio_duration,
                processing_time_sec=processing_time
            )
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return AlignmentResult(
                audio_path=audio_path,
                transcription=transcription,
                alignment_score=0.0,
                mean_log_prob=-10.0,
                processing_time_sec=time.time() - start_time
            )
    
    def cleanup(self):
        """Release resources."""
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


# Convenience function
def forced_align(
    audio_path: str,
    transcription: str,
    language: str = "te"
) -> Dict:
    """Quick forced alignment."""
    aligner = CTCForcedAligner(language=language)
    result = aligner.align(audio_path, transcription)
    aligner.cleanup()
    return result.to_dict()


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 3:
        print("Usage: python ctc_forced_aligner.py <audio_path> <transcription>")
        sys.exit(1)
    
    audio_path = sys.argv[1]
    transcription = " ".join(sys.argv[2:])
    
    result = forced_align(audio_path, transcription)
    
    print(f"\n{'='*60}")
    print(f"FORCED ALIGNMENT RESULT")
    print(f"{'='*60}")
    print(f"Alignment Score: {result['alignment_score']:.4f}")
    print(f"Mean Log Prob: {result['mean_log_prob']:.4f}")
    print(f"Low Confidence Words: {result['low_confidence_words']}")
    print(f"Low Confidence Ratio: {result['low_confidence_ratio']:.2%}")
    
    print(f"\nWord Alignments:")
    for wa in result['word_alignments']:
        conf_bar = "█" * int(wa['confidence'] * 10)
        print(f"  {wa['word']:<20} {wa['confidence']:.3f} {conf_bar} [{wa['start_time']:.2f}-{wa['end_time']:.2f}s]")
