"""
Unit tests for BiCodec incremental token parsing.

Tests that the new O(1) incremental parser produces identical results
to the old O(n²) decode-all + regex-all approach.
"""

import pytest
from unittest.mock import MagicMock, patch
import re


class TestBiCodecTokenParser:
    """Tests for BiCodecTokenParser incremental parsing."""
    
    @pytest.fixture
    def mock_tokenizer(self):
        """Create a mock tokenizer with BiCodec vocabulary."""
        tokenizer = MagicMock()
        
        # Build a realistic vocabulary mapping
        # Token ID -> Token String
        vocab = {}
        
        # Add some regular tokens first
        vocab["<|im_start|>"] = 0
        vocab["<|im_end|>"] = 1
        vocab["hello"] = 100
        vocab["world"] = 101
        vocab["<|speaker_0|>"] = 200
        
        # Add BiCodec global tokens (32 tokens, IDs 1000-1031)
        for i in range(32):
            vocab[f"<|bicodec_global_{i}|>"] = 1000 + i
        
        # Add BiCodec semantic tokens (IDs 2000+)
        for i in range(1024):  # Sample range
            vocab[f"<|bicodec_semantic_{i}|>"] = 2000 + i
        
        # Reverse mapping for decode
        id_to_token = {v: k for k, v in vocab.items()}
        
        def mock_get_vocab():
            return vocab
        
        def mock_decode(token_ids, skip_special_tokens=False):
            return "".join(id_to_token.get(tid, f"[UNK:{tid}]") for tid in token_ids)
        
        tokenizer.get_vocab = mock_get_vocab
        tokenizer.decode = mock_decode
        
        return tokenizer
    
    def test_parse_semantic_token(self, mock_tokenizer):
        """Test parsing a semantic token."""
        from veena3modal.core.token_utils import BiCodecTokenParser
        
        parser = BiCodecTokenParser(mock_tokenizer)
        
        # Token ID 2123 should be <|bicodec_semantic_123|>
        result = parser.parse(2123)
        assert result == ("semantic", 123)
    
    def test_parse_global_token(self, mock_tokenizer):
        """Test parsing a global token."""
        from veena3modal.core.token_utils import BiCodecTokenParser
        
        parser = BiCodecTokenParser(mock_tokenizer)
        
        # Token ID 1007 should be <|bicodec_global_7|>
        result = parser.parse(1007)
        assert result == ("global", 7)
    
    def test_parse_non_bicodec_token(self, mock_tokenizer):
        """Test parsing a non-BiCodec token returns None."""
        from veena3modal.core.token_utils import BiCodecTokenParser
        
        parser = BiCodecTokenParser(mock_tokenizer)
        
        # Token ID 100 is "hello", not a BiCodec token
        result = parser.parse(100)
        assert result is None
    
    def test_cache_hit(self, mock_tokenizer):
        """Test that repeated parsing uses cache and returns same result."""
        from veena3modal.core.token_utils import BiCodecTokenParser
        
        parser = BiCodecTokenParser(mock_tokenizer)
        
        # First call populates cache (prewarm should have populated it already)
        result1 = parser.parse(2050)
        
        # Second call should return same result (from cache)
        result2 = parser.parse(2050)
        
        assert result1 == result2
        assert result1 == ("semantic", 50)  # 2050 - 2000 = 50
    
    def test_incremental_parsing_matches_regex(self, mock_tokenizer):
        """
        CRITICAL TEST: Verify incremental parsing produces same results as regex.
        
        This ensures the optimization doesn't change behavior.
        """
        from veena3modal.core.token_utils import BiCodecTokenParser
        
        parser = BiCodecTokenParser(mock_tokenizer)
        
        # Simulate a realistic token stream:
        # First some text tokens, then 32 global tokens, then semantic tokens
        token_stream = (
            [0, 100, 101, 200]  # <|im_start|> hello world <|speaker_0|>
            + list(range(1000, 1032))  # 32 global tokens
            + list(range(2000, 2100))  # 100 semantic tokens
        )
        
        # Method 1: Old O(n²) approach - decode all + regex
        full_text = mock_tokenizer.decode(token_stream, skip_special_tokens=False)
        
        semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", full_text)
        old_semantic = [int(m) for m in semantic_matches]
        
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", full_text)
        old_global = [int(m) for m in global_matches]
        
        # Method 2: New incremental approach
        new_semantic = []
        new_global = []
        parser.parse_incremental(token_stream, new_semantic, new_global)
        
        # Verify they match
        assert new_semantic == old_semantic, f"Semantic mismatch: {new_semantic} vs {old_semantic}"
        assert new_global == old_global, f"Global mismatch: {new_global} vs {old_global}"
    
    def test_incremental_streaming_simulation(self, mock_tokenizer):
        """
        Test incremental parsing as it would happen in streaming.
        
        Simulates vLLM returning growing token lists.
        """
        from veena3modal.core.token_utils import BiCodecTokenParser
        
        parser = BiCodecTokenParser(mock_tokenizer)
        
        # Simulate streaming: tokens arrive incrementally
        semantic_buffer = []
        global_buffer = []
        processed_count = 0
        
        # Full stream that will arrive in chunks
        full_stream = (
            list(range(1000, 1032))  # 32 global tokens
            + list(range(2000, 2050))  # 50 semantic tokens
        )
        
        # Simulate 5 iterations, each adding more tokens
        for i in range(5):
            # vLLM returns growing list
            end_idx = min((i + 1) * 20, len(full_stream))
            current_tokens = full_stream[:end_idx]
            
            # Only process NEW tokens
            new_tokens = current_tokens[processed_count:]
            processed_count = len(current_tokens)
            
            parser.parse_incremental(new_tokens, semantic_buffer, global_buffer)
        
        # Final verification: should match full parse
        expected_semantic = list(range(50))  # 0-49 (from semantic tokens 2000-2049)
        expected_global = list(range(32))    # 0-31 (from global tokens 1000-1031)
        
        assert semantic_buffer == expected_semantic
        assert global_buffer == expected_global
    
    def test_empty_token_list(self, mock_tokenizer):
        """Test parsing empty token list."""
        from veena3modal.core.token_utils import BiCodecTokenParser
        
        parser = BiCodecTokenParser(mock_tokenizer)
        
        semantic = []
        global_ = []
        parser.parse_incremental([], semantic, global_)
        
        assert semantic == []
        assert global_ == []
    
    def test_cache_stats(self, mock_tokenizer):
        """Test cache statistics reporting."""
        from veena3modal.core.token_utils import BiCodecTokenParser
        
        parser = BiCodecTokenParser(mock_tokenizer)
        
        # Parse some tokens
        parser.parse(2000)  # semantic
        parser.parse(1000)  # global
        parser.parse(100)   # non-bicodec
        
        stats = parser.get_cache_stats()
        
        assert "total_cached" in stats
        assert "bicodec_tokens" in stats
        assert "non_bicodec_tokens" in stats
        # After prewarm + 3 parses, we should have entries
        assert stats["total_cached"] >= 3


class TestRegexFallback:
    """Test the original regex method still works (for comparison)."""
    
    def test_extract_bicodec_tokens_from_text(self):
        """Test the regex extraction method."""
        # Simulate generated text with BiCodec tokens
        text = (
            "<|im_start|>system<|im_end|>"
            "<|bicodec_global_0|><|bicodec_global_1|><|bicodec_global_2|>"
            "<|bicodec_semantic_100|><|bicodec_semantic_101|><|bicodec_semantic_102|>"
        )
        
        # Original regex patterns
        semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", text)
        semantic_ids = [int(t) for t in semantic_matches]
        
        global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", text)
        global_ids = [int(t) for t in global_matches]
        
        assert semantic_ids == [100, 101, 102]
        assert global_ids == [0, 1, 2]


if __name__ == "__main__":
    pytest.main([__file__, "-v"])

