import pytest

from veena3modal.processing.prompt_builder import IndicPromptBuilder


class _DummyTokenizer:
    def encode(self, text: str, add_special_tokens: bool = False):
        # Deterministic placeholder output; enough to validate wiring.
        return [1, 2, 3]


def test_build_prefix_format_and_speaker_mapping():
    pb = IndicPromptBuilder(tokenizer=_DummyTokenizer())
    prompt = pb.build_prefix("lipakshi", "Hello [giggle]")

    assert prompt.startswith("<|task_controllable_tts|><|start_content|>")
    assert "Hello [giggle]" in prompt
    assert "<|speaker_0|>" in prompt  # lipakshi -> speaker_0
    assert prompt.endswith("<|start_global_token|>")


def test_build_prefix_normalizes_legacy_angle_emotions():
    pb = IndicPromptBuilder(tokenizer=_DummyTokenizer())
    prompt = pb.build_prefix("reet", "Hello <laugh>")
    assert "Hello [laughs]" in prompt


def test_build_prefix_rejects_invalid_speaker():
    pb = IndicPromptBuilder(tokenizer=_DummyTokenizer())
    with pytest.raises(ValueError):
        pb.build_prefix("not_a_speaker", "Hello")


def test_build_prefix_with_globals_injects_exact_32_globals_and_start_semantic_token():
    pb = IndicPromptBuilder(tokenizer=_DummyTokenizer())
    global_ids = list(range(32))

    prompt = pb.build_prefix_with_globals("lipakshi", "Hello", global_ids)
    assert "<|start_global_token|>" in prompt
    for gid in global_ids:
        assert f"<|bicodec_global_{gid}|>" in prompt
    assert prompt.endswith("<|start_semantic_token|>")


def test_build_prefix_with_globals_requires_exactly_32_tokens():
    pb = IndicPromptBuilder(tokenizer=_DummyTokenizer())
    with pytest.raises(ValueError):
        pb.build_prefix_with_globals("lipakshi", "Hello", [0, 1, 2])


def test_build_prefix_ids_uses_tokenizer_encode():
    pb = IndicPromptBuilder(tokenizer=_DummyTokenizer())
    ids = pb.build_prefix_ids("lipakshi", "Hello", validate=False)
    assert ids == [1, 2, 3]


