"""
Simple Transcription Validator
==============================

Three checks:
1. Character validation - catch garbage/alien characters
2. Native CTC alignment - verify native script matches audio
3. Romanized MMS alignment - verify romanized text matches audio

Combined scoring: S = 0.45*N + 0.55*R - 0.10*abs(N-R)
  N = native CTC score, R = romanized MMS score
  Accept if S >= 0.70 and min(N,R) >= 0.40
  Forced review if abs(N-R) > 0.25 (high validator disagreement)

Usage:
    from src.validators.simple_validator import validate_transcription
    
    result = validate_transcription(
        "audio.flac", "నాకు కొన్ని యాడ్స్",
        romanized_text="naaku konni ads", language="te"
    )
    print(result.status)          # accept / review / retry / reject
    print(result.combined_score)  # 0-1
"""
import re
from dataclasses import dataclass
from typing import List, Dict, Optional, Set


# Unicode ranges for Indic scripts (all 12 supported languages)
SCRIPT_RANGES = {
    "te": {  # Telugu
        "name": "Telugu",
        "ranges": [(0x0C00, 0x0C7F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "hi": {  # Hindi (Devanagari)
        "name": "Hindi",
        "ranges": [(0x0900, 0x097F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "mr": {  # Marathi (Devanagari, same block as Hindi)
        "name": "Marathi",
        "ranges": [(0x0900, 0x097F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "ta": {  # Tamil
        "name": "Tamil",
        "ranges": [(0x0B80, 0x0BFF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "kn": {  # Kannada
        "name": "Kannada",
        "ranges": [(0x0C80, 0x0CFF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "ml": {  # Malayalam
        "name": "Malayalam",
        "ranges": [(0x0D00, 0x0D7F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "bn": {  # Bengali
        "name": "Bengali",
        "ranges": [(0x0980, 0x09FF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "as": {  # Assamese (same Unicode block as Bengali, slight char differences)
        "name": "Assamese",
        "ranges": [(0x0980, 0x09FF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "gu": {  # Gujarati
        "name": "Gujarati",
        "ranges": [(0x0A80, 0x0AFF)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "pa": {  # Punjabi (Gurmukhi)
        "name": "Punjabi",
        "ranges": [(0x0A00, 0x0A7F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "or": {  # Odia
        "name": "Odia",
        "ranges": [(0x0B00, 0x0B7F)],
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
    "en": {  # English (Latin)
        "name": "English",
        "ranges": [(0x0041, 0x005A), (0x0061, 0x007A)],  # A-Z, a-z
        "allow_ascii_punct": True,
        "allow_digits": True,
    },
}

# Common allowed characters across all languages
ALLOWED_COMMON = set(" \t\n.,!?;:'\"()-–—0123456789")

# English characters (for code-mixed text)
ENGLISH_CHARS = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")


def get_valid_chars(language: str, allow_english: bool = True) -> Set[int]:
    """Get set of valid Unicode codepoints for a language."""
    valid = set()
    
    # Add common chars
    for c in ALLOWED_COMMON:
        valid.add(ord(c))
    
    # Add script-specific chars
    lang_config = SCRIPT_RANGES.get(language, SCRIPT_RANGES["te"])
    for start, end in lang_config["ranges"]:
        for cp in range(start, end + 1):
            valid.add(cp)
    
    # Allow English for code-mixed
    if allow_english:
        for c in ENGLISH_CHARS:
            valid.add(ord(c))
    
    return valid


def check_characters(
    text: str,
    language: str = "te",
    allow_english: bool = False  # Changed: native = no English
) -> Dict:
    """
    Check if text contains only valid characters for the language.
    
    For native transcription: allow_english=False (default)
    
    Returns:
        {
            "valid": True/False,
            "invalid_chars": [...],
            "script_ratio": 0.8,  # ratio of native script chars
        }
    """
    if not text or not text.strip():
        return {"valid": False, "invalid_chars": [], "script_ratio": 0, "reason": "empty"}
    
    valid_chars = get_valid_chars(language, allow_english)
    lang_config = SCRIPT_RANGES.get(language, SCRIPT_RANGES["te"])
    
    invalid_chars = []
    script_count = 0
    total_alpha = 0
    
    for i, char in enumerate(text):
        cp = ord(char)
        
        # Check if in native script
        is_native = any(start <= cp <= end for start, end in lang_config["ranges"])
        if is_native:
            script_count += 1
            total_alpha += 1
        
        # Check if invalid (includes English if not allowed)
        elif cp not in valid_chars:
            invalid_chars.append({
                "char": char,
                "codepoint": cp,
                "position": i
            })
            if char in ENGLISH_CHARS:
                total_alpha += 1
    
    return {
        "valid": len(invalid_chars) == 0,
        "invalid_chars": invalid_chars[:10],  # First 10 only
        "invalid_count": len(invalid_chars),
        "script_ratio": script_count / total_alpha if total_alpha > 0 else 0,
    }


@dataclass
class ValidationResult:
    """Result of transcription validation with dual scoring."""
    status: str  # "accept", "review", "retry", "reject"

    # Character check
    char_valid: bool = True
    invalid_chars: List[Dict] = None
    script_ratio: float = 0.0

    # Native CTC alignment (language-specific wav2vec2)
    alignment_score: float = 0.0       # kept for backward compat
    native_ctc_score: float = 0.0
    low_confidence_words: List[str] = None
    low_confidence_ratio: float = 0.0

    # Romanized MMS alignment (language-agnostic torchaudio MMS_FA)
    roman_mms_score: float = 0.0

    # Combined weighted score: S = 0.45*N + 0.55*R - 0.10*abs(N-R)
    combined_score: float = 0.0

    # Reasons
    reasons: List[str] = None

    def to_dict(self) -> Dict:
        return {
            "status": self.status,
            "char_valid": self.char_valid,
            "invalid_chars": self.invalid_chars or [],
            "script_ratio": round(self.script_ratio, 3),
            "native_ctc_score": round(self.native_ctc_score, 4),
            "roman_mms_score": round(self.roman_mms_score, 4),
            "combined_score": round(self.combined_score, 4),
            "alignment_score": round(self.alignment_score, 4),
            "low_confidence_words": self.low_confidence_words or [],
            "low_confidence_ratio": round(self.low_confidence_ratio, 3),
            "reasons": self.reasons or []
        }


# Global aligners (lazy loaded)
_ctc_aligner = None
_mms_aligner = None

def _get_ctc_aligner(language: str):
    """Lazy load native CTC aligner."""
    global _ctc_aligner
    if _ctc_aligner is None:
        from .ctc_forced_aligner import CTCForcedAligner
        _ctc_aligner = CTCForcedAligner(language=language)
    return _ctc_aligner

def _get_mms_aligner():
    """Lazy load romanized MMS aligner."""
    global _mms_aligner
    if _mms_aligner is None:
        from .mms_aligner import MMSAligner
        _mms_aligner = MMSAligner()
    return _mms_aligner


def compute_combined_score(
    native_ctc: float,
    roman_mms: float,
) -> tuple:
    """
    Compute weighted validation score with disagreement penalty.

    Formula: S = 0.45*N + 0.55*R - 0.10*abs(N-R)
    MMS weighted higher (0.55): more stable on code-mixed Indic audio.
    Disagreement penalty: flags when validators diverge (suspicious).

    Returns:
        (combined_score, verdict)
        verdict: "accept" / "review" / "retry" / "reject"
    """
    N, R = native_ctc, roman_mms
    S = 0.45 * N + 0.55 * R - 0.10 * abs(N - R)

    if S >= 0.70 and min(N, R) >= 0.40:
        if abs(N - R) > 0.25:
            # High disagreement — flag even though score passes
            verdict = "review"
        else:
            verdict = "accept"
    elif S >= 0.55:
        verdict = "retry"
    else:
        verdict = "reject"

    return round(S, 4), verdict


def _strip_punctuation(text: str) -> str:
    """Strip punctuation for CTC alignment. Keeps native script + spaces."""
    import re
    # Remove common punctuation marks (keep native script chars, digits, spaces)
    return re.sub(r'[।,.!?;:\'"()\-–—]', '', text).strip()


def validate_transcription(
    audio_path: str,
    transcription: str,
    language: str = "te",
    romanized_text: str = "",
    check_audio: bool = True,
) -> ValidationResult:
    """
    Validate a transcription using dual scoring (native CTC + romanized MMS).

    Runs:
    1. Character validation (instant) on native script
    2. Native CTC forced alignment (language-specific wav2vec2)
    3. Romanized MMS forced alignment (language-agnostic, if romanized_text provided)
    4. Weighted combined score: S = 0.45*N + 0.55*R - 0.10*abs(N-R)

    Args:
        audio_path: Path to audio file
        transcription: Native script text (may include punctuation)
        language: Language code (te, hi, ta, etc.)
        romanized_text: Latin-script romanization (for MMS alignment)
        check_audio: Whether to run audio alignment checks (slower)

    Returns:
        ValidationResult with dual scores and combined verdict
    """
    reasons = []

    # Strip punctuation for validation
    clean_text = _strip_punctuation(transcription)
    clean_roman = _strip_punctuation(romanized_text) if romanized_text else ""

    # === STEP 1: Character validation (instant) ===
    # For English language, allow English chars in native check
    allow_eng = language == "en"
    char_check = check_characters(clean_text, language, allow_english=allow_eng)

    if not char_check["valid"]:
        return ValidationResult(
            status="reject",
            char_valid=False,
            invalid_chars=char_check["invalid_chars"],
            script_ratio=char_check["script_ratio"],
            reasons=["Invalid/alien characters found"]
        )

    if char_check["script_ratio"] < 0.5:
        reasons.append(f"Too few native chars ({char_check['script_ratio']:.0%})")

    # === STEP 2: Native CTC alignment ===
    native_ctc_score = 0.0
    low_conf_words = []
    low_conf_ratio = 0.0

    if check_audio:
        try:
            aligner = _get_ctc_aligner(language)
            ctc_result = aligner.align(audio_path, clean_text)
            native_ctc_score = ctc_result.alignment_score
            low_conf_words = ctc_result.low_confidence_words
            low_conf_ratio = ctc_result.low_confidence_ratio
        except Exception as e:
            reasons.append(f"CTC alignment failed: {str(e)[:50]}")

    # === STEP 3: Romanized MMS alignment ===
    roman_mms_score = 0.0

    if check_audio and clean_roman:
        try:
            mms = _get_mms_aligner()
            mms_result = mms.align(audio_path, clean_roman)
            roman_mms_score = mms_result.alignment_score
        except Exception as e:
            reasons.append(f"MMS alignment failed: {str(e)[:50]}")

    # === STEP 4: Combined scoring ===
    if check_audio and native_ctc_score > 0 and roman_mms_score > 0:
        # Both validators returned scores - use weighted formula
        combined_score, verdict = compute_combined_score(
            native_ctc_score, roman_mms_score
        )
        if verdict == "review" and abs(native_ctc_score - roman_mms_score) > 0.25:
            reasons.append(
                f"Validator disagreement: CTC={native_ctc_score:.2f} "
                f"MMS={roman_mms_score:.2f}"
            )
    elif check_audio and native_ctc_score > 0:
        # CTC only (MMS failed or no romanized text)
        combined_score = native_ctc_score
        if native_ctc_score >= 0.70:
            verdict = "accept"
        elif native_ctc_score >= 0.50:
            verdict = "retry"
        else:
            verdict = "reject"
        if roman_mms_score == 0 and clean_roman:
            reasons.append("MMS failed, using CTC only")
    elif check_audio and roman_mms_score > 0:
        # MMS only (CTC failed)
        combined_score = roman_mms_score
        if roman_mms_score >= 0.70:
            verdict = "accept"
        elif roman_mms_score >= 0.50:
            verdict = "retry"
        else:
            verdict = "reject"
        reasons.append("CTC failed, using MMS only")
    elif check_audio:
        # Both validators returned 0
        combined_score = 0.0
        verdict = "review"
        reasons.append("Both validators returned 0")
    else:
        # No audio check requested
        combined_score = 0.0
        verdict = "review" if reasons else "accept"

    # Character validation failures override
    if not char_check["valid"]:
        verdict = "reject"

    # Append score-based reasons
    if verdict == "retry":
        reasons.append(
            f"Below threshold: combined={combined_score:.2f} "
            f"(CTC={native_ctc_score:.2f}, MMS={roman_mms_score:.2f})"
        )
    elif verdict == "reject" and combined_score > 0:
        reasons.append(
            f"Poor alignment: combined={combined_score:.2f} "
            f"(CTC={native_ctc_score:.2f}, MMS={roman_mms_score:.2f})"
        )

    return ValidationResult(
        status=verdict,
        char_valid=char_check["valid"],
        invalid_chars=char_check.get("invalid_chars"),
        script_ratio=char_check["script_ratio"],
        alignment_score=native_ctc_score,  # backward compat
        native_ctc_score=native_ctc_score,
        roman_mms_score=roman_mms_score,
        combined_score=combined_score,
        low_confidence_words=low_conf_words,
        low_confidence_ratio=low_conf_ratio,
        reasons=reasons if reasons else None,
    )


def cleanup():
    """Release all aligner resources (CTC + MMS)."""
    global _ctc_aligner, _mms_aligner
    if _ctc_aligner is not None:
        _ctc_aligner.cleanup()
        _ctc_aligner = None
    if _mms_aligner is not None:
        _mms_aligner.cleanup()
        _mms_aligner = None


# === Quick validation (no audio check) ===
def quick_validate(transcription: str, language: str = "te") -> Dict:
    """
    Quick character-only validation for NATIVE transcription.
    
    Checks:
    - Not empty
    - All characters in target script (no English/garbage)
    - Sufficient native script content
    
    Args:
        transcription: Native script text to validate
        language: Language code (te, hi, ta, etc.)
        
    Returns:
        {"valid": True/False, "reason": "..."}
    """
    if not transcription or not transcription.strip():
        return {"valid": False, "reason": "Empty transcription"}
    
    char_check = check_characters(transcription, language, allow_english=False)
    
    if not char_check["valid"]:
        invalid = [c['char'] for c in char_check['invalid_chars'][:5]]
        return {
            "valid": False,
            "reason": f"Invalid chars: {invalid}"
        }
    
    if char_check["script_ratio"] < 0.5:
        return {
            "valid": False,
            "reason": f"Too few native chars ({char_check['script_ratio']:.0%})"
        }
    
    return {"valid": True, "script_ratio": char_check["script_ratio"]}


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 3:
        print("""
Simple Validator
================

Usage:
    # Full validation (with audio check)
    python simple_validator.py <audio_path> <transcription>
    
    # Quick validation (character check only)
    python simple_validator.py --quick <transcription>

Examples:
    python simple_validator.py audio.flac "నాకు కొన్ని యాడ్స్ గుర్తుంటాయి"
    python simple_validator.py --quick "నాకు కొన్ని యాడ్స్"
""")
        sys.exit(1)
    
    if sys.argv[1] == "--quick":
        text = " ".join(sys.argv[2:])
        result = quick_validate(text)
        print(f"Valid: {result['valid']}")
        if not result['valid']:
            print(f"Reason: {result['reason']}")
    else:
        audio_path = sys.argv[1]
        transcription = " ".join(sys.argv[2:])
        
        print("Validating...")
        result = validate_transcription(audio_path, transcription)
        cleanup()
        
        print(f"\nStatus: {result.status.upper()}")
        print(f"Character valid: {result.char_valid}")
        print(f"Script ratio: {result.script_ratio:.1%}")
        print(f"Alignment score: {result.alignment_score:.4f}")
        print(f"Low conf words: {result.low_confidence_ratio:.1%}")
        
        if result.reasons:
            print(f"\nReasons for review:")
            for r in result.reasons:
                print(f"  - {r}")
        
        if result.low_confidence_words:
            print(f"\nLow confidence words: {result.low_confidence_words[:5]}")
