"""
Simple Transcription Validator (v4)
====================================

Architecture change: validator now derives its own romanization via uroman
instead of trusting Gemini's creative romanization. This makes validation
deterministic and reproducible.

Four checks:
  Step 0: Structural sanity (no ML, instant) - unicode, tags, length/duration
  Step 1: Character validation - catch garbage/alien characters
  Step 2: Native CTC alignment - verify native script matches audio
  Step 3: Romanized MMS alignment - verify uroman-derived text matches audio

Combined scoring: S = 0.45*N + 0.55*R - 0.10*abs(N-R)
  N = native CTC score, R = romanized MMS score (using uroman, not Gemini)

Stricter thresholds with per-word analysis:
  ACCEPT: avg >= 0.85 AND min_internal_word >= 0.70
  REVIEW: avg in [0.70, 0.85) or boundary issues
  RETRY:  avg in [0.55, 0.70)
  REJECT: avg < 0.55 or structural failure

Usage:
    from src.validators.simple_validator import validate_transcription
    result = validate_transcription("audio.flac", native_text, language="te")
    # result.status: "accept" / "review" / "retry" / "reject"
    # result.combined_score: 0-1
"""
import re
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Set


# Unicode ranges for Indic scripts (all 12 supported languages)
SCRIPT_RANGES = {
    "te": {
        "name": "Telugu",
        "ranges": [(0x0C00, 0x0C7F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "hi": {
        "name": "Hindi",
        "ranges": [(0x0900, 0x097F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "mr": {
        "name": "Marathi",
        "ranges": [(0x0900, 0x097F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "ta": {
        "name": "Tamil",
        "ranges": [(0x0B80, 0x0BFF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "kn": {
        "name": "Kannada",
        "ranges": [(0x0C80, 0x0CFF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "ml": {
        "name": "Malayalam",
        "ranges": [(0x0D00, 0x0D7F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "bn": {
        "name": "Bengali",
        "ranges": [(0x0980, 0x09FF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "as": {
        "name": "Assamese",
        "ranges": [(0x0980, 0x09FF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "gu": {
        "name": "Gujarati",
        "ranges": [(0x0A80, 0x0AFF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "pa": {
        "name": "Punjabi",
        "ranges": [(0x0A00, 0x0A7F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "or": {
        "name": "Odia",
        "ranges": [(0x0B00, 0x0B7F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "en": {
        "name": "English",
        "ranges": [(0x0041, 0x005A), (0x0061, 0x007A)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
}

ALLOWED_COMMON = set(" \t\n.,!?;:'\"()-\u2013\u2014\u0964\u0965" + "0123456789")
ENGLISH_CHARS = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")

# Valid bracketed tags that Gemini may insert
VALID_TAGS = {
    "[laugh]", "[cough]", "[sigh]", "[breath]", "[singing]",
    "[noise]", "[music]", "[applause]", "[UNK]", "[INAUDIBLE]", "[NO_SPEECH]"
}


def get_valid_chars(language: str, allow_english: bool = True) -> Set[int]:
    """Get set of valid Unicode codepoints for a language."""
    valid = set()
    for c in ALLOWED_COMMON:
        valid.add(ord(c))
    lang_config = SCRIPT_RANGES.get(language, SCRIPT_RANGES["te"])
    for start, end in lang_config["ranges"]:
        for cp in range(start, end + 1):
            valid.add(cp)
    if allow_english:
        for c in ENGLISH_CHARS:
            valid.add(ord(c))
    return valid


def check_characters(
    text: str,
    language: str = "te",
    allow_english: bool = False
) -> Dict:
    """Check if text contains only valid characters for the language."""
    if not text or not text.strip():
        return {"valid": False, "invalid_chars": [], "script_ratio": 0, "reason": "empty"}

    valid_chars = get_valid_chars(language, allow_english)
    lang_config = SCRIPT_RANGES.get(language, SCRIPT_RANGES["te"])

    invalid_chars = []
    script_count = 0
    total_alpha = 0

    for i, char in enumerate(text):
        cp = ord(char)
        is_native = any(start <= cp <= end for start, end in lang_config["ranges"])
        if is_native:
            script_count += 1
            total_alpha += 1
        elif cp not in valid_chars:
            invalid_chars.append({"char": char, "codepoint": cp, "position": i})
            if char in ENGLISH_CHARS:
                total_alpha += 1

    return {
        "valid": len(invalid_chars) == 0,
        "invalid_chars": invalid_chars[:10],
        "invalid_count": len(invalid_chars),
        "script_ratio": script_count / total_alpha if total_alpha > 0 else 0,
    }


# === Step 0: Structural sanity checks (no ML, instant) ===

def structural_sanity_check(
    text: str,
    language: str = "te",
    duration_sec: float = 0.0,
) -> Dict:
    """
    Fast structural checks before any ML model runs.
    Catches obvious garbage without wasting GPU cycles.

    Checks:
    1. Unicode block: native text must be in target script (no Latin in non-English)
    2. Tag format: only allowed bracketed tags, properly formatted
    3. Length vs duration: reject impossibly dense text (>15 words/sec)
    4. Empty/whitespace-only: instant reject

    Returns: {"pass": bool, "reasons": [...]}
    """
    reasons = []

    if not text or not text.strip():
        return {"pass": False, "reasons": ["empty_text"]}

    # Special markers pass through
    stripped = text.strip()
    if stripped in ("[NO_SPEECH]", "[INAUDIBLE]"):
        return {"pass": True, "reasons": []}

    # Check for stray bracketed tags that aren't in the allowed set
    found_tags = re.findall(r'\[\w+\]', text)
    for tag in found_tags:
        if tag not in VALID_TAGS:
            reasons.append(f"invalid_tag:{tag}")

    # Unicode block check for native transcription
    # Non-English text should be primarily native script, not Latin
    if language != "en":
        latin_chars = sum(1 for c in text if c in ENGLISH_CHARS)
        native_chars = 0
        lang_config = SCRIPT_RANGES.get(language, SCRIPT_RANGES["te"])
        for c in text:
            cp = ord(c)
            if any(start <= cp <= end for start, end in lang_config["ranges"]):
                native_chars += 1
        total = latin_chars + native_chars
        # v5: transcription is code-mixed, so lower threshold - heavy English mixing is OK
        if total > 0 and native_chars / total < 0.1:
            reasons.append(f"too_few_native_chars:{native_chars}/{total}")

    # Length vs duration sanity (if duration provided)
    if duration_sec > 0:
        word_count = len(text.split())
        words_per_sec = word_count / max(duration_sec, 0.1)
        # Normal speech: 2-4 words/sec. Allow up to 15 for fast speech.
        # >15 words/sec for audio = almost certainly hallucination.
        if words_per_sec > 15:
            reasons.append(f"impossibly_dense:{words_per_sec:.1f}w/s")
        # Very long text for short audio = suspicious
        if duration_sec < 1.5 and word_count > 10:
            reasons.append(f"too_many_words_for_short_audio:{word_count}w/{duration_sec:.1f}s")

    return {"pass": len(reasons) == 0, "reasons": reasons}


@dataclass
class ValidationResult:
    """Result of transcription validation with dual scoring."""
    status: str  # "accept", "review", "retry", "reject"

    # Character check
    char_valid: bool = True
    invalid_chars: List[Dict] = None
    script_ratio: float = 0.0

    # Native CTC alignment (language-specific wav2vec2)
    alignment_score: float = 0.0
    native_ctc_score: float = 0.0
    low_confidence_words: List[str] = None
    low_confidence_ratio: float = 0.0

    # Romanized MMS alignment (uroman-derived, not Gemini)
    roman_mms_score: float = 0.0

    # Per-word scoring from MMS (v4: for stricter threshold logic)
    mms_min_word_score: float = 0.0
    mms_boundary_word_avg: float = 0.0
    mms_internal_below_threshold: float = 0.0

    # Combined weighted score: S = 0.45*N + 0.55*R - 0.10*abs(N-R)
    combined_score: float = 0.0

    # Structural sanity (v4)
    structural_pass: bool = True
    structural_reasons: List[str] = None

    # Romanization metadata (v4)
    uroman_romanized: str = ""

    # Reasons
    reasons: List[str] = None

    def to_dict(self) -> Dict:
        return {
            "status": self.status,
            "char_valid": self.char_valid,
            "invalid_chars": self.invalid_chars or [],
            "script_ratio": round(self.script_ratio, 3),
            "native_ctc_score": round(self.native_ctc_score, 4),
            "roman_mms_score": round(self.roman_mms_score, 4),
            "combined_score": round(self.combined_score, 4),
            "alignment_score": round(self.alignment_score, 4),
            "mms_min_word_score": round(self.mms_min_word_score, 4),
            "mms_boundary_word_avg": round(self.mms_boundary_word_avg, 4),
            "low_confidence_words": self.low_confidence_words or [],
            "low_confidence_ratio": round(self.low_confidence_ratio, 3),
            "structural_pass": self.structural_pass,
            "reasons": self.reasons or []
        }


# Global aligners (lazy loaded)
_ctc_aligner = None
_mms_aligner = None

def _get_ctc_aligner(language: str):
    """Lazy load native CTC aligner."""
    global _ctc_aligner
    if _ctc_aligner is None:
        from .ctc_forced_aligner import CTCForcedAligner
        _ctc_aligner = CTCForcedAligner(language=language)
    return _ctc_aligner

def _get_mms_aligner():
    """Lazy load romanized MMS aligner."""
    global _mms_aligner
    if _mms_aligner is None:
        from .mms_aligner import MMSAligner
        _mms_aligner = MMSAligner()
    return _mms_aligner


def compute_combined_score(
    native_ctc: float,
    roman_mms: float,
) -> tuple:
    """
    Compute weighted validation score with disagreement penalty.

    Formula: S = 0.45*N + 0.55*R - 0.10*abs(N-R)
    MMS weighted higher (0.55): more stable on code-mixed Indic audio.
    Disagreement penalty: flags when validators diverge.

    v4 stricter thresholds (calibrated against test segments):
      ACCEPT: S >= 0.75 AND min(N,R) >= 0.50
      REVIEW: S in [0.65, 0.75) OR high disagreement
      RETRY:  S in [0.55, 0.65)
      REJECT: S < 0.55
    """
    N, R = native_ctc, roman_mms
    S = 0.45 * N + 0.55 * R - 0.10 * abs(N - R)

    if S >= 0.75 and min(N, R) >= 0.50:
        if abs(N - R) > 0.25:
            verdict = "review"
        else:
            verdict = "accept"
    elif S >= 0.65:
        verdict = "review"
    elif S >= 0.55:
        verdict = "retry"
    else:
        verdict = "reject"

    return round(S, 4), verdict


def _strip_tags_and_punct(text: str) -> str:
    """Strip tags and punctuation for alignment. Keeps native script + Latin + spaces."""
    cleaned = re.sub(r'\[[\w_]+\]', '', text)
    cleaned = re.sub(r'[\u0964.,!?;:\'"()\-\u2013\u2014]', '', cleaned)
    return re.sub(r'\s+', ' ', cleaned).strip()


def _strip_to_native_only(text: str, language: str) -> str:
    """Strip Latin characters for native CTC alignment on code-mixed text.
    CTC models are language-specific and choke on English words mixed in.
    We remove Latin words and only align the native-script portions."""
    cleaned = _strip_tags_and_punct(text)
    if language == "en":
        return cleaned
    lang_config = SCRIPT_RANGES.get(language, SCRIPT_RANGES["te"])
    result = []
    for char in cleaned:
        cp = ord(char)
        is_native = any(start <= cp <= end for start, end in lang_config["ranges"])
        if is_native or char in ' \t':
            result.append(char)
    return re.sub(r'\s+', ' ', ''.join(result)).strip()


def validate_transcription(
    audio_path: str,
    transcription: str,
    language: str = "te",
    romanized_text: str = "",
    check_audio: bool = True,
    duration_sec: float = 0.0,
) -> ValidationResult:
    """
    Validate a transcription (v4: uroman-based, stricter, structural checks).

    Flow:
      Step 0: Structural sanity checks (instant, no ML)
      Step 1: Character validation (instant)
      Step 2: Native CTC forced alignment (language-specific wav2vec2)
      Step 3: Romanized MMS forced alignment (uroman-derived, NOT Gemini's)
      Step 4: Weighted combined score with per-word analysis

    Key v4 change: romanized_text parameter is IGNORED for MMS alignment.
    Validator derives its own romanization via uroman from the native text.
    This makes validation deterministic and independent of Gemini's output.

    Args:
        audio_path: Path to audio file
        transcription: Native script text
        language: Language code (te, hi, ta, etc.)
        romanized_text: IGNORED in v4 (kept for API compat). Validator uses uroman.
        check_audio: Whether to run audio alignment checks
        duration_sec: Audio duration (for structural checks, optional)
    """
    reasons = []

    # Clean text for alignment
    clean_text = _strip_tags_and_punct(transcription)
    # v5: transcription is code-mixed, so strip to native-only for CTC
    native_only = _strip_to_native_only(transcription, language)

    # === STEP 0: Structural sanity checks (no ML, instant) ===
    structural = structural_sanity_check(clean_text, language, duration_sec)
    if not structural["pass"]:
        return ValidationResult(
            status="reject",
            structural_pass=False,
            structural_reasons=structural["reasons"],
            reasons=[f"structural_fail: {', '.join(structural['reasons'])}"]
        )

    # === STEP 1: Character validation (instant) ===
    # v5: transcription is code-mixed, always allow English chars
    char_check = check_characters(clean_text, language, allow_english=True)

    if not char_check["valid"]:
        return ValidationResult(
            status="reject",
            char_valid=False,
            invalid_chars=char_check["invalid_chars"],
            script_ratio=char_check["script_ratio"],
            reasons=["Invalid/alien characters found"]
        )

    if char_check["script_ratio"] < 0.5:
        reasons.append(f"Too few native chars ({char_check['script_ratio']:.0%})")

    # === STEP 1.5: Derive uroman romanization (deterministic) ===
    # This replaces Gemini's creative romanization with stable, reproducible text
    from src.romanization import romanize_for_alignment
    uroman_roman = romanize_for_alignment(transcription)

    # === STEP 2: Native CTC alignment ===
    native_ctc_score = 0.0
    low_conf_words = []
    low_conf_ratio = 0.0

    if check_audio and native_only:
        try:
            aligner = _get_ctc_aligner(language)
            ctc_result = aligner.align(audio_path, native_only)
            native_ctc_score = ctc_result.alignment_score
            low_conf_words = ctc_result.low_confidence_words
            low_conf_ratio = ctc_result.low_confidence_ratio
        except Exception as e:
            reasons.append(f"CTC alignment failed: {str(e)[:50]}")

    # === STEP 3: Romanized MMS alignment (uroman-derived) ===
    roman_mms_score = 0.0
    mms_min_word = 0.0
    mms_boundary_avg = 0.0
    mms_internal_below = 0.0

    if check_audio and uroman_roman:
        try:
            mms = _get_mms_aligner()
            mms_result = mms.align(audio_path, uroman_roman)
            roman_mms_score = mms_result.alignment_score

            # v4: Extract per-word metrics for stricter analysis
            if hasattr(mms_result, 'word_scores') and mms_result.word_scores:
                scores = mms_result.word_scores
                if scores:
                    mms_min_word = min(scores)
                    # Boundary words: first + last
                    boundary = []
                    if len(scores) >= 1:
                        boundary.append(scores[0])
                    if len(scores) >= 2:
                        boundary.append(scores[-1])
                    mms_boundary_avg = sum(boundary) / len(boundary) if boundary else 0
                    # Internal words below threshold
                    internal = scores[1:-1] if len(scores) > 2 else []
                    if internal:
                        below = sum(1 for s in internal if s < 0.60)
                        mms_internal_below = below / len(internal)
        except Exception as e:
            reasons.append(f"MMS alignment failed: {str(e)[:50]}")

    # === STEP 4: Combined scoring ===
    if check_audio and native_ctc_score > 0 and roman_mms_score > 0:
        combined_score, verdict = compute_combined_score(
            native_ctc_score, roman_mms_score
        )
        if verdict == "review" and abs(native_ctc_score - roman_mms_score) > 0.25:
            reasons.append(
                f"Validator disagreement: CTC={native_ctc_score:.2f} "
                f"MMS={roman_mms_score:.2f}"
            )
    elif check_audio and native_ctc_score > 0:
        combined_score = native_ctc_score
        if native_ctc_score >= 0.75:
            verdict = "accept"
        elif native_ctc_score >= 0.55:
            verdict = "retry"
        else:
            verdict = "reject"
        if roman_mms_score == 0 and uroman_roman:
            reasons.append("MMS failed, using CTC only")
    elif check_audio and roman_mms_score > 0:
        combined_score = roman_mms_score
        if roman_mms_score >= 0.75:
            verdict = "accept"
        elif roman_mms_score >= 0.55:
            verdict = "retry"
        else:
            verdict = "reject"
        reasons.append("CTC failed, using MMS only")
    elif check_audio:
        combined_score = 0.0
        verdict = "review"
        reasons.append("Both validators returned 0")
    else:
        combined_score = 0.0
        verdict = "review" if reasons else "accept"

    # Character validation failures override
    if not char_check["valid"]:
        verdict = "reject"

    # Append score-based reasons
    if verdict == "retry":
        reasons.append(
            f"Below threshold: combined={combined_score:.2f} "
            f"(CTC={native_ctc_score:.2f}, MMS={roman_mms_score:.2f})"
        )
    elif verdict == "reject" and combined_score > 0:
        reasons.append(
            f"Poor alignment: combined={combined_score:.2f} "
            f"(CTC={native_ctc_score:.2f}, MMS={roman_mms_score:.2f})"
        )

    return ValidationResult(
        status=verdict,
        char_valid=char_check["valid"],
        invalid_chars=char_check.get("invalid_chars"),
        script_ratio=char_check["script_ratio"],
        alignment_score=native_ctc_score,
        native_ctc_score=native_ctc_score,
        roman_mms_score=roman_mms_score,
        combined_score=combined_score,
        mms_min_word_score=mms_min_word,
        mms_boundary_word_avg=mms_boundary_avg,
        mms_internal_below_threshold=mms_internal_below,
        low_confidence_words=low_conf_words,
        low_confidence_ratio=low_conf_ratio,
        structural_pass=True,
        uroman_romanized=uroman_roman,
        reasons=reasons if reasons else None,
    )


def cleanup():
    """Release all aligner resources (CTC + MMS + uroman)."""
    global _ctc_aligner, _mms_aligner
    if _ctc_aligner is not None:
        _ctc_aligner.cleanup()
        _ctc_aligner = None
    if _mms_aligner is not None:
        _mms_aligner.cleanup()
        _mms_aligner = None
    # Also cleanup uroman cache
    from src.romanization import cleanup as roman_cleanup
    roman_cleanup()


# === Quick validation (no audio check) ===
def quick_validate(transcription: str, language: str = "te") -> Dict:
    """Quick character-only + structural validation for native transcription."""
    if not transcription or not transcription.strip():
        return {"valid": False, "reason": "Empty transcription"}

    # Structural check
    structural = structural_sanity_check(transcription, language)
    if not structural["pass"]:
        return {"valid": False, "reason": f"Structural: {structural['reasons']}"}

    char_check = check_characters(transcription, language, allow_english=False)

    if not char_check["valid"]:
        invalid = [c['char'] for c in char_check['invalid_chars'][:5]]
        return {"valid": False, "reason": f"Invalid chars: {invalid}"}

    if char_check["script_ratio"] < 0.5:
        return {"valid": False, "reason": f"Too few native chars ({char_check['script_ratio']:.0%})"}

    return {"valid": True, "script_ratio": char_check["script_ratio"]}


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 3:
        print("""
Simple Validator v4
====================

Usage:
    # Full validation (with audio check, uroman-based)
    python simple_validator.py <audio_path> <transcription>

    # Quick validation (character + structural check only)
    python simple_validator.py --quick <transcription>
""")
        sys.exit(1)

    if sys.argv[1] == "--quick":
        text = " ".join(sys.argv[2:])
        result = quick_validate(text)
        print(f"Valid: {result['valid']}")
        if not result['valid']:
            print(f"Reason: {result['reason']}")
    else:
        audio_path = sys.argv[1]
        transcription = " ".join(sys.argv[2:])

        print("Validating (v4: uroman-based)...")
        result = validate_transcription(audio_path, transcription)
        cleanup()

        print(f"\nStatus: {result.status.upper()}")
        print(f"Character valid: {result.char_valid}")
        print(f"Script ratio: {result.script_ratio:.1%}")
        print(f"CTC score: {result.native_ctc_score:.4f}")
        print(f"MMS score: {result.roman_mms_score:.4f}")
        print(f"Combined: {result.combined_score:.4f}")
        if result.uroman_romanized:
            print(f"Uroman: {result.uroman_romanized[:80]}")
        if result.reasons:
            print(f"\nReasons:")
            for r in result.reasons:
                print(f"  - {r}")
