"""
Alignment Scorer
================

CTC-based scoring for transcription validation.
Provides word-level confidence scores and overall alignment quality.

This module uses Wav2Vec2 CTC models to:
1. Score how well transcription matches audio
2. Provide word-level confidence estimates
3. Calculate overall alignment quality

Usage:
    scorer = AlignmentScorer(language="te")
    result = scorer.score_transcription(audio_path, transcription)
    print(f"Score: {result['alignment_score']}")
    print(f"Word confidences: {result['word_scores']}")
"""
import os
import time
import torch
import numpy as np
import soundfile as sf
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any, Tuple
from pathlib import Path


@dataclass
class WordScore:
    """Score for a single word."""
    word: str
    start_time: float
    end_time: float
    confidence: float
    char_probs: List[float] = field(default_factory=list)


@dataclass
class AlignmentResult:
    """Result of alignment scoring."""
    audio_path: str
    transcription: str
    
    # Overall scores
    alignment_score: float  # 0-1, how well text matches audio
    average_confidence: float  # Mean character confidence
    min_confidence: float  # Lowest character confidence
    
    # Word-level details
    word_scores: List[WordScore] = field(default_factory=list)
    
    # Timing
    audio_duration_sec: float = 0.0
    processing_time_sec: float = 0.0
    
    # Metadata
    method: str = "ctc"
    language: str = "te"
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "audio_path": self.audio_path,
            "transcription": self.transcription,
            "alignment_score": round(self.alignment_score, 4),
            "average_confidence": round(self.average_confidence, 4),
            "min_confidence": round(self.min_confidence, 4),
            "word_count": len(self.word_scores),
            "word_scores": [
                {
                    "word": ws.word,
                    "start_time": round(ws.start_time, 3),
                    "end_time": round(ws.end_time, 3),
                    "confidence": round(ws.confidence, 4)
                }
                for ws in self.word_scores
            ],
            "audio_duration_sec": round(self.audio_duration_sec, 3),
            "processing_time_sec": round(self.processing_time_sec, 3),
            "method": self.method,
            "language": self.language
        }


