#!/usr/bin/env python3
"""
Merge 32K Indic tokens into BPE tokenizer.json — v2 (optimized merge strategy).

Strategy: Build merges bottom-up:
  1. Byte pairs → individual Indic characters (shared across all tokens)
  2. Character pairs → common bigrams (shared across tokens)
  3. Remaining subword merges

This minimizes intermediate tokens by maximizing sharing.
"""

import json
import os
import time
import logging
from collections import Counter, defaultdict

import numpy as np
import sentencepiece as spm

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)

ORIG_TOKENIZER = "/workspace/.hf_home/hub/models--CohereLabs--cohere-transcribe-03-2026/snapshots/90cf6a1e8427d6ab5e0060f53c095c245a20da4e/tokenizer.json"
NEW_TOKENS_JSON = "/workspace/training/tokenizer_extension/new_indic_tokens.json"
OUTPUT_DIR = "/workspace/training/tokenizer_extension/merged_48k"


def text_to_byte_tokens(text):
    """Convert text to byte-level BPE tokens."""
    return [f"<0x{b:02X}>" for b in text.encode("utf-8")]


def merge_pair(tokens, left, right):
    """Merge all adjacent (left, right) pairs in a token list."""
    result = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and tokens[i] == left and tokens[i + 1] == right:
            result.append(left + right)
            i += 2
        else:
            result.append(tokens[i])
            i += 1
    return result


def get_pair_counts(all_token_seqs):
    """Count adjacent pairs across all token sequences."""
    counts = Counter()
    for seq in all_token_seqs:
        for i in range(len(seq) - 1):
            counts[(seq[i], seq[i + 1])] += 1
    return counts


