#!/usr/bin/env python3
"""
Tokenizer Extension Pipeline: 16K → 48K (adding 32K Indic subwords)

Steps:
  1. Extract Indic transcripts from training manifest + metadata.parquet (native + code-mixed)
  2. Train a 32K SentencePiece unigram model on the corpus
  3. Merge new tokens into existing 16K tokenizer → 48K
  4. Resize model embeddings + decoder head
  5. Initialize new embeddings from byte-fallback compositions
  6. Benchmark tok/word improvement

Usage:
  python extend_tokenizer.py                    # full pipeline
  python extend_tokenizer.py --step extract     # just extract corpus
  python extend_tokenizer.py --step train       # just train SP model
  python extend_tokenizer.py --step merge       # merge + resize + benchmark
"""

import argparse
import json
import logging
import os
import random
import sys
import time
from collections import Counter, defaultdict
from pathlib import Path

import numpy as np
import pandas as pd
import sentencepiece as spm
import torch

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)

# ── Config ──────────────────────────────────────────────────────────────────
INDIC_LANGS = ["hi", "te", "ta", "ml", "bn", "gu", "kn", "pa", "mr", "or", "as"]
ORIG_MODEL_DIR = "/workspace/.hf_home/hub/models--CohereLabs--cohere-transcribe-03-2026/snapshots/90cf6a1e8427d6ab5e0060f53c095c245a20da4e"
MANIFEST_PATH = "/workspace/maya-asr/manifests/training_manifest_fixed.parquet"
DATA_ROOT = "/workspace/maya-asr/data"
OUTPUT_DIR = "/workspace/training/tokenizer_extension"
CORPUS_PATH = os.path.join(OUTPUT_DIR, "indic_corpus.txt")
SP_MODEL_PREFIX = os.path.join(OUTPUT_DIR, "indic_32k")
MERGED_TOKENIZER_DIR = os.path.join(OUTPUT_DIR, "merged_48k")

# Target vocab sizes
ORIG_VOCAB_SIZE = 16384
NEW_INDIC_TOKENS = 32000  # target Indic subwords to add
TARGET_VOCAB_SIZE = ORIG_VOCAB_SIZE + NEW_INDIC_TOKENS  # 48384, round to 48K

# SentencePiece training config
SP_VOCAB_SIZE = 36000  # train slightly more, will deduplicate down to 32K new tokens
SP_MAX_SENTENCE_LENGTH = 8192
SP_INPUT_SENTENCE_SIZE = 30_000_000  # 30M sentences for training
SP_SEED = 42

# Corpus extraction config — sample per language to balance
MAX_LINES_PER_LANG_MANIFEST = 3_000_000   # from training manifest
MAX_LINES_PER_LANG_METADATA = 500_000     # from metadata.parquet (native + mixed)


