#!/usr/bin/env python3
"""
Merge Indic tokens into the BPE tokenizer — FINAL version.

Strategy: Use the v1/BPE merge approach (which compresses correctly) but fix
decode by adding custom added_tokens entries with proper Unicode content.

The BPE model handles encoding (merge rules compose byte tokens into Indic
subwords). For decoding, each merged token is registered in added_tokens with
its actual Unicode content — this overrides the default byte-name display.
"""

import json
import os
import re
import time
import logging

import numpy as np

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)

ORIG_TOKENIZER = "/workspace/.hf_home/hub/models--CohereLabs--cohere-transcribe-03-2026/snapshots/90cf6a1e8427d6ab5e0060f53c095c245a20da4e/tokenizer.json"
NEW_TOKENS_JSON = "/workspace/training/tokenizer_extension/new_indic_tokens.json"
OUTPUT_DIR = "/workspace/training/tokenizer_extension/merged_48k"

BYTE_RE = re.compile(r"<0x([0-9A-Fa-f]{2})>")


def byte_concat_to_text(token_name):
    """Convert a BPE token name with <0xNN> byte patterns to Unicode text.

    Example: '▁<0xE0><0xA4><0xA8>' → '▁न'  (▁ + Devanagari NA)
    """
    parts = []
    pos = 0
    byte_buffer = []

    while pos < len(token_name):
        m = BYTE_RE.match(token_name, pos)
        if m:
            byte_buffer.append(int(m.group(1), 16))
            pos = m.end()
        else:
            # Flush byte buffer
            if byte_buffer:
                try:
                    parts.append(bytes(byte_buffer).decode("utf-8"))
                except UnicodeDecodeError:
                    # Partial sequence — emit as replacement chars
                    parts.append(bytes(byte_buffer).decode("utf-8", errors="replace"))
                byte_buffer = []
            parts.append(token_name[pos])
            pos += 1

    # Flush remaining bytes
    if byte_buffer:
        try:
            parts.append(bytes(byte_buffer).decode("utf-8"))
        except UnicodeDecodeError:
            parts.append(bytes(byte_buffer).decode("utf-8", errors="replace"))

    return "".join(parts)


def text_to_byte_tokens(text):
    """Convert text to BPE byte-fallback token names."""
    return [f"<0x{b:02X}>" for b in text.encode("utf-8")]


def find_existing_pieces(token, orig_vocab):
    """Greedy left-to-right decomposition into existing vocab pieces."""
    remaining = token
    pieces = []
    while remaining:
        best = None
        for length in range(min(len(remaining), 40), 0, -1):
            if remaining[:length] in orig_vocab:
                best = remaining[:length]
                break
        if best:
            pieces.append(best)
            remaining = remaining[len(best):]
        else:
            for b in remaining[0].encode("utf-8"):
                pieces.append(f"<0x{b:02X}>")
            remaining = remaining[1:]
    return pieces


def generate_merge_chain(pieces):
    """Left-to-right merge chain."""
    if len(pieces) <= 1:
        return []
    merges = []
    current = list(pieces)
    while len(current) > 1:
        merges.append([current[0], current[1]])
        current = [current[0] + current[1]] + current[2:]
    return merges


