"""
MMS Forced Aligner (Romanized Text)
====================================

Language-agnostic forced alignment using Meta's MMS_FA model via torchaudio.
Works on romanized/Latin-script text. Complements the native CTC aligner.

MMS_FA is trained on 1000+ languages — handles romanized Indic + English equally well.
Used alongside native CTC to produce dual validation scores.

Usage:
    aligner = MMSAligner()
    result = aligner.align("audio.flac", "naaku konni ads gurtuntaayi")
    print(f"Score: {result.alignment_score}")  # 0-1
"""
import torch
import torchaudio
import soundfile as sf
import numpy as np
import time
from dataclasses import dataclass, field
from typing import List, Dict, Optional


@dataclass
class MMSAlignmentResult:
    """Result of MMS forced alignment on romanized text."""
    audio_path: str
    transcription: str

    alignment_score: float  # 0-1, overall quality
    mean_log_prob: float    # Raw mean log prob

    # v4: per-word scores for stricter threshold logic
    word_scores: List[float] = field(default_factory=list)

    num_tokens: int = 0
    num_frames: int = 0
    audio_duration_sec: float = 0.0
    processing_time_sec: float = 0.0

    def to_dict(self) -> Dict:
        return {
            "audio_path": self.audio_path,
            "transcription": self.transcription[:100],
            "alignment_score": round(self.alignment_score, 4),
            "mean_log_prob": round(self.mean_log_prob, 4),
            "word_scores": [round(s, 4) for s in self.word_scores],
            "num_tokens": self.num_tokens,
            "num_frames": self.num_frames,
            "audio_duration_sec": round(self.audio_duration_sec, 3),
            "processing_time_sec": round(self.processing_time_sec, 3),
        }