def extract_corpus():
    """Extract Indic transcripts into a single text corpus file."""
    logger.info("=" * 60)
    logger.info("STEP 1: Extracting Indic transcript corpus")
    logger.info("=" * 60)

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    lang_counts = Counter()
    total_lines = 0

    with open(CORPUS_PATH, "w", encoding="utf-8") as f:

        # ── Source 1: Training manifest (main transcripts) ──────────────
        logger.info("Reading training manifest...")
        df = pd.read_parquet(MANIFEST_PATH, columns=["language", "transcript"])
        for lang in INDIC_LANGS:
            subset = df[df["language"] == lang]["transcript"].dropna()
            # Filter empty/too-short
            subset = subset[subset.str.len() > 5]
            n = min(len(subset), MAX_LINES_PER_LANG_MANIFEST)
            if n < len(subset):
                subset = subset.sample(n=n, random_state=SP_SEED)
            for text in subset:
                text = text.strip()
                if text:
                    f.write(text + "\n")
                    lang_counts[lang] += 1
                    total_lines += 1
            logger.info(f"  {lang}: {lang_counts[lang]:,} lines from manifest")
        del df

        # ── Source 2: metadata.parquet — native transcripts ─────────────
        logger.info("Reading metadata.parquet files (native + code-mixed)...")
        for lang in INDIC_LANGS:
            lang_native_count = 0
            lang_mixed_count = 0

            # Find all metadata.parquet for this language
            paths = []
            for dataset in ["final-export/production/shards", "indicvoices", "indicvoices-r"]:
                search_dir = os.path.join(DATA_ROOT, dataset)
                if "indicvoices" == dataset:
                    lang_dir = os.path.join(search_dir, lang)
                else:
                    lang_dir = os.path.join(search_dir, f"lang={lang}")
                if os.path.isdir(lang_dir):
                    for shard_dir in os.listdir(lang_dir):
                        mp = os.path.join(lang_dir, shard_dir, "metadata.parquet")
                        if os.path.isfile(mp):
                            paths.append(mp)

            # Sample a subset of shards for efficiency
            random.seed(SP_SEED)
            if len(paths) > 100:
                paths = random.sample(paths, 100)

            for mp in paths:
                try:
                    mdf = pd.read_parquet(mp, columns=["transcription_native", "transcription_mixed"])
                except Exception:
                    continue

                # Native transcripts
                if "transcription_native" in mdf.columns:
                    for text in mdf["transcription_native"].dropna():
                        text = str(text).strip()
                        if len(text) > 5:
                            f.write(text + "\n")
                            lang_native_count += 1
                            total_lines += 1

                # Code-mixed transcripts (may differ from native)
                if "transcription_mixed" in mdf.columns:
                    for text in mdf["transcription_mixed"].dropna():
                        text = str(text).strip()
                        if len(text) > 5:
                            f.write(text + "\n")
                            lang_mixed_count += 1
                            total_lines += 1

            lang_counts[lang] += lang_native_count + lang_mixed_count
            logger.info(f"  {lang}: +{lang_native_count:,} native, +{lang_mixed_count:,} mixed from metadata")

    logger.info(f"\nCorpus written to {CORPUS_PATH}")
    logger.info(f"Total lines: {total_lines:,}")
    for lang in INDIC_LANGS:
        logger.info(f"  {lang}: {lang_counts[lang]:,}")

    # Corpus size
    size_gb = os.path.getsize(CORPUS_PATH) / (1024 ** 3)
    logger.info(f"Corpus size: {size_gb:.2f} GB")
    return CORPUS_PATH


def train_sentencepiece():
    """Train a 32K SentencePiece model on the Indic corpus."""
    logger.info("=" * 60)
    logger.info("STEP 2: Training SentencePiece model (32K Indic vocab)")
    logger.info("=" * 60)

    if not os.path.isfile(CORPUS_PATH):
        raise FileNotFoundError(f"Corpus not found at {CORPUS_PATH}. Run --step extract first.")

    # Count lines
    with open(CORPUS_PATH, "r") as f:
        n_lines = sum(1 for _ in f)
    logger.info(f"Corpus: {n_lines:,} lines")

    # Train SentencePiece
    # Using unigram model — better for agglutinative Indic languages than BPE
    # character_coverage=0.9999 to capture all Indic scripts
    logger.info(f"Training SentencePiece unigram model with vocab_size={SP_VOCAB_SIZE}...")
    t0 = time.time()

    spm.SentencePieceTrainer.train(
        input=CORPUS_PATH,
        model_prefix=SP_MODEL_PREFIX,
        vocab_size=SP_VOCAB_SIZE,
        model_type="unigram",
        character_coverage=0.9999,
        num_threads=os.cpu_count(),
        input_sentence_size=min(n_lines, SP_INPUT_SENTENCE_SIZE),
        shuffle_input_sentence=True,
        max_sentence_length=SP_MAX_SENTENCE_LENGTH,
        seed_sentencepiece_size=1_000_000,
        split_digits=True,
        byte_fallback=True,
        # Normalization: keep Indic scripts as-is
        normalization_rule_name="identity",
        # Don't add default special tokens — we'll handle them in the merge
        unk_id=0,
        bos_id=-1,
        eos_id=-1,
        pad_id=-1,
        train_extremely_large_corpus=True,
    )

    elapsed = time.time() - t0
    logger.info(f"SentencePiece training completed in {elapsed:.0f}s ({elapsed/60:.1f}min)")
    logger.info(f"Model saved to {SP_MODEL_PREFIX}.model")

    # Quick sanity check
    sp = spm.SentencePieceProcessor()
    sp.Load(f"{SP_MODEL_PREFIX}.model")
    logger.info(f"Trained vocab size: {sp.GetPieceSize()}")

    # Show some sample tokenizations
    test_texts = {
        "hi": "नमस्ते दुनिया यह एक परीक्षण है",
        "te": "హలో ప్రపంచం ఇది ఒక పరీక్ష",
        "ta": "வணக்கம் உலகம் இது ஒரு சோதனை",
        "ml": "ഹലോ ലോകം ഇത് ഒരു പരീക്ష",
        "bn": "হ্যালো বিশ্ব এটি একটি পরীক্ষা",
    }
    for lang, text in test_texts.items():
        pieces = sp.EncodeAsPieces(text)
        logger.info(f"  {lang}: '{text}' → {len(pieces)} tokens: {pieces}")

    return f"{SP_MODEL_PREFIX}.model"