class AlignmentScorer:
    """
    CTC-based alignment scorer for transcription validation.
    
    Uses Wav2Vec2 model to compute frame-level probabilities and
    align transcription to audio for scoring.
    """
    
    # Language-specific models
    MODELS = {
        "te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
        "hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
        "ta": "Harveenchadha/wav2vec2-large-xlsr-53-tamil",
        "bn": "ai4bharat/indicwav2vec-hindi",  # Multilingual fallback
        "default": "facebook/wav2vec2-large-xlsr-53",
    }
    
    def __init__(
        self,
        language: str = "te",
        device: str = "auto",
        cache_dir: str = "./models/alignment"
    ):
        """
        Initialize alignment scorer.
        
        Args:
            language: Language code (te, hi, ta, etc.)
            device: Device to use (auto, cuda, cpu)
            cache_dir: Directory for model cache
        """
        self.language = language.lower()[:2]
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # Device selection
        if device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        
        self.model = None
        self.processor = None
        self._setup_complete = False
        
    def setup(self) -> bool:
        """Load model and processor."""
        if self._setup_complete:
            return True
            
        try:
            from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
            
            model_name = self.MODELS.get(self.language, self.MODELS["default"])
            print(f"[AlignmentScorer] Loading {model_name}...")
            
            self.processor = Wav2Vec2Processor.from_pretrained(
                model_name,
                cache_dir=self.cache_dir
            )
            self.model = Wav2Vec2ForCTC.from_pretrained(
                model_name,
                cache_dir=self.cache_dir
            )
            
            self.model.to(self.device)
            self.model.eval()
            
            print(f"[AlignmentScorer] Ready on {self.device}")
            self._setup_complete = True
            return True
            
        except Exception as e:
            print(f"[AlignmentScorer] Setup failed: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def _load_audio(self, audio_path: str) -> Tuple[np.ndarray, int]:
        """Load and resample audio to 16kHz."""
        data, sr = sf.read(audio_path)
        
        # Convert to mono if stereo
        if len(data.shape) > 1:
            data = data.mean(axis=1)
        
        # Resample to 16kHz if needed
        if sr != 16000:
            import torchaudio
            waveform = torch.from_numpy(data).float().unsqueeze(0)
            resampler = torchaudio.transforms.Resample(sr, 16000)
            waveform = resampler(waveform)
            data = waveform.squeeze().numpy()
            sr = 16000
        
        return data, sr
    
    def _get_ctc_logits(self, audio: np.ndarray) -> torch.Tensor:
        """Get CTC logits from audio."""
        inputs = self.processor(
            audio,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )
        
        input_values = inputs.input_values.to(self.device)
        
        with torch.no_grad():
            outputs = self.model(input_values)
            logits = outputs.logits.cpu()
        
        return logits
    
    def _compute_frame_probs(self, logits: torch.Tensor) -> np.ndarray:
        """Convert logits to frame-level probabilities."""
        probs = torch.softmax(logits, dim=-1)
        return probs.squeeze(0).numpy()  # [frames, vocab_size]
    
    def _align_words(
        self,
        frame_probs: np.ndarray,
        transcription: str,
        audio_duration: float
    ) -> Tuple[List[WordScore], float]:
        """
        Align words to frames and compute confidence scores.
        
        Uses a simple proportional alignment with confidence from max probs.
        """
        words = transcription.split()
        if not words:
            return [], 0.0
        
        num_frames = frame_probs.shape[0]
        frame_duration = audio_duration / num_frames
        
        # Get max probability per frame (confidence)
        frame_confidences = frame_probs.max(axis=1)
        
        # Proportional word alignment
        total_chars = sum(len(w) for w in words)
        if total_chars == 0:
            return [], 0.0
        
        word_scores = []
        current_frame = 0
        
        for word in words:
            # Calculate frames for this word
            word_frames = int((len(word) / total_chars) * num_frames)
            word_frames = max(1, word_frames)
            
            end_frame = min(current_frame + word_frames, num_frames)
            
            # Get confidence for this word's frames
            if current_frame < num_frames:
                word_frame_probs = frame_confidences[current_frame:end_frame]
                word_confidence = float(word_frame_probs.mean()) if len(word_frame_probs) > 0 else 0.5
            else:
                word_confidence = 0.5
            
            start_time = current_frame * frame_duration
            end_time = end_frame * frame_duration
            
            word_scores.append(WordScore(
                word=word,
                start_time=start_time,
                end_time=end_time,
                confidence=word_confidence,
                char_probs=[]
            ))
            
            current_frame = end_frame
        
        # Overall alignment score
        overall_score = float(frame_confidences.mean())
        
        return word_scores, overall_score
    
    def score_transcription(
        self,
        audio_path: str,
        transcription: str
    ) -> AlignmentResult:
        """
        Score a transcription against audio.
        
        Args:
            audio_path: Path to audio file
            transcription: Text to score
            
        Returns:
            AlignmentResult with scores
        """
        if not self._setup_complete:
            if not self.setup():
                return AlignmentResult(
                    audio_path=audio_path,
                    transcription=transcription,
                    alignment_score=0.0,
                    average_confidence=0.0,
                    min_confidence=0.0,
                    method="failed"
                )
        
        start_time = time.time()
        
        try:
            # Load audio
            audio, sr = self._load_audio(audio_path)
            audio_duration = len(audio) / sr
            
            # Get CTC logits
            logits = self._get_ctc_logits(audio)
            
            # Compute frame probabilities
            frame_probs = self._compute_frame_probs(logits)
            
            # Align words and compute scores
            word_scores, alignment_score = self._align_words(
                frame_probs, transcription, audio_duration
            )
            
            # Calculate confidence stats
            confidences = [ws.confidence for ws in word_scores]
            avg_conf = np.mean(confidences) if confidences else 0.0
            min_conf = np.min(confidences) if confidences else 0.0
            
            processing_time = time.time() - start_time
            
            return AlignmentResult(
                audio_path=audio_path,
                transcription=transcription,
                alignment_score=alignment_score,
                average_confidence=avg_conf,
                min_confidence=min_conf,
                word_scores=word_scores,
                audio_duration_sec=audio_duration,
                processing_time_sec=processing_time,
                method="ctc",
                language=self.language
            )
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return AlignmentResult(
                audio_path=audio_path,
                transcription=transcription,
                alignment_score=0.0,
                average_confidence=0.0,
                min_confidence=0.0,
                processing_time_sec=time.time() - start_time,
                method=f"error: {str(e)}"
            )
    
    def score_batch(
        self,
        items: List[Dict[str, str]]
    ) -> List[AlignmentResult]:
        """
        Score multiple transcriptions.
        
        Args:
            items: List of {"audio_path": ..., "transcription": ...}
            
        Returns:
            List of AlignmentResult
        """
        results = []
        for item in items:
            result = self.score_transcription(
                item["audio_path"],
                item["transcription"]
            )
            results.append(result)
        return results
    
    def cleanup(self):
        """Release model resources."""
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


# === Convenience function ===
def score_transcription(
    audio_path: str,
    transcription: str,
    language: str = "te"
) -> Dict[str, Any]:
    """
    Quick scoring function.
    
    Args:
        audio_path: Path to audio file
        transcription: Text to score
        language: Language code
        
    Returns:
        Dict with scores
    """
    scorer = AlignmentScorer(language=language)
    result = scorer.score_transcription(audio_path, transcription)
    scorer.cleanup()
    return result.to_dict()


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 3:
        print("Usage: python alignment_scorer.py <audio_path> <transcription>")
        sys.exit(1)
    
    audio_path = sys.argv[1]
    transcription = " ".join(sys.argv[2:])
    
    result = score_transcription(audio_path, transcription)
    
    print(f"\nAlignment Score: {result['alignment_score']:.4f}")
    print(f"Average Confidence: {result['average_confidence']:.4f}")
    print(f"Min Confidence: {result['min_confidence']:.4f}")
    print(f"Word Count: {result['word_count']}")
    print(f"Processing Time: {result['processing_time_sec']:.3f}s")
    
    print("\nWord Scores:")
    for ws in result['word_scores'][:10]:
        print(f"  {ws['word']}: {ws['confidence']:.4f} [{ws['start_time']:.2f}-{ws['end_time']:.2f}s]")