def main():
    t0 = time.time()
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # ── Load ────────────────────────────────────────────────────────────
    logger.info("Loading original tokenizer.json...")
    with open(ORIG_TOKENIZER, "r", encoding="utf-8") as f:
        tok_json = json.load(f)

    orig_vocab = tok_json["model"]["vocab"]
    orig_merges = tok_json["model"]["merges"]
    existing_merge_set = set(tuple(m) for m in orig_merges)

    logger.info(f"Original vocab: {len(orig_vocab)}, merges: {len(orig_merges)}")

    with open(NEW_TOKENS_JSON, "r", encoding="utf-8") as f:
        new_tokens = json.load(f)
    new_tokens = [t for t in new_tokens if t not in orig_vocab]
    logger.info(f"New Indic tokens: {len(new_tokens)}")

    # ── Build merges (same as v1 — left-to-right chain) ────────────────
    new_merges = []
    new_vocab_entries = {}
    next_id = max(orig_vocab.values()) + 1

    for token in new_tokens:
        pieces = find_existing_pieces(token, orig_vocab)
        if len(pieces) <= 1:
            if token not in orig_vocab:
                orig_vocab[token] = next_id
                next_id += 1
            continue

        chain = generate_merge_chain(pieces)
        for left, right in chain:
            merged = left + right
            if merged not in orig_vocab and merged not in new_vocab_entries:
                new_vocab_entries[merged] = next_id
                next_id += 1

        for left, right in chain:
            if (left, right) not in existing_merge_set:
                new_merges.append([left, right])
                existing_merge_set.add((left, right))

    # Add all new vocab entries
    orig_vocab.update(new_vocab_entries)

    final_vocab_size = len(orig_vocab)
    logger.info(f"New merges: {len(new_merges)}, new vocab entries: {len(new_vocab_entries)}")
    logger.info(f"Final vocab size: {final_vocab_size}")

    # ── Update tokenizer.json ───────────────────────────────────────────
    tok_json["model"]["vocab"] = orig_vocab
    tok_json["model"]["merges"] = orig_merges + new_merges

    # ── Fix decode: Add token-to-text mappings via added_tokens ─────────
    # For every new token whose name contains <0xNN> patterns, add an
    # added_token entry with the actual Unicode content. This teaches the
    # tokenizer decoder to produce correct text.
    #
    # We can't use added_tokens for this (they affect encoding too and
    # would conflict with BPE). Instead, fix the decoder by enhancing
    # the ByteFallback decoder to handle our merged tokens.
    #
    # Actually, the correct approach: modify the tokenizer's `decoder`
    # pipeline to first split merged tokens back into byte components,
    # then apply ByteFallback. We do this by adding a custom Replace
    # step or by post-processing.
    #
    # Simplest correct approach: Add a normalizer step in the decoder
    # that converts each token's name from byte-patterns to actual UTF-8.
    # The HF tokenizer's decoder processes token STRINGS, not IDs.
    # ByteFallback handles individual <0xNN> tokens → bytes.
    # For merged tokens like "<0xE0><0xA4><0xA8>", we need them split
    # back to individual bytes first, THEN ByteFallback converts to bytes.
    #
    # Solution: Use the "Split" decoder to split on <0xNN> boundaries,
    # which turns "<0xE0><0xA4><0xA8>" into ["<0xE0>", "<0xA4>", "<0xA8>"],
    # then ByteFallback handles each one.

    # Update the decoder pipeline to handle merged byte tokens
    tok_json["decoder"] = {
        "type": "Sequence",
        "decoders": [
            {
                "type": "Replace",
                "pattern": {"String": "▁"},
                "content": " ",
            },
            # Split merged byte tokens back into individual <0xNN> tokens
            # by inserting a separator before each <0x pattern
            {
                "type": "Replace",
                "pattern": {"Regex": "(?=<0x[0-9A-Fa-f]{2}>)"},
                "content": " \x00 ",  # temporary split marker
            },
            # Now use ByteFallback which processes individual <0xNN> → bytes
            {
                "type": "ByteFallback",
            },
            # Remove the split markers
            {
                "type": "Replace",
                "pattern": {"String": " \x00 "},
                "content": "",
            },
            {
                "type": "Fuse",
            },
            {
                "type": "Strip",
                "content": " ",
                "start": 1,
                "stop": 0,
            },
        ],
    }

    # ── Save ────────────────────────────────────────────────────────────
    output_path = os.path.join(OUTPUT_DIR, "tokenizer.json")
    logger.info(f"Writing to {output_path}...")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(tok_json, f, ensure_ascii=False)

    # Copy configs
    orig_dir = os.path.dirname(ORIG_TOKENIZER)
    import shutil
    for fname in ["special_tokens_map.json"]:
        src = os.path.join(orig_dir, fname)
        if os.path.exists(src):
            shutil.copy2(src, os.path.join(OUTPUT_DIR, fname))
    src = os.path.join(orig_dir, "tokenizer_config.json")
    with open(src) as f:
        cfg = json.load(f)
    cfg.pop("auto_map", None)
    cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
    with open(os.path.join(OUTPUT_DIR, "tokenizer_config.json"), "w") as f:
        json.dump(cfg, f, indent=2, ensure_ascii=False)

    with open(os.path.join(OUTPUT_DIR, "vocab_info.json"), "w") as f:
        json.dump({
            "original_vocab_size": 16384,
            "new_tokens_added": len(new_vocab_entries),
            "final_vocab_size": final_vocab_size,
            "new_merges_added": len(new_merges),
        }, f, indent=2)

    elapsed = time.time() - t0
    logger.info(f"Built in {elapsed:.1f}s")

    # ── Round-trip test ─────────────────────────────────────────────────
    logger.info("\nROUND-TRIP DECODE TEST")
    from transformers import PreTrainedTokenizerFast
    orig_tok = PreTrainedTokenizerFast(tokenizer_file=ORIG_TOKENIZER)
    new_tok = PreTrainedTokenizerFast(tokenizer_file=output_path)

    tests = {
        "hi": "नमस्ते दुनिया यह एक परीक्षण है",
        "te": "హలో ప్రపంచం ఇది ఒక పరీక్ష",
        "ta": "வணக்கம் உலகம் இது ஒரு சோதனை",
        "ml": "ഹലോ ലോകം ഇത് ഒരു പരീക്ഷ",
        "bn": "হ্যালো বিশ্ব এটি একটি পরীক্ষা",
        "gu": "હેલો વિશ્વ આ એક પરીક્ષા છે",
        "kn": "ಹಲೋ ವಿಶ್ವ ಇದು ಒಂದು ಪರೀಕ್ಷೆ",
        "pa": "ਹੈਲੋ ਵਿਸ਼ਵ ਇਹ ਇੱਕ ਟੈਸਟ ਹੈ",
        "mr": "हेलो विश्व हा एक चाचणी आहे",
        "or": "ହେଲୋ ବିଶ୍ୱ ଏହା ଏକ ପରୀକ୍ଷା",
        "as": "হেলো বিশ্ব এইটো এটা পৰীক্ষা",
        "en": "hello world this is a test",
        "hi-en": "आज का weather बहुत अच्छा है",
    }

    print(f"\n{'Lang':<8} {'New tok':>8} {'Old tok':>8} {'Red%':>7} {'Round-trip':>12}")
    print("=" * 55)
    all_pass = True
    for lang, text in tests.items():
        new_ids = new_tok.encode(text)
        old_ids = orig_tok.encode(text)
        decoded = new_tok.decode(new_ids, skip_special_tokens=True)
        match = decoded.strip() == text
        if not match:
            all_pass = False
        red = (1 - len(new_ids) / len(old_ids)) * 100
        status = "PASS" if match else "FAIL"
        print(f"{lang:<8} {len(new_ids):>8} {len(old_ids):>8} {red:>6.1f}% {status:>12}")
        if not match:
            print(f"  Expected: {text}")
            print(f"  Got:      {decoded.strip()}")
    print("=" * 55)
    print(f"Overall: {'ALL PASS' if all_pass else 'SOME FAILURES'}")


if __name__ == "__main__":
    main()