def merge_and_resize():
    """Merge Indic SP vocab into existing tokenizer, resize embeddings."""
    logger.info("=" * 60)
    logger.info("STEP 3: Merging tokenizers (16K + 32K → 48K)")
    logger.info("=" * 60)

    os.makedirs(MERGED_TOKENIZER_DIR, exist_ok=True)

    # ── Load original tokenizer ─────────────────────────────────────────
    orig_sp = spm.SentencePieceProcessor()
    orig_sp.Load(os.path.join(ORIG_MODEL_DIR, "tokenizer.model"))
    orig_vocab_size = orig_sp.GetPieceSize()
    logger.info(f"Original vocab size: {orig_vocab_size}")

    # Build set of existing tokens
    orig_tokens = set()
    for i in range(orig_vocab_size):
        orig_tokens.add(orig_sp.IdToPiece(i))
    logger.info(f"Unique original tokens: {len(orig_tokens)}")

    # ── Load new Indic SP model ─────────────────────────────────────────
    indic_sp = spm.SentencePieceProcessor()
    indic_sp.Load(f"{SP_MODEL_PREFIX}.model")
    indic_vocab_size = indic_sp.GetPieceSize()
    logger.info(f"Indic SP vocab size: {indic_vocab_size}")

    # ── Find new tokens (not in original) ───────────────────────────────
    new_tokens = []
    new_token_scores = {}
    for i in range(indic_vocab_size):
        piece = indic_sp.IdToPiece(i)
        # Skip byte tokens, unk, control tokens
        if piece.startswith("<") and piece.endswith(">"):
            continue
        if piece in orig_tokens:
            continue
        score = indic_sp.GetScore(i)
        new_tokens.append(piece)
        new_token_scores[piece] = score

    logger.info(f"New unique Indic tokens (not in original): {len(new_tokens)}")

    # Sort by score (higher = more useful) and take top NEW_INDIC_TOKENS
    new_tokens.sort(key=lambda t: new_token_scores[t], reverse=True)
    if len(new_tokens) > NEW_INDIC_TOKENS:
        new_tokens = new_tokens[:NEW_INDIC_TOKENS]
    logger.info(f"Keeping top {len(new_tokens)} new tokens")

    # ── Analyze script distribution of new tokens ───────────────────────
    script_counts = Counter()
    for token in new_tokens:
        clean = token.replace("▁", "")
        if not clean:
            continue
        # Detect script by Unicode block
        ch = clean[0]
        cp = ord(ch)
        if 0x0900 <= cp <= 0x097F:
            script_counts["Devanagari (hi/mr)"] += 1
        elif 0x0980 <= cp <= 0x09FF:
            script_counts["Bengali (bn/as)"] += 1
        elif 0x0A00 <= cp <= 0x0A7F:
            script_counts["Gurmukhi (pa)"] += 1
        elif 0x0A80 <= cp <= 0x0AFF:
            script_counts["Gujarati (gu)"] += 1
        elif 0x0B00 <= cp <= 0x0B7F:
            script_counts["Odia (or)"] += 1
        elif 0x0B80 <= cp <= 0x0BFF:
            script_counts["Tamil (ta)"] += 1
        elif 0x0C00 <= cp <= 0x0C7F:
            script_counts["Telugu (te)"] += 1
        elif 0x0C80 <= cp <= 0x0CFF:
            script_counts["Kannada (kn)"] += 1
        elif 0x0D00 <= cp <= 0x0D7F:
            script_counts["Malayalam (ml)"] += 1
        elif 0x0000 <= cp <= 0x007F:
            script_counts["Latin/ASCII"] += 1
        else:
            script_counts["Other"] += 1

    logger.info("Script distribution of new tokens:")
    for script, count in script_counts.most_common():
        logger.info(f"  {script}: {count:,} ({100*count/len(new_tokens):.1f}%)")

    # ── Save the new token list ─────────────────────────────────────────
    token_list_path = os.path.join(OUTPUT_DIR, "new_indic_tokens.json")
    with open(token_list_path, "w", encoding="utf-8") as f:
        json.dump(new_tokens, f, ensure_ascii=False, indent=2)
    logger.info(f"New token list saved to {token_list_path}")

    # ── Extend the HuggingFace tokenizer ────────────────────────────────
    logger.info("Extending HuggingFace tokenizer...")
    from transformers import PreTrainedTokenizerFast

    tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=os.path.join(ORIG_MODEL_DIR, "tokenizer.json")
    )
    orig_hf_vocab_size = len(tokenizer)
    logger.info(f"Original HF tokenizer vocab size: {orig_hf_vocab_size}")

    # Add new tokens
    num_added = tokenizer.add_tokens(new_tokens)
    final_vocab_size = len(tokenizer)
    logger.info(f"Added {num_added} new tokens → final vocab size: {final_vocab_size}")

    # Save extended tokenizer
    tokenizer.save_pretrained(MERGED_TOKENIZER_DIR)
    logger.info(f"Extended tokenizer saved to {MERGED_TOKENIZER_DIR}")

    return new_tokens, final_vocab_size


