"""Tests for the Tier 1 validator."""
import pytest
from src.validator import validate_transcription, validate_batch, ValidationResult


class TestEmptyAndNoSpeech:
    def test_empty_transcription(self):
        result = validate_transcription("seg1", {"transcription": "", "tagged": ""}, "te", 5.0)
        assert result.is_empty
        assert result.quality_score == 0.0
        assert not result.asr_eligible

    def test_no_speech(self):
        data = {"transcription": "[NO_SPEECH]", "tagged": "[NO_SPEECH]", "detected_language": "te"}
        result = validate_transcription("seg1", data, "te", 5.0)
        assert result.is_no_speech
        assert not result.asr_eligible

    def test_none_transcription(self):
        result = validate_transcription("seg1", {}, "te", 5.0)
        assert result.is_empty


class TestLengthRatio:
    def test_normal_length(self):
        text = "hello world this is a test transcription with some words"
        data = {"transcription": text, "tagged": text, "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.length_ratio_ok

    def test_suspiciously_long(self):
        text = "a" * 500
        data = {"transcription": text, "tagged": text, "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 2.0)
        assert not result.length_ratio_ok


class TestLanguageMismatch:
    def test_matching_language(self):
        data = {"transcription": "test", "tagged": "test", "detected_language": "te"}
        result = validate_transcription("seg1", data, "te", 5.0)
        assert not result.lang_mismatch

    def test_mismatched_language(self):
        data = {"transcription": "test", "tagged": "test", "detected_language": "hi"}
        result = validate_transcription("seg1", data, "te", 5.0)
        assert result.lang_mismatch
        assert any("lang_mismatch" in f for f in result.flags)


class TestTagConsistency:
    def test_consistent_tags(self):
        trans = "hello world"
        tagged = "[laugh] hello world"
        data = {"transcription": trans, "tagged": tagged, "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.tag_consistency_ok

    def test_inconsistent_tags(self):
        trans = "hello world"
        tagged = "completely different text"
        data = {"transcription": trans, "tagged": tagged, "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert not result.tag_consistency_ok


class TestUNKDensity:
    def test_no_unk(self):
        data = {"transcription": "clean text here", "tagged": "clean text here", "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.num_unk == 0
        assert result.num_inaudible == 0

    def test_high_unk(self):
        data = {"transcription": "[UNK] [UNK] [UNK] word [UNK]", "tagged": "[UNK] [UNK] [UNK] word [UNK]", "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.num_unk == 4


class TestEventTagCounting:
    def test_counts_event_tags(self):
        tagged = "[laugh] hello [noise] world [cough]"
        data = {"transcription": "hello world", "tagged": tagged, "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.num_event_tags == 3


class TestQualityScore:
    def test_perfect_quality(self):
        data = {"transcription": "clean text", "tagged": "clean text", "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.quality_score >= 0.9

    def test_degraded_quality_with_issues(self):
        data = {"transcription": "[UNK] [UNK] x", "tagged": "different text", "detected_language": "hi"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.quality_score < 0.8


class TestLaneFlags:
    def test_asr_eligible(self):
        data = {"transcription": "good text", "tagged": "good text", "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.asr_eligible

    def test_tts_clean_eligible(self):
        data = {"transcription": "good text here", "tagged": "good text here", "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert result.tts_clean_eligible

    def test_tts_not_eligible_with_unk(self):
        data = {"transcription": "[UNK] text", "tagged": "[UNK] text", "detected_language": "en"}
        result = validate_transcription("seg1", data, "en", 5.0)
        assert not result.tts_clean_eligible


class TestBoundaryScore:
    def test_abrupt_start_penalty(self):
        data = {"transcription": "text", "tagged": "text", "detected_language": "en"}
        trim = {"abrupt_start": True, "abrupt_end": False}
        result = validate_transcription("seg1", data, "en", 5.0, trim)
        assert result.boundary_score < 1.0

    def test_both_abrupt(self):
        data = {"transcription": "text", "tagged": "text", "detected_language": "en"}
        trim = {"abrupt_start": True, "abrupt_end": True}
        result = validate_transcription("seg1", data, "en", 5.0, trim)
        assert result.boundary_score <= 0.61  # float precision tolerance


class TestBatchValidation:
    def test_validate_batch(self):
        responses = [
            {"segment_id": "s1", "transcription_data": {"transcription": "hello", "tagged": "hello", "detected_language": "en"}},
            {"segment_id": "s2", "transcription_data": {"transcription": "", "tagged": "", "detected_language": "en"}},
        ]
        results = validate_batch(responses, "en", {"s1": 5.0, "s2": 3.0}, {})
        assert len(results) == 2
        assert results[0].quality_score > results[1].quality_score