class MMSAligner:
    """
    Language-agnostic forced alignment using torchaudio MMS_FA.

    Tokenization: space → * (word boundary, idx 28), CTC blank = - (idx 0).
    Score normalization: log_prob range ~[-10, 0] mapped to [0, 1].
    """

    def __init__(self, device: str = "auto"):
        if device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        self.model = None
        self.labels = None
        self.label2idx = None
        self.sample_rate = None
        self._setup_complete = False

        # Token indices (set during setup)
        self.BLANK_IDX = 0   # '-' = CTC blank
        self.STAR_IDX = 28   # '*' = word boundary

    def setup(self) -> bool:
        """Lazy load MMS_FA model."""
        if self._setup_complete:
            return True

        try:
            bundle = torchaudio.pipelines.MMS_FA
            self.model = bundle.get_model().to(self.device)
            self.model.eval()
            self.labels = bundle.get_labels()
            self.sample_rate = bundle.sample_rate  # 16000
            self.label2idx = {l: i for i, l in enumerate(self.labels)}
            self.BLANK_IDX = self.label2idx.get("-", 0)
            self.STAR_IDX = self.label2idx.get("*", 28)

            print(f"[MMSAligner] Ready on {self.device} "
                  f"({len(self.labels)} labels, SR={self.sample_rate})")
            self._setup_complete = True
            return True

        except Exception as e:
            print(f"[MMSAligner] Setup failed: {e}")
            return False

    def _load_audio(self, audio_path: str) -> torch.Tensor:
        """Load and resample audio to MMS sample rate (16kHz)."""
        data, sr = sf.read(audio_path)
        if len(data.shape) > 1:
            data = data.mean(axis=1)
        waveform = torch.tensor(data, dtype=torch.float32).unsqueeze(0)
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(
                waveform, sr, self.sample_rate
            )
        return waveform

    def _text_to_tokens(self, text: str) -> List[int]:
        """
        Convert romanized text to MMS token indices.
        Space → * (word boundary). Unknown chars skipped.
        Punctuation stripped.
        """
        tokens = []
        for ch in text.lower().strip():
            if ch == " ":
                # Word boundary - don't duplicate
                if tokens and tokens[-1] != self.STAR_IDX:
                    tokens.append(self.STAR_IDX)
            elif ch in self.label2idx and self.label2idx[ch] != self.BLANK_IDX:
                tokens.append(self.label2idx[ch])
            # Punctuation and unknown chars silently skipped
        # Strip leading/trailing boundaries
        while tokens and tokens[0] == self.STAR_IDX:
            tokens.pop(0)
        while tokens and tokens[-1] == self.STAR_IDX:
            tokens.pop()
        return tokens

    def _compute_word_scores(
        self, tokens: List[int], alignment: List[int], scores: List[float]
    ) -> List[float]:
        """
        Compute per-word alignment scores from frame-level scores.

        Words are delimited by * (STAR_IDX) tokens. For each word, average
        the log-prob scores of its frames, then convert to 0-1 confidence.
        """
        # Map token indices to word indices
        word_idx = 0
        token_to_word = {}
        for i, tok in enumerate(tokens):
            if tok == self.STAR_IDX:
                word_idx += 1
            else:
                token_to_word[i] = word_idx

        n_words = word_idx + 1
        word_log_probs = [[] for _ in range(n_words)]

        # Collect scores per word from alignment
        token_cursor = 0
        for frame_idx, (aligned_tok, score) in enumerate(zip(alignment, scores)):
            if aligned_tok != self.BLANK_IDX and token_cursor < len(tokens):
                if token_cursor in token_to_word:
                    wi = token_to_word[token_cursor]
                    word_log_probs[wi].append(score)
                token_cursor += 1

        # Average log-prob per word, convert to 0-1
        result = []
        for lps in word_log_probs:
            if lps:
                mean_lp = sum(lps) / len(lps)
                conf = max(0.0, min(1.0, 1.0 + mean_lp / 5.0))
                result.append(conf)
            else:
                result.append(0.0)

        return result

    def align(self, audio_path: str, romanized_text: str) -> MMSAlignmentResult:
        """
        Perform forced alignment of romanized text against audio.

        Args:
            audio_path: Path to audio file
            romanized_text: Latin-script text to align

        Returns:
            MMSAlignmentResult with alignment score
        """
        if not self._setup_complete:
            if not self.setup():
                return MMSAlignmentResult(
                    audio_path=audio_path,
                    transcription=romanized_text,
                    alignment_score=0.0,
                    mean_log_prob=-10.0,
                )

        start_time = time.time()

        try:
            waveform = self._load_audio(audio_path)
            audio_dur = waveform.shape[1] / self.sample_rate
            tokens = self._text_to_tokens(romanized_text)

            if not tokens:
                return MMSAlignmentResult(
                    audio_path=audio_path,
                    transcription=romanized_text,
                    alignment_score=0.0,
                    mean_log_prob=-10.0,
                    audio_duration_sec=audio_dur,
                    processing_time_sec=time.time() - start_time,
                )

            with torch.no_grad():
                emission, _ = self.model(waveform.to(self.device))
            emission = emission.cpu()

            token_tensor = torch.tensor([tokens], dtype=torch.int32)

            alignment, scores = torchaudio.functional.forced_align(
                emission, token_tensor, blank=self.BLANK_IDX
            )

            mean_lp = scores.mean().item()
            alignment_score = max(0.0, min(1.0, 1.0 + mean_lp / 5.0))

            # v4: Extract per-word scores for stricter validation
            # Split scores by word boundaries (* tokens)
            word_scores = self._compute_word_scores(
                tokens, alignment[0].tolist(), scores[0].tolist()
            )

            return MMSAlignmentResult(
                audio_path=audio_path,
                transcription=romanized_text,
                alignment_score=alignment_score,
                mean_log_prob=mean_lp,
                word_scores=word_scores,
                num_tokens=len(tokens),
                num_frames=emission.shape[1],
                audio_duration_sec=audio_dur,
                processing_time_sec=time.time() - start_time,
            )

        except Exception as e:
            import traceback
            traceback.print_exc()
            return MMSAlignmentResult(
                audio_path=audio_path,
                transcription=romanized_text,
                alignment_score=0.0,
                mean_log_prob=-10.0,
                processing_time_sec=time.time() - start_time,
            )

    def cleanup(self):
        """Release resources."""
        if self.model is not None:
            del self.model
            self.model = None
        self._setup_complete = False
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