def resize_model_embeddings(new_tokens, final_vocab_size):
    """Resize model embeddings and initialize new ones from byte-fallback."""
    logger.info("=" * 60)
    logger.info("STEP 4: Resizing model embeddings")
    logger.info("=" * 60)

    from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

    # Load original model — try native class first, then trust_remote_code
    logger.info("Loading original model...")
    try:
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            "CohereLabs/cohere-transcribe-03-2026",
            trust_remote_code=True,
            dtype=torch.float32,
        )
    except Exception:
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            ORIG_MODEL_DIR,
            trust_remote_code=True,
            dtype=torch.float32,
        )

    # Load original tokenizer (before extension) for byte-fallback init
    from transformers import PreTrainedTokenizerFast
    orig_tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=os.path.join(ORIG_MODEL_DIR, "tokenizer.json")
    )

    orig_vocab = model.config.vocab_size
    logger.info(f"Original model vocab_size: {orig_vocab}")
    logger.info(f"Target vocab_size: {final_vocab_size}")

    # ── Find embedding and output projection layers ─────────────────────
    # The model architecture varies — find them dynamically
    embed_layer = None
    output_layer = None

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Embedding):
            if module.num_embeddings == orig_vocab:
                embed_layer = (name, module)
                logger.info(f"Found embedding layer: {name} [{module.num_embeddings} x {module.embedding_dim}]")
        if isinstance(module, torch.nn.Linear):
            if module.out_features == orig_vocab:
                output_layer = (name, module)
                logger.info(f"Found output projection: {name} [{module.in_features} x {module.out_features}]")

    if embed_layer is None:
        logger.error("Could not find embedding layer!")
        return
    if output_layer is None:
        logger.warning("Could not find output projection — may be tied to embeddings")

    # ── Compute byte-fallback initialization for new tokens ─────────────
    logger.info("Computing byte-fallback initializations for new tokens...")
    embed_name, embed_mod = embed_layer
    old_embed_weight = embed_mod.weight.data.clone()
    embed_dim = embed_mod.embedding_dim

    # For each new token, encode it with the ORIGINAL tokenizer to get byte-fallback IDs
    # Then average those embeddings as initialization
    new_inits = torch.zeros(len(new_tokens), embed_dim)
    init_stats = {"avg_fallback_len": 0, "max_fallback_len": 0}

    for i, token in enumerate(new_tokens):
        # Remove the sentencepiece underscore prefix for encoding
        text = token.replace("▁", " ").strip()
        if not text:
            text = token
        # Get byte-fallback token IDs from original tokenizer
        fallback_ids = orig_tokenizer.encode(text, add_special_tokens=False)
        if len(fallback_ids) > 0:
            # Average the embeddings of the byte-fallback tokens
            fallback_embeds = old_embed_weight[fallback_ids]
            new_inits[i] = fallback_embeds.mean(dim=0)
            init_stats["avg_fallback_len"] += len(fallback_ids)
            init_stats["max_fallback_len"] = max(init_stats["max_fallback_len"], len(fallback_ids))
        else:
            # Random init from normal distribution matching existing embed stats
            new_inits[i] = torch.randn(embed_dim) * old_embed_weight.std()

    init_stats["avg_fallback_len"] /= max(len(new_tokens), 1)
    logger.info(f"Byte-fallback init stats: avg_len={init_stats['avg_fallback_len']:.1f}, "
                f"max_len={init_stats['max_fallback_len']}")

    # ── Resize embedding layer ──────────────────────────────────────────
    logger.info("Resizing embedding layer...")
    new_embed = torch.nn.Embedding(final_vocab_size, embed_dim)
    new_embed.weight.data[:orig_vocab] = old_embed_weight
    new_embed.weight.data[orig_vocab:orig_vocab + len(new_tokens)] = new_inits

    # Set on model
    _set_module(model, embed_name, new_embed)
    logger.info(f"Embedding resized: {orig_vocab} → {final_vocab_size}")

    # ── Resize output projection (decoder head) ────────────────────────
    if output_layer is not None:
        out_name, out_mod = output_layer
        old_out_weight = out_mod.weight.data.clone()
        old_out_bias = out_mod.bias.data.clone() if out_mod.bias is not None else None
        in_features = out_mod.in_features

        new_out = torch.nn.Linear(in_features, final_vocab_size, bias=old_out_bias is not None)
        new_out.weight.data[:orig_vocab] = old_out_weight

        # Initialize new output weights similarly — use byte-fallback averaged weights
        for i, token in enumerate(new_tokens):
            text = token.replace("▁", " ").strip()
            if not text:
                text = token
            fallback_ids = orig_tokenizer.encode(text, add_special_tokens=False)
            if len(fallback_ids) > 0:
                fallback_weights = old_out_weight[fallback_ids]
                new_out.weight.data[orig_vocab + i] = fallback_weights.mean(dim=0)
            else:
                new_out.weight.data[orig_vocab + i] = torch.randn(in_features) * old_out_weight.std() * 0.01

        if old_out_bias is not None:
            new_out.bias.data[:orig_vocab] = old_out_bias
            # New biases: slightly negative to avoid dominating softmax initially
            new_out.bias.data[orig_vocab:] = -2.0

        _set_module(model, out_name, new_out)
        logger.info(f"Output projection resized: {orig_vocab} → {final_vocab_size}")

    # ── Update model config ─────────────────────────────────────────────
    model.config.vocab_size = final_vocab_size
    # Update head config if present
    if hasattr(model.config, "head"):
        if isinstance(model.config.head, dict):
            model.config.head["num_classes"] = final_vocab_size
        elif hasattr(model.config.head, "num_classes"):
            model.config.head.num_classes = final_vocab_size

    # ── Save extended model ─────────────────────────────────────────────
    model_save_dir = os.path.join(OUTPUT_DIR, "extended_model")
    os.makedirs(model_save_dir, exist_ok=True)
    model.save_pretrained(model_save_dir)
    logger.info(f"Extended model saved to {model_save_dir}")

    # Also copy the extended tokenizer into the model dir
    import shutil
    for fname in os.listdir(MERGED_TOKENIZER_DIR):
        shutil.copy2(os.path.join(MERGED_TOKENIZER_DIR, fname), model_save_dir)
    logger.info(f"Extended tokenizer copied to {model_save_dir}")

    return model_save_dir


