"""
Tier 1 programmatic validation: instant, free, runs on every segment.
Computes quality_score (0-1) and lane flags.
Optional Tier 1.5 GPU validation is flag-gated.
"""
from __future__ import annotations

import logging
import re
import unicodedata
from dataclasses import dataclass
from typing import Optional

from .config import LANGUAGE_MAP, AUDIO_EVENT_TAGS, VALIDATOR_VERSION

logger = logging.getLogger(__name__)

# Script->unicode block mapping for script checking
SCRIPT_RANGES = {
    "Devanagari": (0x0900, 0x097F),
    "Telugu": (0x0C00, 0x0C7F),
    "Tamil": (0x0B80, 0x0BFF),
    "Kannada": (0x0C80, 0x0CFF),
    "Malayalam": (0x0D00, 0x0D7F),
    "Gujarati": (0x0A80, 0x0AFF),
    "Gurmukhi": (0x0A00, 0x0A7F),
    "Bengali": (0x0980, 0x09FF),
    "Assamese": (0x0980, 0x09FF),  # shares Bengali block
    "Odia": (0x0B00, 0x0B7F),
    "Latin": (0x0000, 0x007F),
}

TAG_PATTERN = re.compile(r"\[(" + "|".join(AUDIO_EVENT_TAGS) + r")\]", re.IGNORECASE)
UNK_PATTERN = re.compile(r"\[UNK\]", re.IGNORECASE)
INAUDIBLE_PATTERN = re.compile(r"\[INAUDIBLE\]", re.IGNORECASE)
NO_SPEECH_PATTERN = re.compile(r"\[NO_SPEECH\]", re.IGNORECASE)


@dataclass
class ValidationResult:
    segment_id: str
    quality_score: float = 1.0
    is_empty: bool = False
    is_no_speech: bool = False
    chars_per_second: float = 0.0
    length_ratio_ok: bool = True
    script_check_ok: bool = True
    lang_mismatch: bool = False
    tag_consistency_ok: bool = True
    num_unk: int = 0
    num_inaudible: int = 0
    num_event_tags: int = 0
    overlap_suspected: bool = False
    boundary_score: float = 1.0
    asr_eligible: bool = True
    tts_clean_eligible: bool = True
    tts_expressive_eligible: bool = True
    flags: list[str] = None
    validator_version: str = VALIDATOR_VERSION

    def __post_init__(self):
        if self.flags is None:
            self.flags = []


