"""
BiCodec Token Utilities - Incremental Parsing

Provides O(1) per-token parsing instead of O(n) decode + regex scanning.

The key optimization: instead of:
    tokenizer.decode(all_tokens)  # O(n)
    re.findall(entire_text)       # O(n)

We do:
    for new_token_id in new_tokens:  # Only new tokens
        parsed = token_cache.get(token_id)  # O(1) lookup

This eliminates the O(n²) CPU burn in the streaming hot loop.
"""

from typing import Dict, Optional, Tuple, List
import re


class BiCodecTokenParser:
    """
    Caching parser for BiCodec tokens.
    
    Converts vLLM token IDs to (type, value) tuples efficiently.
    
    BiCodec tokens in the vocabulary look like:
        - <|bicodec_semantic_123|> -> ("semantic", 123)
        - <|bicodec_global_7|>     -> ("global", 7)
    
    Usage:
        parser = BiCodecTokenParser(tokenizer)
        
        # In hot loop (O(1) per token):
        for token_id in new_token_ids:
            result = parser.parse(token_id)
            if result:
                token_type, value = result
                if token_type == "semantic":
                    semantic_buffer.append(value)
                elif token_type == "global":
                    global_buffer.append(value)
    """
    
    # Regex patterns for parsing token strings
    SEMANTIC_PATTERN = re.compile(r"^<\|bicodec_semantic_(\d+)\|>$")
    GLOBAL_PATTERN = re.compile(r"^<\|bicodec_global_(\d+)\|>$")
    
    def __init__(self, tokenizer):
        """
        Initialize parser with tokenizer.
        
        Args:
            tokenizer: HuggingFace tokenizer with BiCodec special tokens
        """
        self.tokenizer = tokenizer
        # Cache: token_id -> Optional[Tuple[str, int]]
        # None means "not a BiCodec token" (cached negative)
        self._cache: Dict[int, Optional[Tuple[str, int]]] = {}
        
        # Pre-warm cache with known BiCodec token ranges if possible
        # This is optional but can speed up first-pass parsing
        self._prewarm_cache()
    
    def _prewarm_cache(self) -> None:
        """
        Pre-populate cache with BiCodec tokens from vocabulary.
        
        This avoids lazy cache misses during streaming.
        """
        try:
            vocab = self.tokenizer.get_vocab()
            for token_str, token_id in vocab.items():
                # Check if it's a BiCodec token
                if token_str.startswith("<|bicodec_"):
                    parsed = self._parse_token_string(token_str)
                    if parsed:
                        self._cache[token_id] = parsed
        except Exception:
            # If vocab access fails, we'll rely on lazy caching
            pass
    
    def _parse_token_string(self, token_str: str) -> Optional[Tuple[str, int]]:
        """
        Parse a token string into (type, value) tuple.
        
        Args:
            token_str: Token string like "<|bicodec_semantic_123|>"
        
        Returns:
            ("semantic", 123) or ("global", 7) or None if not BiCodec token
        """
        # Try semantic pattern
        match = self.SEMANTIC_PATTERN.match(token_str)
        if match:
            return ("semantic", int(match.group(1)))
        
        # Try global pattern
        match = self.GLOBAL_PATTERN.match(token_str)
        if match:
            return ("global", int(match.group(1)))
        
        return None
    
    def parse(self, token_id: int) -> Optional[Tuple[str, int]]:
        """
        Parse a single token ID into (type, value) tuple.
        
        O(1) for cached tokens, O(1) amortized for uncached.
        
        Args:
            token_id: vLLM token ID
        
        Returns:
            ("semantic", value) or ("global", value) or None
        """
        # Check cache first
        if token_id in self._cache:
            return self._cache[token_id]
        
        # Cache miss - decode and parse
        try:
            token_str = self.tokenizer.decode([token_id], skip_special_tokens=False)
            parsed = self._parse_token_string(token_str.strip())
            self._cache[token_id] = parsed
            return parsed
        except Exception:
            # Mark as non-BiCodec to avoid repeated failures
            self._cache[token_id] = None
            return None
    
    def parse_incremental(
        self,
        new_token_ids: List[int],
        semantic_buffer: List[int],
        global_buffer: List[int],
    ) -> Tuple[List[int], List[int]]:
        """
        Parse new tokens and append to existing buffers.
        
        This is the main method for incremental streaming:
        - Only processes NEW tokens (not entire history)
        - Modifies buffers in-place for efficiency
        - Returns the updated buffers (same objects)
        
        Args:
            new_token_ids: Only the new token IDs since last call
            semantic_buffer: Existing semantic token values (modified in-place)
            global_buffer: Existing global token values (modified in-place)
        
        Returns:
            Tuple of (semantic_buffer, global_buffer) - same objects, modified
        """
        for token_id in new_token_ids:
            parsed = self.parse(token_id)
            if parsed:
                token_type, value = parsed
                if token_type == "semantic":
                    semantic_buffer.append(value)
                elif token_type == "global":
                    global_buffer.append(value)
        
        return semantic_buffer, global_buffer
    
    def get_cache_stats(self) -> Dict[str, int]:
        """
        Return cache statistics for debugging.
        
        Returns:
            Dict with cache size and breakdown
        """
        total = len(self._cache)
        bicodec_count = sum(1 for v in self._cache.values() if v is not None)
        return {
            "total_cached": total,
            "bicodec_tokens": bicodec_count,
            "non_bicodec_tokens": total - bicodec_count,
        }


def create_bicodec_parser(tokenizer) -> BiCodecTokenParser:
    """
    Factory function to create a BiCodec token parser.
    
    Args:
        tokenizer: HuggingFace tokenizer
    
    Returns:
        Configured BiCodecTokenParser instance
    """
    return BiCodecTokenParser(tokenizer)