def main():
    t0 = time.time()
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # ── Load original tokenizer.json ────────────────────────────────────
    logger.info("Loading original tokenizer.json...")
    with open(ORIG_TOKENIZER, "r", encoding="utf-8") as f:
        tok_json = json.load(f)

    orig_vocab = dict(tok_json["model"]["vocab"])  # copy
    orig_merges = list(tok_json["model"]["merges"])  # copy
    orig_vocab_set = set(orig_vocab.keys())
    existing_merge_set = set(tuple(m) for m in orig_merges)

    logger.info(f"Original vocab: {len(orig_vocab)}, merges: {len(orig_merges)}")

    # ── Load new Indic tokens ───────────────────────────────────────────
    with open(NEW_TOKENS_JSON, "r", encoding="utf-8") as f:
        new_tokens = json.load(f)
    logger.info(f"New Indic tokens to add: {len(new_tokens)}")

    # ── Decompose all new tokens into byte sequences ────────────────────
    # Each new token becomes a sequence of byte-level tokens
    token_byte_seqs = {}
    for token in new_tokens:
        if token in orig_vocab:
            continue
        byte_seq = []
        for char in token:
            if char in orig_vocab:
                byte_seq.append(char)
            else:
                byte_seq.extend(text_to_byte_tokens(char))
        token_byte_seqs[token] = byte_seq

    logger.info(f"Tokens to merge: {len(token_byte_seqs)}")

    # ── BPE-style bottom-up merge learning ──────────────────────────────
    # Start with byte sequences, repeatedly merge the most frequent pair
    # This naturally creates shared intermediates (like Indic character tokens)

    all_seqs = {token: list(seq) for token, seq in token_byte_seqs.items()}
    new_merges = []
    new_vocab = {}
    next_id = max(orig_vocab.values()) + 1

    # We need enough merges to reduce all sequences to length 1
    target_remaining = len(all_seqs)  # each should become 1 token
    total_tokens_remaining = sum(len(s) for s in all_seqs.values())

    logger.info(f"Starting: {total_tokens_remaining} total tokens across {len(all_seqs)} sequences")
    logger.info(f"Target: each sequence → 1 token (need to eliminate ~{total_tokens_remaining - target_remaining} tokens via merges)")

    iteration = 0
    max_iterations = 200000  # safety limit

    while iteration < max_iterations:
        # Count pairs
        pair_counts = get_pair_counts(all_seqs.values())
        if not pair_counts:
            break

        # Find most frequent pair
        best_pair = pair_counts.most_common(1)[0]
        (left, right), count = best_pair

        if count < 2 and all(len(s) <= 1 for s in all_seqs.values()):
            break

        # Skip if this merge already exists in original
        if (left, right) in existing_merge_set:
            # Still apply it (the original tokenizer does) but don't add as new merge
            merged = left + right
            for token in all_seqs:
                all_seqs[token] = merge_pair(all_seqs[token], left, right)
            # Make sure merged token is in vocab
            if merged not in orig_vocab and merged not in new_vocab:
                new_vocab[merged] = next_id
                next_id += 1
            continue

        merged = left + right

        # Add merge
        new_merges.append([left, right])

        # Add merged token to vocab
        if merged not in orig_vocab and merged not in new_vocab:
            new_vocab[merged] = next_id
            next_id += 1

        # Apply merge to all sequences
        for token in all_seqs:
            all_seqs[token] = merge_pair(all_seqs[token], left, right)

        iteration += 1

        # Check if all sequences are length 1
        max_len = max(len(s) for s in all_seqs.values())
        if max_len <= 1:
            break

        if iteration % 5000 == 0:
            total_remaining = sum(len(s) for s in all_seqs.values())
            logger.info(f"  iter {iteration}: {len(new_merges)} merges, "
                        f"{len(new_vocab)} new vocab, max_seq_len={max_len}, "
                        f"total_tokens={total_remaining}")

    total_remaining = sum(len(s) for s in all_seqs.values())
    logger.info(f"Merge learning done: {iteration} iterations, {len(new_merges)} new merges, "
                f"{len(new_vocab)} new vocab entries, total_tokens_remaining={total_remaining}")

    # ── Verify all tokens are in vocab ──────────────────────────────────
    missing = 0
    for token, seq in all_seqs.items():
        if len(seq) == 1 and seq[0] == token:
            if token not in orig_vocab and token not in new_vocab:
                new_vocab[token] = next_id
                next_id += 1
        elif len(seq) > 1:
            missing += 1

    if missing > 0:
        logger.warning(f"{missing} tokens not fully merged (will use partial subwords)")

    final_vocab_size = len(orig_vocab) + len(new_vocab)
    logger.info(f"Final vocab size: {len(orig_vocab)} + {len(new_vocab)} = {final_vocab_size}")

    # ── Analyze new vocab ───────────────────────────────────────────────
    # Categorize new tokens
    char_tokens = 0
    subword_tokens = 0
    for token in new_vocab:
        # Count Unicode characters (excluding ▁)
        clean = token.replace("▁", "").replace("<0x", "").replace(">", "")
        # Check if it's a byte sequence or actual text
        if all(c in "0123456789ABCDEF" for c in clean):
            char_tokens += 1  # byte-level intermediate
        else:
            subword_tokens += 1

    logger.info(f"New vocab breakdown: {char_tokens} byte-intermediates, {subword_tokens} subword tokens")

    # ── Build final tokenizer.json ──────────────────────────────────────
    # Merge vocabs
    final_vocab = dict(orig_vocab)
    final_vocab.update(new_vocab)
    final_merges = orig_merges + new_merges

    tok_json["model"]["vocab"] = final_vocab
    tok_json["model"]["merges"] = final_merges

    # ── Save ────────────────────────────────────────────────────────────
    output_path = os.path.join(OUTPUT_DIR, "tokenizer.json")
    logger.info(f"Writing merged tokenizer to {output_path}...")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(tok_json, f, ensure_ascii=False)

    # Copy configs
    orig_dir = os.path.dirname(ORIG_TOKENIZER)
    import shutil
    for fname in ["tokenizer_config.json", "special_tokens_map.json"]:
        src = os.path.join(orig_dir, fname)
        if os.path.exists(src):
            # For tokenizer_config, fix the class reference
            if fname == "tokenizer_config.json":
                with open(src) as f:
                    cfg = json.load(f)
                cfg.pop("auto_map", None)
                cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
                with open(os.path.join(OUTPUT_DIR, fname), "w") as f:
                    json.dump(cfg, f, indent=2, ensure_ascii=False)
            else:
                shutil.copy2(src, os.path.join(OUTPUT_DIR, fname))

    # Save vocab size info for embedding resize
    with open(os.path.join(OUTPUT_DIR, "vocab_info.json"), "w") as f:
        json.dump({
            "original_vocab_size": len(orig_vocab),
            "new_tokens_added": len(new_vocab),
            "final_vocab_size": final_vocab_size,
            "new_merges_added": len(new_merges),
        }, f, indent=2)

    elapsed = time.time() - t0
    logger.info(f"Done in {elapsed:.1f}s")

    # ── Benchmark ───────────────────────────────────────────────────────
    benchmark(output_path)


