"""
Utility functions for the extended Indic tokenizer.

The extended tokenizer uses BPE merges that compose byte-fallback tokens
(<0xNN>) into Indic subwords. Encoding works perfectly. Decoding requires
a post-processing step to convert <0xNN> patterns back to Unicode text,
because HF's ByteFallback decoder only handles individual byte tokens,
not merged multi-byte tokens.
"""

import json
import re
from pathlib import Path

from transformers import PreTrainedTokenizerFast

BYTE_PATTERN = re.compile(r"<0x([0-9A-Fa-f]{2})>")


def decode_byte_patterns(text: str) -> str:
    """Convert contiguous <0xNN> byte patterns in text to Unicode.

    Example: '<0xE0><0xA4><0xA8>' → 'न' (Devanagari NA)

    This is needed because the extended tokenizer's BPE merges create
    tokens whose names contain <0xNN> patterns. The standard HF decoder
    only handles individual <0xNN> tokens, not merged ones.
    """
    result = []
    pos = 0
    while pos < len(text):
        m = BYTE_PATTERN.match(text, pos)
        if m:
            byte_vals = []
            while m:
                byte_vals.append(int(m.group(1), 16))
                pos = m.end()
                m = BYTE_PATTERN.match(text, pos)
            result.append(bytes(byte_vals).decode("utf-8", errors="replace"))
        else:
            result.append(text[pos])
            pos += 1
    return "".join(result)


def decode_tokens(tokenizer, ids, skip_special_tokens=True):
    """Decode token IDs to Unicode text with byte-pattern fix.

    Use this instead of tokenizer.decode() for the extended Indic tokenizer.
    """
    raw = tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
    return decode_byte_patterns(raw)


def load_extended_tokenizer(model_path: str):
    """Load the extended fast tokenizer without splitting control tokens.

    `from_pretrained()` currently honors `split_special_tokens=true` from the
    saved config, which breaks decoder prompts like
    `<|startofcontext|>...<|nodiarize|>` into dozens of subword pieces.
    For training we need those prompt/control tokens to stay atomic.
    """
    model_dir = Path(model_path)
    tokenizer_file = model_dir / "tokenizer.json"
    special_tokens_file = model_dir / "special_tokens_map.json"

    if not tokenizer_file.exists():
        raise FileNotFoundError(f"Missing tokenizer.json in {model_dir}")
    if not special_tokens_file.exists():
        raise FileNotFoundError(f"Missing special_tokens_map.json in {model_dir}")

    special = json.loads(special_tokens_file.read_text())
    tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=str(tokenizer_file),
        bos_token=special.get("bos_token"),
        eos_token=special.get("eos_token"),
        pad_token=special.get("pad_token"),
        unk_token=special.get("unk_token"),
        additional_special_tokens=special.get("additional_special_tokens", []),
        split_special_tokens=False,
    )

    # Preserve the configured context window if present.
    tokenizer_config = model_dir / "tokenizer_config.json"
    if tokenizer_config.exists():
        cfg = json.loads(tokenizer_config.read_text())
        tokenizer.model_max_length = cfg.get("model_max_length", tokenizer.model_max_length)

    return tokenizer
