#!/usr/bin/env python3
"""
Properly merge 32K Indic tokens into the BPE tokenizer.json.

The HF tokenizer uses a BPE model with byte_fallback. To add Indic subwords:
1. Add each new token to the BPE vocab
2. Generate synthetic BPE merge rules that compose each token from existing pieces
3. Place these merges AFTER existing merges so they have lower priority
   (existing tokenization for English is preserved)

This produces a tokenizer that:
- Tokenizes English identically to before
- Tokenizes Indic text using the new subword tokens instead of byte fallback
"""

import json
import os
import copy
import time
import logging
import sentencepiece as spm

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)

ORIG_TOKENIZER = "/workspace/.hf_home/hub/models--CohereLabs--cohere-transcribe-03-2026/snapshots/90cf6a1e8427d6ab5e0060f53c095c245a20da4e/tokenizer.json"
INDIC_SP_MODEL = "/workspace/training/tokenizer_extension/indic_32k.model"
NEW_TOKENS_JSON = "/workspace/training/tokenizer_extension/new_indic_tokens.json"
OUTPUT_DIR = "/workspace/training/tokenizer_extension/merged_48k"


def byte_to_bpe_token(byte_val):
    """Convert a byte value to its BPE byte-fallback token representation."""
    return f"<0x{byte_val:02X}>"


def text_to_bpe_bytes(text):
    """Convert text to a sequence of BPE byte-fallback tokens."""
    utf8_bytes = text.encode("utf-8")
    return [byte_to_bpe_token(b) for b in utf8_bytes]


def find_existing_pieces(token, orig_vocab):
    """
    Break down a new token into existing BPE pieces using greedy left-to-right matching.
    Returns list of existing vocab pieces that compose this token.
    """
    # Try to match using existing vocab pieces (longest match first)
    remaining = token
    pieces = []

    while remaining:
        # Try longest match from existing vocab
        best_match = None
        for length in range(min(len(remaining), 40), 0, -1):
            candidate = remaining[:length]
            if candidate in orig_vocab:
                best_match = candidate
                break

        if best_match:
            pieces.append(best_match)
            remaining = remaining[len(best_match):]
        else:
            # Fall back to byte-level tokens
            char = remaining[0]
            char_bytes = char.encode("utf-8")
            for b in char_bytes:
                bt = byte_to_bpe_token(b)
                pieces.append(bt)
            remaining = remaining[1:]

    return pieces


def generate_merge_chain(pieces):
    """
    Generate a chain of BPE merges that combines a list of pieces into one token.
    Returns list of (left, right) merge pairs.
    """
    if len(pieces) <= 1:
        return []

    merges = []
    current = list(pieces)

    # Merge from left to right: (a, b) -> ab, then (ab, c) -> abc, etc.
    while len(current) > 1:
        left = current[0]
        right = current[1]
        merged = left + right
        merges.append((left, right))
        current = [merged] + current[2:]

    return merges