def benchmark(new_tokenizer_path):
    """Compare old vs new tokenizer."""
    logger.info("\n" + "=" * 80)
    logger.info("BENCHMARK: Old vs New tokenizer")
    logger.info("=" * 80)

    from transformers import PreTrainedTokenizerFast

    orig_tok = PreTrainedTokenizerFast(tokenizer_file=ORIG_TOKENIZER)
    new_tok = PreTrainedTokenizerFast(tokenizer_file=new_tokenizer_path)

    logger.info(f"Old vocab: {orig_tok.vocab_size}, New vocab: {new_tok.vocab_size}")

    test_data = {
        "hi": [
            "नमस्ते दुनिया यह एक परीक्षण है",
            "भारत एक विविधताओं से भरा देश है जहां अनेक भाषाएं बोली जाती हैं",
            "आज का समय बहुत ही डिजिटल समय हो गया है",
            "कहना चाहूँगा कि आप सर ना इस्तेमाल करें",
        ],
        "te": [
            "హలో ప్రపంచం ఇది ఒక పరీక్ష",
            "మీది హరిత గారిది రవళి గారిది బాండింగ్ కానివ్వండి",
            "మీరు చెప్పండి అసలు ఇప్పటికీ ఎలా ఉంటుంది",
            "బయటికి కదిలే పని ఉండదు బయటి పనులు మనం చేయలేము",
        ],
        "ta": [
            "வணக்கம் உலகம் இது ஒரு சோதனை",
            "ஆதவன் தமிழ் நேயர்களுக்கு வணக்கம்",
            "சிறந்த ஆளுமைகளை சந்தித்து நாம் நேர்காணல் எடுத்து வருகிறோம்",
            "அரசியல் மேடை நிகழ்ச்சியில் மீண்டும் உங்களை சந்திப்பதில்",
        ],
        "ml": [
            "ഹലോ ലോകം ഇത് ഒരു പരീക്ഷ",
            "ഇനി കളർത്തേണ്ട ശിവരാത്രിയുടെ തത്വമേ",
            "കൂട്ടത്തിലുണ്ടായിരുന്ന യുവാക്കൾ പാഞ്ഞുചെന്നു",
            "അങ്ങനെ കാർലി സുരക്ഷിതമായി തിരിച്ചുകയറുകയാണ്",
        ],
        "bn": [
            "হ্যালো বিশ্ব এটি একটি পরীক্ষা",
            "গ্যাগ কি জিনিস ভাই আসিফ ভাই আমি তো জীবনে প্রথম শুনলাম",
            "এটা যদি কোনো প্রমাণ না হয় তাহলে ভাইয়া একটা কাজ করেন",
            "আপনি কোনটা পছন্দ আপনার সেটা জিজ্ঞেস করতেছি",
        ],
        "gu": [
            "હેલો વિશ્વ આ એક પરીક્ષા છે",
            "પ્રિય શિક્ષક માટે ના એ વખતે તો શું અમારે એવું કંઈ હતું જ નહીં",
            "એમનેમ જ એ બધા સાથે એ રીતે બિહેવ કરતા હતા",
            "સવા વાસુદેવભાઈ આપણે ઉમિયા માતા મંદિર ખરું ને",
        ],
        "kn": [
            "ಹಲೋ ವಿಶ್ವ ಇದು ಒಂದು ಪರೀಕ್ಷೆ",
            "ಕರ್ನಾಟಕ ರಾಜ್ಯದಲ್ಲಿ ಅನೇಕ ಭಾಷೆಗಳನ್ನು ಮಾತನಾಡಲಾಗುತ್ತದೆ",
            "ಇಂದು ನಾವು ಒಂದು ಹೊಸ ವಿಷಯದ ಬಗ್ಗೆ ಮಾತನಾಡೋಣ",
            "ನಮಸ್ಕಾರ ಎಲ್ಲರಿಗೂ ಸ್ವಾಗತ",
        ],
        "pa": [
            "ਹੈਲੋ ਵਿਸ਼ਵ ਇਹ ਇੱਕ ਟੈਸਟ ਹੈ",
            "ਪੰਜਾਬ ਵਿੱਚ ਬਹੁਤ ਸਾਰੀਆਂ ਭਾਸ਼ਾਵਾਂ ਬੋਲੀਆਂ ਜਾਂਦੀਆਂ ਹਨ",
            "ਅੱਜ ਦਾ ਸਮਾਂ ਬਹੁਤ ਹੀ ਡਿਜੀਟਲ ਹੋ ਗਿਆ ਹੈ",
            "ਸਤਿ ਸ੍ਰੀ ਅਕਾਲ ਜੀ ਆਇਆਂ ਨੂੰ",
        ],
        "mr": [
            "हेलो विश्व हा एक चाचणी आहे",
            "महाराष्ट्र राज्यामध्ये अनेक भाषा बोलल्या जातात",
            "आज सकाळी मी बाजारात गेलो होतो",
            "नमस्कार सर्वांना स्वागत आहे",
        ],
        "or": [
            "ହେଲୋ ବିଶ୍ୱ ଏହା ଏକ ପରୀକ୍ଷା",
            "ଓଡ଼ିଶାରେ ବହୁତ ଭାଷା କୁହାଯାଏ",
            "ଆଜି ଆମେ ଏକ ନୂଆ ବିଷୟ ବାରେ ଆଲୋଚନା କରିବା",
            "ନମସ୍କାର ସମସ୍ତଙ୍କୁ ସ୍ୱାଗତ",
        ],
        "as": [
            "হেলো বিশ্ব এইটো এটা পৰীক্ষা",
            "অসমীয়া ভাষা এটা সুন্দৰ ভাষা",
            "আজি আমি বজাৰলৈ গৈছিলোঁ",
            "নমস্কাৰ সকলোকে স্বাগতম",
        ],
        "en": [
            "hello world this is a test",
            "the quick brown fox jumps over the lazy dog",
            "India is a diverse country with many languages",
            "good morning everyone welcome to the show",
        ],
        "hi-en": [
            "bhai yeh video bahut interesting hai please like karo",
            "आज का weather बहुत अच्छा है let's go outside",
            "मैंने आज office में meeting attend की",
            "coding करना बहुत fun है especially Python में",
        ],
    }

    print("\n" + "=" * 100)
    print(f"{'Lang':<8} {'Words':>6} {'Chars':>6} │ {'Old tok':>8} {'tok/w':>7} {'tok/ch':>7} │ "
          f"{'New tok':>8} {'tok/w':>7} {'tok/ch':>7} │ {'Reduction':>10}")
    print("=" * 100)

    results = {}
    for lang, sentences in test_data.items():
        tw = tc = tot = tnt = 0
        for sent in sentences:
            tw += len(sent.split())
            tc += len(sent)
            tot += len(orig_tok.encode(sent))
            tnt += len(new_tok.encode(sent))

        old_tpw = tot / tw
        new_tpw = tnt / tw
        red = (1 - new_tpw / old_tpw) * 100
        results[lang] = {"old_tpw": old_tpw, "new_tpw": new_tpw, "reduction": red}

        print(f"{lang:<8} {tw:>6} {tc:>6} │ {tot:>8} {old_tpw:>7.2f} {tot/tc:>7.2f} │ "
              f"{tnt:>8} {new_tpw:>7.2f} {tnt/tc:>7.2f} │ {red:>9.1f}%")

    print("=" * 100)

    indic = {k: v for k, v in results.items() if k != "en"}
    avg_old = np.mean([v["old_tpw"] for v in indic.values()])
    avg_new = np.mean([v["new_tpw"] for v in indic.values()])
    avg_red = np.mean([v["reduction"] for v in indic.values()])
    print(f"\nIndic average: {avg_old:.2f} → {avg_new:.2f} tok/word ({avg_red:.1f}% reduction)")
    if "en" in results:
        en = results["en"]
        print(f"English:       {en['old_tpw']:.2f} → {en['new_tpw']:.2f} tok/word ({en['reduction']:.1f}% change)")

    # Detailed examples
    print("\n" + "=" * 80)
    print("DETAILED EXAMPLES")
    print("=" * 80)
    for lang in ["hi", "te", "ta", "bn", "en"]:
        sent = test_data[lang][0]
        old = orig_tok.tokenize(sent)
        new = new_tok.tokenize(sent)
        print(f"\n{lang}: \"{sent}\"")
        print(f"  OLD ({len(old):>3}): {old[:15]}{'...' if len(old) > 15 else ''}")
        print(f"  NEW ({len(new):>3}): {new[:15]}{'...' if len(new) > 15 else ''}")

    with open(os.path.join(OUTPUT_DIR, "benchmark_results.json"), "w") as f:
        json.dump(results, f, indent=2)


if __name__ == "__main__":
    main()