def validate_transcription(
    segment_id: str,
    transcription_data: dict,
    expected_language: str,
    audio_duration_s: float,
    trim_meta: Optional[dict] = None,
) -> ValidationResult:
    """Run Tier 1 programmatic checks on a single transcription result."""
    result = ValidationResult(segment_id=segment_id)

    transcription = transcription_data.get("transcription", "")
    tagged = transcription_data.get("tagged", "")
    detected_lang = transcription_data.get("detected_language", "")
    speaker = transcription_data.get("speaker", {})

    # 1. Empty / NO_SPEECH check
    if not transcription or not transcription.strip():
        result.is_empty = True
        result.quality_score = 0.0
        result.asr_eligible = False
        result.tts_clean_eligible = False
        result.tts_expressive_eligible = False
        result.flags.append("empty_transcription")
        return result

    if NO_SPEECH_PATTERN.search(transcription):
        result.is_no_speech = True
        result.quality_score = 0.5  # valid but no speech
        result.asr_eligible = False
        result.tts_clean_eligible = False
        result.tts_expressive_eligible = False
        return result

    # 2. Length ratio (chars per second)
    clean_text = TAG_PATTERN.sub("", transcription).strip()
    clean_text = UNK_PATTERN.sub("", clean_text).strip()
    clean_text = INAUDIBLE_PATTERN.sub("", clean_text).strip()

    if audio_duration_s > 0:
        result.chars_per_second = len(clean_text) / audio_duration_s
        # Reasonable range: 2-30 chars/sec for Indic languages
        if result.chars_per_second < 1.0 or result.chars_per_second > 50.0:
            result.length_ratio_ok = False
            result.flags.append(f"suspicious_length_ratio:{result.chars_per_second:.1f}")

    # 3. Script check
    if expected_language in LANGUAGE_MAP:
        _, script_name, _ = LANGUAGE_MAP[expected_language]
        if script_name in SCRIPT_RANGES and script_name != "Latin":
            lo, hi = SCRIPT_RANGES[script_name]
            script_chars = sum(1 for c in clean_text if lo <= ord(c) <= hi)
            latin_chars = sum(1 for c in clean_text if unicodedata.category(c).startswith("L") and ord(c) < 0x0080)
            total_alpha = script_chars + latin_chars
            if total_alpha > 0:
                expected_ratio = script_chars / total_alpha
                # Code-mixed is fine, but if <10% expected script and not English, flag
                if expected_ratio < 0.1 and expected_language != "en":
                    result.script_check_ok = False
                    result.flags.append(f"low_expected_script_ratio:{expected_ratio:.2f}")

    # 4. Language mismatch
    if detected_lang and detected_lang != expected_language:
        result.lang_mismatch = True
        result.flags.append(f"lang_mismatch:expected={expected_language},detected={detected_lang}")

    # 5. Tag consistency
    stripped_tagged = TAG_PATTERN.sub("", tagged).strip()
    stripped_transcription = transcription.strip()
    if stripped_tagged != stripped_transcription:
        # Allow minor whitespace differences
        if stripped_tagged.replace(" ", "") != stripped_transcription.replace(" ", ""):
            result.tag_consistency_ok = False
            result.flags.append("tag_text_mismatch")

    # 6. UNK / INAUDIBLE density
    result.num_unk = len(UNK_PATTERN.findall(transcription))
    result.num_inaudible = len(INAUDIBLE_PATTERN.findall(transcription))
    result.num_event_tags = len(TAG_PATTERN.findall(tagged))

    total_words = len(clean_text.split())
    if total_words > 0:
        unk_density = (result.num_unk + result.num_inaudible) / total_words
        if unk_density > 0.3:
            result.flags.append(f"high_unk_density:{unk_density:.2f}")

    # 7. Boundary score from trim metadata
    if trim_meta:
        abrupt_start = trim_meta.get("abrupt_start", False)
        abrupt_end = trim_meta.get("abrupt_end", False)
        if abrupt_start:
            result.boundary_score -= 0.2
        if abrupt_end:
            result.boundary_score -= 0.2
        result.boundary_score = max(0, result.boundary_score)

    # 8. Composite quality score
    penalties = 0.0
    if not result.length_ratio_ok:
        penalties += 0.15
    if not result.script_check_ok:
        penalties += 0.2
    if result.lang_mismatch:
        penalties += 0.1
    if not result.tag_consistency_ok:
        penalties += 0.1
    if result.num_unk > 2:
        penalties += min(0.2, result.num_unk * 0.05)
    if result.boundary_score < 1.0:
        penalties += (1.0 - result.boundary_score) * 0.15

    result.quality_score = max(0.0, 1.0 - penalties)

    # 9. Lane flags
    result.asr_eligible = result.quality_score >= 0.5 and not result.is_empty
    result.tts_clean_eligible = (
        result.quality_score >= 0.7
        and result.boundary_score >= 0.8
        and not result.lang_mismatch
        and result.num_unk == 0
        and result.num_inaudible == 0
    )
    result.tts_expressive_eligible = (
        result.tts_clean_eligible
        and result.num_event_tags <= 2
        and result.quality_score >= 0.8
    )

    return result


def validate_batch(
    responses: list[dict],
    expected_language: str,
    audio_durations: dict[str, float],
    trim_metas: dict[str, dict],
) -> list[ValidationResult]:
    """Validate a batch of transcription responses."""
    results = []
    for resp in responses:
        seg_id = resp.get("segment_id", "unknown")
        data = resp.get("transcription_data", {})
        duration = audio_durations.get(seg_id, 5.0)
        trim = trim_metas.get(seg_id)
        results.append(validate_transcription(seg_id, data, expected_language, duration, trim))
    return results