def main():
    t0 = time.time()
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # ── Load original tokenizer.json ────────────────────────────────────
    logger.info("Loading original tokenizer.json...")
    with open(ORIG_TOKENIZER, "r", encoding="utf-8") as f:
        tok_json = json.load(f)

    orig_vocab = tok_json["model"]["vocab"]
    orig_merges = tok_json["model"]["merges"]
    logger.info(f"Original vocab: {len(orig_vocab)}, merges: {len(orig_merges)}")

    # ── Load new Indic tokens ───────────────────────────────────────────
    with open(NEW_TOKENS_JSON, "r", encoding="utf-8") as f:
        new_tokens = json.load(f)
    logger.info(f"New Indic tokens to add: {len(new_tokens)}")

    # ── Generate merges and add to vocab ────────────────────────────────
    new_merges = []
    added_count = 0
    skipped_count = 0
    intermediate_tokens = {}  # track intermediate merge results we also need in vocab

    next_id = max(orig_vocab.values()) + 1

    for token in new_tokens:
        if token in orig_vocab:
            skipped_count += 1
            continue

        # Decompose token into existing pieces
        pieces = find_existing_pieces(token, orig_vocab)

        if len(pieces) <= 1:
            # Single piece — just add to vocab, no merge needed
            if token not in orig_vocab:
                orig_vocab[token] = next_id
                next_id += 1
                added_count += 1
            continue

        # Generate merge chain
        merge_chain = generate_merge_chain(pieces)

        # Add intermediate merged tokens to vocab if not present
        current = pieces[0]
        for left, right in merge_chain:
            merged = left + right
            if merged not in orig_vocab and merged not in intermediate_tokens:
                intermediate_tokens[merged] = next_id
                next_id += 1
            current = merged

        # Add all merges (as [left, right] lists to match original format)
        for left, right in merge_chain:
            new_merges.append([left, right])

        added_count += 1

    # Add intermediate tokens to vocab
    for token, token_id in intermediate_tokens.items():
        if token not in orig_vocab:
            orig_vocab[token] = token_id

    # Deduplicate merges (preserve order)
    # Merges may be stored as strings "a b" or lists ["a", "b"]
    if orig_merges and isinstance(orig_merges[0], list):
        existing_merge_set = set(tuple(m) for m in orig_merges)
    else:
        existing_merge_set = set(orig_merges)
    unique_new_merges = []
    seen = set()
    for m in new_merges:
        m_key = m if isinstance(m, str) else tuple(m)
        if m_key not in existing_merge_set and m_key not in seen:
            unique_new_merges.append(m)
            seen.add(m_key)

    logger.info(f"Added {added_count} new tokens, skipped {skipped_count} duplicates")
    logger.info(f"Generated {len(unique_new_merges)} new unique merges")
    logger.info(f"Added {len(intermediate_tokens)} intermediate tokens")
    logger.info(f"Final vocab size: {len(orig_vocab)}")
    logger.info(f"Final merges: {len(orig_merges)} + {len(unique_new_merges)} = {len(orig_merges) + len(unique_new_merges)}")

    # ── Update tokenizer.json ───────────────────────────────────────────
    tok_json["model"]["vocab"] = orig_vocab
    tok_json["model"]["merges"] = orig_merges + unique_new_merges

    # Update added_tokens list with new token entries
    existing_added = {t["id"]: t for t in tok_json.get("added_tokens", [])}
    # Don't add regular tokens to added_tokens — they're in the vocab

    # ── Save ────────────────────────────────────────────────────────────
    output_path = os.path.join(OUTPUT_DIR, "tokenizer.json")
    logger.info(f"Writing merged tokenizer to {output_path}...")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(tok_json, f, ensure_ascii=False)

    # Copy tokenizer_config.json, update vocab size
    orig_config_path = os.path.join(os.path.dirname(ORIG_TOKENIZER), "tokenizer_config.json")
    with open(orig_config_path, "r") as f:
        tok_config = json.load(f)
    # Remove auto_map that references custom tokenizer class (use generic fast tokenizer)
    if "auto_map" in tok_config:
        del tok_config["auto_map"]
    if "tokenizer_class" in tok_config:
        tok_config["tokenizer_class"] = "PreTrainedTokenizerFast"
    config_out = os.path.join(OUTPUT_DIR, "tokenizer_config.json")
    with open(config_out, "w") as f:
        json.dump(tok_config, f, indent=2, ensure_ascii=False)

    # Copy special_tokens_map
    orig_special = os.path.join(os.path.dirname(ORIG_TOKENIZER), "special_tokens_map.json")
    if os.path.exists(orig_special):
        import shutil
        shutil.copy2(orig_special, os.path.join(OUTPUT_DIR, "special_tokens_map.json"))

    elapsed = time.time() - t0
    logger.info(f"Done in {elapsed:.1f}s")

    # ── Quick benchmark ─────────────────────────────────────────────────
    logger.info("\n" + "=" * 80)
    logger.info("BENCHMARK: Old vs New tokenizer")
    logger.info("=" * 80)

    from transformers import PreTrainedTokenizerFast

    orig_tok = PreTrainedTokenizerFast(tokenizer_file=ORIG_TOKENIZER)
    new_tok = PreTrainedTokenizerFast(tokenizer_file=output_path)

    logger.info(f"Old vocab: {orig_tok.vocab_size}, New vocab: {new_tok.vocab_size}")

    test_data = {
        "hi": [
            "नमस्ते दुनिया यह एक परीक्षण है",
            "भारत एक विविधताओं से भरा देश है जहां अनेक भाषाएं बोली जाती हैं",
            "आज का समय बहुत ही डिजिटल समय हो गया है",
            "कहना चाहूँगा कि आप सर ना इस्तेमाल करें",
        ],
        "te": [
            "హలో ప్రపంచం ఇది ఒక పరీక్ష",
            "మీది హరిత గారిది రవళి గారిది బాండింగ్ కానివ్వండి",
            "మీరు చెప్పండి అసలు ఇప్పటికీ ఎలా ఉంటుంది",
            "బయటికి కదిలే పని ఉండదు బయటి పనులు మనం చేయలేము",
        ],
        "ta": [
            "வணக்கம் உலகம் இது ஒரு சோதனை",
            "ஆதவன் தமிழ் நேயர்களுக்கு வணக்கம்",
            "சிறந்த ஆளுமைகளை சந்தித்து நாம் நேர்காணல் எடுத்து வருகிறோம்",
            "அரசியல் மேடை நிகழ்ச்சியில் மீண்டும் உங்களை சந்திப்பதில்",
        ],
        "ml": [
            "ഹലോ ലോകം ഇത് ഒരു പരീക്ഷ",
            "ഇനി കളർത്തേണ്ട ശിവരാത്രിയുടെ തത്വമേ",
            "കൂട്ടത്തിലുണ്ടായിരുന്ന യുവാക്കൾ പാഞ്ഞുചെന്നു",
            "അങ്ങനെ കാർലി സുരക്ഷിതമായി തിരിച്ചുകയറുകയാണ്",
        ],
        "bn": [
            "হ্যালো বিশ্ব এটি একটি পরীক্ষা",
            "গ্যাগ কি জিনিস ভাই আসিফ ভাই আমি তো জীবনে প্রথম শুনলাম",
            "এটা যদি কোনো প্রমাণ না হয় তাহলে ভাইয়া একটা কাজ করেন",
            "আপনি কোনটা পছন্দ আপনার সেটা জিজ্ঞেস করতেছি",
        ],
        "gu": [
            "હેલો વિશ્વ આ એક પરીક્ષા છે",
            "પ્રિય શિક્ષક માટે ના એ વખતે તો શું અમારે એવું કંઈ હતું જ નહીં",
            "એમનેમ જ એ બધા સાથે એ રીતે બિહેવ કરતા હતા",
            "સવા વાસુદેવભાઈ આપણે ઉમિયા માતા મંદિર ખરું ને",
        ],
        "kn": [
            "ಹಲೋ ವಿಶ್ವ ಇದು ಒಂದು ಪರೀಕ್ಷೆ",
            "ಕರ್ನಾಟಕ ರಾಜ್ಯದಲ್ಲಿ ಅನೇಕ ಭಾಷೆಗಳನ್ನು ಮಾತನಾಡಲಾಗುತ್ತದೆ",
            "ಇಂದು ನಾವು ಒಂದು ಹೊಸ ವಿಷಯದ ಬಗ್ಗೆ ಮಾತನಾಡೋಣ",
            "ನಮಸ್ಕಾರ ಎಲ್ಲರಿಗೂ ಸ್ವಾಗತ",
        ],
        "pa": [
            "ਹੈਲੋ ਵਿਸ਼ਵ ਇਹ ਇੱਕ ਟੈਸਟ ਹੈ",
            "ਪੰਜਾਬ ਵਿੱਚ ਬਹੁਤ ਸਾਰੀਆਂ ਭਾਸ਼ਾਵਾਂ ਬੋਲੀਆਂ ਜਾਂਦੀਆਂ ਹਨ",
            "ਅੱਜ ਦਾ ਸਮਾਂ ਬਹੁਤ ਹੀ ਡਿਜੀਟਲ ਹੋ ਗਿਆ ਹੈ",
            "ਸਤਿ ਸ੍ਰੀ ਅਕਾਲ ਜੀ ਆਇਆਂ ਨੂੰ",
        ],
        "mr": [
            "हेलो विश्व हा एक चाचणी आहे",
            "महाराष्ट्र राज्यामध्ये अनेक भाषा बोलल्या जातात",
            "आज सकाळी मी बाजारात गेलो होतो",
            "नमस्कार सर्वांना स्वागत आहे",
        ],
        "or": [
            "ହେଲୋ ବିଶ୍ୱ ଏହା ଏକ ପରୀକ୍ଷା",
            "ଓଡ଼ିଶାରେ ବହୁତ ଭାଷା କୁହାଯାଏ",
            "ଆଜି ଆମେ ଏକ ନୂଆ ବିଷୟ ବାରେ ଆଲୋଚନା କରିବା",
            "ନମସ୍କାର ସମସ୍ତଙ୍କୁ ସ୍ୱାଗତ",
        ],
        "as": [
            "হেলো বিশ্ব এইটো এটা পৰীক্ষা",
            "অসমীয়া ভাষা এটা সুন্দৰ ভাষা",
            "আজি আমি বজাৰলৈ গৈছিলোঁ",
            "নমস্কাৰ সকলোকে স্বাগতম",
        ],
        "en": [
            "hello world this is a test",
            "the quick brown fox jumps over the lazy dog",
            "India is a diverse country with many languages",
            "good morning everyone welcome to the show",
        ],
        "hi-en": [
            "bhai yeh video bahut interesting hai please like karo",
            "आज का weather बहुत अच्छा है let's go outside",
            "मैंने आज office में meeting attend की",
            "coding करना बहुत fun है especially Python में",
        ],
    }

    import numpy as np
    print("\n" + "=" * 100)
    print(f"{'Lang':<8} {'Words':>6} {'Chars':>6} │ {'Old tok':>8} {'tok/w':>7} {'tok/ch':>7} │ "
          f"{'New tok':>8} {'tok/w':>7} {'tok/ch':>7} │ {'Reduction':>10}")
    print("=" * 100)

    results = {}
    for lang, sentences in test_data.items():
        total_words = 0
        total_chars = 0
        total_old_tokens = 0
        total_new_tokens = 0

        for sent in sentences:
            words = sent.split()
            total_words += len(words)
            total_chars += len(sent)
            total_old_tokens += len(orig_tok.encode(sent))
            total_new_tokens += len(new_tok.encode(sent))

        old_tpw = total_old_tokens / total_words
        new_tpw = total_new_tokens / total_words
        reduction = (1 - new_tpw / old_tpw) * 100

        results[lang] = {
            "old_tok_per_word": old_tpw,
            "new_tok_per_word": new_tpw,
            "reduction_pct": reduction,
        }

        print(f"{lang:<8} {total_words:>6} {total_chars:>6} │ {total_old_tokens:>8} {old_tpw:>7.2f} "
              f"{total_old_tokens/total_chars:>7.2f} │ {total_new_tokens:>8} {new_tpw:>7.2f} "
              f"{total_new_tokens/total_chars:>7.2f} │ {reduction:>9.1f}%")

    print("=" * 100)

    indic_results = {k: v for k, v in results.items() if k not in ("en",)}
    avg_old = np.mean([v["old_tok_per_word"] for v in indic_results.values()])
    avg_new = np.mean([v["new_tok_per_word"] for v in indic_results.values()])
    avg_red = np.mean([v["reduction_pct"] for v in indic_results.values()])
    print(f"\nIndic average: {avg_old:.2f} → {avg_new:.2f} tok/word ({avg_red:.1f}% reduction)")

    if "en" in results:
        en = results["en"]
        print(f"English:       {en['old_tok_per_word']:.2f} → {en['new_tok_per_word']:.2f} tok/word "
              f"({en['reduction_pct']:.1f}% change)")

    # Detailed examples
    print("\n" + "=" * 80)
    print("DETAILED EXAMPLES")
    print("=" * 80)
    for lang in ["hi", "te", "ta", "bn", "en", "hi-en"]:
        sent = test_data[lang][0]
        old_toks = orig_tok.tokenize(sent)
        new_toks = new_tok.tokenize(sent)
        print(f"\n{lang}: \"{sent}\"")
        print(f"  OLD ({len(old_toks):>3}): {old_toks[:20]}{'...' if len(old_toks) > 20 else ''}")
        print(f"  NEW ({len(new_toks):>3}): {new_toks[:20]}{'...' if len(new_toks) > 20 else ''}")

    # Save results
    with open(os.path.join(OUTPUT_DIR, "benchmark_results.json"), "w") as f:
        json.dump(results, f, indent=2)


if __name__ == "__main__":
    main()
