"""
Deterministic Romanization Module
==================================

Converts native Indic script text to MMS_FA-compatible Latin romanization
using the uroman library. This replaces Gemini-generated romanization with
a deterministic, reproducible transform.

Why deterministic:
  - Gemini's "romanized as pronounced" is creative and inconsistent across runs
  - uroman produces identical output for identical input, every time
  - MMS_FA was trained with uroman-style romanization, so alignment is optimal

MMS_FA normalization: [a-z' ] only, lowercase, collapsed spaces.
Per: https://docs.pytorch.org/audio/2.8/tutorials/forced_alignment_for_multilingual_data_tutorial.html

Usage:
    from src.romanization import romanize, romanize_for_alignment

    # Basic romanization
    roman = romanize("నాకు కొన్ని యాడ్స్ గుర్తుంటాయి")
    # -> "naaku konni yaadds gurtumttaayi"

    # MMS_FA-ready (stripped to [a-z' ] only)
    aligned = romanize_for_alignment("నాకు కొన్ని, యాడ్స్!")
    # -> "naaku konni yaadds"
"""
import re
from typing import Optional

# Lazy-loaded uroman instance (heavy init: ~3s for data files)
_uroman_instance = None


def _get_uroman():
    """Lazy load uroman. First call ~3s, subsequent calls instant."""
    global _uroman_instance
    if _uroman_instance is None:
        import uroman
        _uroman_instance = uroman.Uroman()
    return _uroman_instance


def romanize(text: str) -> str:
    """
    Romanize text from any script to Latin using uroman.

    Deterministic: same input always produces same output.
    Handles Indic scripts (Telugu, Hindi, Tamil, etc.) + passes Latin through.

    Args:
        text: Text in any script (Indic, Latin, mixed)

    Returns:
        Romanized Latin text
    """
    if not text or not text.strip():
        return ""
    u = _get_uroman()
    return u.romanize_string(text.strip())


def romanize_for_alignment(text: str) -> str:
    """
    Romanize and normalize for MMS_FA forced alignment.

    MMS_FA expects: lowercase [a-z' ] only.
    Steps:
      1. Strip tags like [laugh], [UNK], [INAUDIBLE] etc.
      2. Romanize via uroman
      3. Lowercase
      4. Strip everything except [a-z' ]
      5. Collapse multiple spaces
      6. Strip leading/trailing whitespace

    Args:
        text: Native script text (may include punctuation, tags)

    Returns:
        Normalized romanized text ready for MMS_FA alignment
    """
    if not text or not text.strip():
        return ""

    # Step 1: Strip bracketed tags [laugh] [UNK] [INAUDIBLE] [NO_SPEECH] etc.
    cleaned = re.sub(r'\[[\w_]+\]', '', text)

    # Step 2: Strip punctuation that would confuse uroman
    # Keep native script chars + basic Latin + spaces
    cleaned = re.sub(r'[।,.!?;:\'"()\-–—…]', ' ', cleaned)

    # Step 3: Romanize
    roman = romanize(cleaned)

    # Step 4: MMS_FA normalization - lowercase, [a-z' ] only
    roman = roman.lower()
    roman = re.sub(r"[^a-z' ]", " ", roman)

    # Step 5: Collapse spaces
    roman = re.sub(r'\s+', ' ', roman).strip()

    return roman


def compute_romanization_cer(
    native_text: str,
    gemini_romanized: str
) -> float:
    """
    Compute Character Error Rate between uroman-derived and Gemini-provided romanization.

    Low CER (<3%) = Gemini's romanization is consistent with deterministic version.
    High CER (>10%) = Gemini hallucinated or used non-standard romanization.

    Useful for flagging romanization drift without blocking transcriptions.

    Args:
        native_text: Native script text (source of truth)
        gemini_romanized: Gemini's romanized output

    Returns:
        CER as float (0.0 = identical, 1.0 = completely different)
    """
    # Normalize both to same format
    uroman_roman = romanize_for_alignment(native_text)
    gemini_norm = re.sub(r"[^a-z' ]", " ", gemini_romanized.lower())
    gemini_norm = re.sub(r'\s+', ' ', gemini_norm).strip()

    if not uroman_roman:
        return 1.0 if gemini_norm else 0.0

    # Simple character-level edit distance (Levenshtein)
    # For CER we don't need a library - implement inline
    ref = uroman_roman
    hyp = gemini_norm
    n = len(ref)
    m = len(hyp)

    if n == 0:
        return 1.0 if m > 0 else 0.0

    # DP table (space-optimized)
    prev = list(range(m + 1))
    for i in range(1, n + 1):
        curr = [i] + [0] * m
        for j in range(1, m + 1):
            cost = 0 if ref[i-1] == hyp[j-1] else 1
            curr[j] = min(
                curr[j-1] + 1,      # insertion
                prev[j] + 1,        # deletion
                prev[j-1] + cost    # substitution
            )
        prev = curr

    return min(1.0, prev[m] / n)


def cleanup():
    """Release uroman resources."""
    global _uroman_instance
    _uroman_instance = None
