"""
MMS Forced Aligner (Romanized Text)
====================================

Language-agnostic forced alignment using Meta's MMS_FA model via torchaudio.
Works on romanized/Latin-script text. Complements the native CTC aligner.

MMS_FA is trained on 1000+ languages — handles romanized Indic + English equally well.
Used alongside native CTC to produce dual validation scores.

Usage:
    aligner = MMSAligner()
    result = aligner.align("audio.flac", "naaku konni ads gurtuntaayi")
    print(f"Score: {result.alignment_score}")  # 0-1
"""
import torch
import torchaudio
import soundfile as sf
import numpy as np
import time
from dataclasses import dataclass, field
from typing import List, Dict, Optional


@dataclass
class MMSAlignmentResult:
    """Result of MMS forced alignment on romanized text."""
    audio_path: str
    transcription: str

    alignment_score: float  # 0-1, overall quality
    mean_log_prob: float    # Raw mean log prob

    num_tokens: int = 0
    num_frames: int = 0
    audio_duration_sec: float = 0.0
    processing_time_sec: float = 0.0

    def to_dict(self) -> Dict:
        return {
            "audio_path": self.audio_path,
            "transcription": self.transcription[:100],
            "alignment_score": round(self.alignment_score, 4),
            "mean_log_prob": round(self.mean_log_prob, 4),
            "num_tokens": self.num_tokens,
            "num_frames": self.num_frames,
            "audio_duration_sec": round(self.audio_duration_sec, 3),
            "processing_time_sec": round(self.processing_time_sec, 3),
        }


class MMSAligner:
    """
    Language-agnostic forced alignment using torchaudio MMS_FA.

    Tokenization: space → * (word boundary, idx 28), CTC blank = - (idx 0).
    Score normalization: log_prob range ~[-10, 0] mapped to [0, 1].
    """

    def __init__(self, device: str = "auto"):
        if device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        self.model = None
        self.labels = None
        self.label2idx = None
        self.sample_rate = None
        self._setup_complete = False

        # Token indices (set during setup)
        self.BLANK_IDX = 0   # '-' = CTC blank
        self.STAR_IDX = 28   # '*' = word boundary

    def setup(self) -> bool:
        """Lazy load MMS_FA model."""
        if self._setup_complete:
            return True

        try:
            bundle = torchaudio.pipelines.MMS_FA
            self.model = bundle.get_model().to(self.device)
            self.model.eval()
            self.labels = bundle.get_labels()
            self.sample_rate = bundle.sample_rate  # 16000
            self.label2idx = {l: i for i, l in enumerate(self.labels)}
            self.BLANK_IDX = self.label2idx.get("-", 0)
            self.STAR_IDX = self.label2idx.get("*", 28)

            print(f"[MMSAligner] Ready on {self.device} "
                  f"({len(self.labels)} labels, SR={self.sample_rate})")
            self._setup_complete = True
            return True

        except Exception as e:
            print(f"[MMSAligner] Setup failed: {e}")
            return False

    def _load_audio(self, audio_path: str) -> torch.Tensor:
        """Load and resample audio to MMS sample rate (16kHz)."""
        data, sr = sf.read(audio_path)
        if len(data.shape) > 1:
            data = data.mean(axis=1)
        waveform = torch.tensor(data, dtype=torch.float32).unsqueeze(0)
        if sr != self.sample_rate:
            waveform = torchaudio.functional.resample(
                waveform, sr, self.sample_rate
            )
        return waveform

    def _text_to_tokens(self, text: str) -> List[int]:
        """
        Convert romanized text to MMS token indices.
        Space → * (word boundary). Unknown chars skipped.
        Punctuation stripped.
        """
        tokens = []
        for ch in text.lower().strip():
            if ch == " ":
                # Word boundary - don't duplicate
                if tokens and tokens[-1] != self.STAR_IDX:
                    tokens.append(self.STAR_IDX)
            elif ch in self.label2idx and self.label2idx[ch] != self.BLANK_IDX:
                tokens.append(self.label2idx[ch])
            # Punctuation and unknown chars silently skipped
        # Strip leading/trailing boundaries
        while tokens and tokens[0] == self.STAR_IDX:
            tokens.pop(0)
        while tokens and tokens[-1] == self.STAR_IDX:
            tokens.pop()
        return tokens

    def align(self, audio_path: str, romanized_text: str) -> MMSAlignmentResult:
        """
        Perform forced alignment of romanized text against audio.

        Args:
            audio_path: Path to audio file
            romanized_text: Latin-script text to align

        Returns:
            MMSAlignmentResult with alignment score
        """
        if not self._setup_complete:
            if not self.setup():
                return MMSAlignmentResult(
                    audio_path=audio_path,
                    transcription=romanized_text,
                    alignment_score=0.0,
                    mean_log_prob=-10.0,
                )

        start_time = time.time()

        try:
            waveform = self._load_audio(audio_path)
            audio_dur = waveform.shape[1] / self.sample_rate
            tokens = self._text_to_tokens(romanized_text)

            if not tokens:
                return MMSAlignmentResult(
                    audio_path=audio_path,
                    transcription=romanized_text,
                    alignment_score=0.0,
                    mean_log_prob=-10.0,
                    audio_duration_sec=audio_dur,
                    processing_time_sec=time.time() - start_time,
                )

            with torch.no_grad():
                emission, _ = self.model(waveform.to(self.device))
            emission = emission.cpu()

            token_tensor = torch.tensor([tokens], dtype=torch.int32)

            _, scores = torchaudio.functional.forced_align(
                emission, token_tensor, blank=self.BLANK_IDX
            )

            mean_lp = scores.mean().item()
            # Normalize: log probs range ~[-10, 0]. Map to [0, 1].
            # Same formula as CTC aligner for consistency.
            alignment_score = max(0.0, min(1.0, 1.0 + mean_lp / 5.0))

            return MMSAlignmentResult(
                audio_path=audio_path,
                transcription=romanized_text,
                alignment_score=alignment_score,
                mean_log_prob=mean_lp,
                num_tokens=len(tokens),
                num_frames=emission.shape[1],
                audio_duration_sec=audio_dur,
                processing_time_sec=time.time() - start_time,
            )

        except Exception as e:
            import traceback
            traceback.print_exc()
            return MMSAlignmentResult(
                audio_path=audio_path,
                transcription=romanized_text,
                alignment_score=0.0,
                mean_log_prob=-10.0,
                processing_time_sec=time.time() - start_time,
            )

    def cleanup(self):
        """Release resources."""
        if self.model is not None:
            del self.model
            self.model = None
        self._setup_complete = False
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