def _set_module(model, name, new_module):
    """Set a module in a model by its dot-separated name."""
    parts = name.split(".")
    parent = model
    for part in parts[:-1]:
        parent = getattr(parent, part)
    setattr(parent, parts[-1], new_module)


def benchmark_tokenizer():
    """Compare old vs new tokenizer efficiency on Indic text."""
    logger.info("=" * 60)
    logger.info("STEP 5: Benchmarking tokenizer improvement")
    logger.info("=" * 60)

    from transformers import AutoTokenizer, PreTrainedTokenizerFast

    # Load both tokenizers
    orig_tok = PreTrainedTokenizerFast(tokenizer_file=os.path.join(ORIG_MODEL_DIR, "tokenizer.json"))
    ext_tok = PreTrainedTokenizerFast(tokenizer_file=os.path.join(MERGED_TOKENIZER_DIR, "tokenizer.json"))

    logger.info(f"Original vocab: {len(orig_tok)}, Extended vocab: {len(ext_tok)}")

    # ── Test sentences ──────────────────────────────────────────────────
    test_data = {
        "hi": [
            "नमस्ते दुनिया यह एक परीक्षण है",
            "भारत एक विविधताओं से भरा देश है जहां अनेक भाषाएं बोली जाती हैं",
            "आज का समय बहुत ही डिजिटल समय हो गया है",
            "कहना चाहूँगा कि आप सर ना इस्तेमाल करें",
        ],
        "te": [
            "హలో ప్రపంచం ఇది ఒక పరీక్ష",
            "మీది హరిత గారిది రవళి గారిది బాండింగ్ కానివ్వండి",
            "మీరు చెప్పండి అసలు ఇప్పటికీ ఎలా ఉంటుంది",
            "బయటికి కదిలే పని ఉండదు బయటి పనులు మనం చేయలేము",
        ],
        "ta": [
            "வணக்கம் உலகம் இது ஒரு சோதனை",
            "ஆதவன் தமிழ் நேயர்களுக்கு வணக்கம்",
            "சிறந்த ஆளுமைகளை சந்தித்து நாம் நேர்காணல் எடுத்து வருகிறோம்",
            "அரசியல் மேடை நிகழ்ச்சியில் மீண்டும் உங்களை சந்திப்பதில்",
        ],
        "ml": [
            "ഹലോ ലോകം ഇത് ഒരു പരീക്ഷ",
            "ഇനി കളർത്തേണ്ട ശിവരാത്രിയുടെ തത്വമേ",
            "കൂട്ടത്തിലുണ്ടായിരുന്ന യുവാക്കൾ പാഞ്ഞുചെന്നു",
            "അങ്ങനെ കാർലി സുരക്ഷിതമായി തിരിച്ചുകയറുകയാണ്",
        ],
        "bn": [
            "হ্যালো বিশ্ব এটি একটি পরীক্ষা",
            "গ্যাগ কি জিনিস ভাই আসিফ ভাই আমি তো জীবনে প্রথম শুনলাম",
            "এটা যদি কোনো প্রমাণ না হয় তাহলে ভাইয়া একটা কাজ করেন",
            "আপনি কোনটা পছন্দ আপনার সেটা জিজ্ঞেস করতেছি",
        ],
        "gu": [
            "હેલો વિશ્વ આ એક પરીક્ષા છે",
            "પ્રિય શિક્ષક માટે ના એ વખતે તો શું અમારે એવું કંઈ હતું જ નહીં",
            "એમનેમ જ એ બધા સાથે એ રીતે બિહેવ કરતા હતા",
            "સવા વાસુદેવભાઈ આપણે ઉમિયા માતા મંદિર ખરું ને",
        ],
        "kn": [
            "ಹಲೋ ವಿಶ್ವ ಇದು ಒಂದು ಪರೀಕ್ಷೆ",
            "ಕರ್ನಾಟಕ ರಾಜ್ಯದಲ್ಲಿ ಅನೇಕ ಭಾಷೆಗಳನ್ನು ಮಾತನಾಡಲಾಗುತ್ತದೆ",
            "ಇಂದು ನಾವು ಒಂದು ಹೊಸ ವಿಷಯದ ಬಗ್ಗೆ ಮಾತನಾಡೋಣ",
            "ನಮಸ್ಕಾರ ಎಲ್ಲರಿಗೂ ಸ್ವಾಗತ",
        ],
        "pa": [
            "ਹੈਲੋ ਵਿਸ਼ਵ ਇਹ ਇੱਕ ਟੈਸਟ ਹੈ",
            "ਪੰਜਾਬ ਵਿੱਚ ਬਹੁਤ ਸਾਰੀਆਂ ਭਾਸ਼ਾਵਾਂ ਬੋਲੀਆਂ ਜਾਂਦੀਆਂ ਹਨ",
            "ਅੱਜ ਦਾ ਸਮਾਂ ਬਹੁਤ ਹੀ ਡਿਜੀਟਲ ਹੋ ਗਿਆ ਹੈ",
            "ਸਤਿ ਸ੍ਰੀ ਅਕਾਲ ਜੀ ਆਇਆਂ ਨੂੰ",
        ],
        "mr": [
            "हेलो विश्व हा एक चाचणी आहे",
            "महाराष्ट्र राज्यामध्ये अनेक भाषा बोलल्या जातात",
            "आज सकाळी मी बाजारात गेलो होतो",
            "नमस्कार सर्वांना स्वागत आहे",
        ],
        "or": [
            "ହେଲୋ ବିଶ୍ୱ ଏହା ଏକ ପରୀକ୍ଷା",
            "ଓଡ଼ିଶାରେ ବହୁତ ଭାଷା କୁହାଯାଏ",
            "ଆଜି ଆମେ ଏକ ନୂଆ ବିଷୟ ବାରେ ଆଲୋଚନା କରିବା",
            "ନମସ୍କାର ସମସ୍ତଙ୍କୁ ସ୍ୱାଗତ",
        ],
        "as": [
            "হেলো বিশ্ব এইটো এটা পৰীক্ষা",
            "অসমীয়া ভাষা এটা সুন্দৰ ভাষা",
            "আজি আমি বজাৰলৈ গৈছিলোঁ",
            "নমস্কাৰ সকলোকে স্বাগতম",
        ],
        "en": [
            "hello world this is a test",
            "the quick brown fox jumps over the lazy dog",
            "India is a diverse country with many languages",
            "good morning everyone welcome to the show",
        ],
        # Code-mixed examples
        "hi-en": [
            "bhai yeh video bahut interesting hai please like karo",
            "आज का weather बहुत अच्छा है let's go outside",
            "मैंने आज office में meeting attend की",
            "coding करना बहुत fun है especially Python में",
        ],
    }

    # ── Compute metrics ─────────────────────────────────────────────────
    results = {}
    print("\n" + "=" * 100)
    print(f"{'Lang':<8} {'Words':>6} {'Chars':>6} │ {'Old tok':>8} {'tok/w':>7} {'tok/ch':>7} │ "
          f"{'New tok':>8} {'tok/w':>7} {'tok/ch':>7} │ {'Reduction':>10}")
    print("=" * 100)

    for lang, sentences in test_data.items():
        total_words = 0
        total_chars = 0
        total_old_tokens = 0
        total_new_tokens = 0

        for sent in sentences:
            words = sent.split()
            total_words += len(words)
            total_chars += len(sent)
            total_old_tokens += len(orig_tok.encode(sent, add_special_tokens=False))
            total_new_tokens += len(ext_tok.encode(sent, add_special_tokens=False))

        old_tpw = total_old_tokens / total_words
        new_tpw = total_new_tokens / total_words
        old_tpc = total_old_tokens / total_chars
        new_tpc = total_new_tokens / total_chars
        reduction = (1 - new_tpw / old_tpw) * 100

        results[lang] = {
            "words": total_words,
            "chars": total_chars,
            "old_tokens": total_old_tokens,
            "new_tokens": total_new_tokens,
            "old_tok_per_word": old_tpw,
            "new_tok_per_word": new_tpw,
            "old_tok_per_char": old_tpc,
            "new_tok_per_char": new_tpc,
            "reduction_pct": reduction,
        }

        print(f"{lang:<8} {total_words:>6} {total_chars:>6} │ {total_old_tokens:>8} {old_tpw:>7.2f} {old_tpc:>7.2f} │ "
              f"{total_new_tokens:>8} {new_tpw:>7.2f} {new_tpc:>7.2f} │ {reduction:>9.1f}%")

    print("=" * 100)

    # ── Summary stats ───────────────────────────────────────────────────
    indic_results = {k: v for k, v in results.items() if k not in ("en",)}
    avg_old_tpw = np.mean([v["old_tok_per_word"] for v in indic_results.values()])
    avg_new_tpw = np.mean([v["new_tok_per_word"] for v in indic_results.values()])
    avg_reduction = np.mean([v["reduction_pct"] for v in indic_results.values()])

    print(f"\nIndic average: old={avg_old_tpw:.2f} tok/word → new={avg_new_tpw:.2f} tok/word "
          f"({avg_reduction:.1f}% reduction)")

    if "en" in results:
        en = results["en"]
        print(f"English:       old={en['old_tok_per_word']:.2f} tok/word → new={en['new_tok_per_word']:.2f} tok/word "
              f"(should be ~unchanged)")

    # ── Save results ────────────────────────────────────────────────────
    results_path = os.path.join(OUTPUT_DIR, "benchmark_results.json")
    with open(results_path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2)
    logger.info(f"Benchmark results saved to {results_path}")

    # ── Show detailed token examples ────────────────────────────────────
    print("\n" + "=" * 80)
    print("DETAILED TOKENIZATION EXAMPLES (first sentence per language)")
    print("=" * 80)
    for lang, sentences in test_data.items():
        sent = sentences[0]
        old_tokens = orig_tok.tokenize(sent)
        new_tokens_list = ext_tok.tokenize(sent)
        print(f"\n{lang}: \"{sent}\"")
        print(f"  OLD ({len(old_tokens):>3} tok): {old_tokens[:30]}{'...' if len(old_tokens) > 30 else ''}")
        print(f"  NEW ({len(new_tokens_list):>3} tok): {new_tokens_list[:30]}{'...' if len(new_tokens_list) > 30 else ''}")

    return results


def main():
    parser = argparse.ArgumentParser(description="Extend tokenizer from 16K to 48K with Indic subwords")
    parser.add_argument("--step", choices=["extract", "train", "merge", "benchmark", "all"],
                        default="all", help="Which step to run")
    args = parser.parse_args()

    if args.step in ("extract", "all"):
        extract_corpus()

    if args.step in ("train", "all"):
        train_sentencepiece()

    if args.step in ("merge", "all"):
        new_tokens, final_vocab_size = merge_and_resize()
        # Embedding resize requires model weights — run separately if weights available
        try:
            resize_model_embeddings(new_tokens, final_vocab_size)
        except (OSError, FileNotFoundError) as e:
            logger.warning(f"Model weights not available for embedding resize: {e}")
            logger.info("Tokenizer merge complete. Run embedding resize after downloading weights:")
            logger.info("  python3 extend_tokenizer.py --step resize")

    if args.step in ("benchmark", "all"):
        benchmark_tokenizer()

    logger.info("Done!")


if __name__ == "__main__":
    main()
