#!/usr/bin/env python3
"""
Rigorous Tokenizer Benchmark: Gemma-3-270M vs Qwen3-ASR-1.7B
Tests: efficiency, speed, vocabulary coverage, multilingual handling,
       edge cases, roundtrip fidelity, and token distribution analysis.
"""

import time
import json
import sys
import os
import statistics
import unicodedata
from collections import Counter
from transformers import AutoTokenizer

# ─── Configuration ───────────────────────────────────────────────────────────

MODELS = {
    "gemma3-270m": "google/gemma-3-270m",
    "qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
}

# ─── Test Data ───────────────────────────────────────────────────────────────

# English texts of varying complexity
ENGLISH_TEXTS = {
    "simple": "The quick brown fox jumps over the lazy dog.",
    "technical": "The transformer architecture uses multi-head self-attention mechanisms with scaled dot-product attention, layer normalization, and residual connections.",
    "code": 'def fibonacci(n: int) -> int:\n    """Return the nth Fibonacci number."""\n    if n <= 1:\n        return n\n    return fibonacci(n - 1) + fibonacci(n - 2)',
    "numbers": "In 2024, the GDP grew by 3.14159% to $21,427,700,000,000. Temperature was -40.5°C.",
    "punctuation_heavy": 'He said, "Wait—no! Don\'t do that!!!" She replied: "Why not?" [silence] ...then (nothing).',
    "urls_emails": "Visit https://www.example.com/path?key=value&foo=bar or email user@domain.co.uk for info.",
    "long_paragraph": (
        "Artificial intelligence has transformed numerous industries, from healthcare diagnostics "
        "to autonomous vehicle navigation. Large language models, in particular, have demonstrated "
        "remarkable capabilities in natural language understanding, generation, and reasoning. "
        "These models are trained on vast corpora of text data, learning statistical patterns "
        "that enable them to produce coherent and contextually appropriate responses. The "
        "architecture underlying most modern LLMs is the Transformer, which relies on self-attention "
        "mechanisms to capture long-range dependencies in sequential data. Fine-tuning these "
        "pre-trained models on domain-specific datasets has become a standard practice for "
        "achieving state-of-the-art performance on specialized tasks."
    ),
}

# Multilingual / Indic texts
MULTILINGUAL_TEXTS = {
    "hindi": "भारत एक विविधताओं से भरा देश है। यहाँ की संस्कृति, भाषा और परंपराएँ अद्वितीय हैं।",
    "hindi_technical": "स्वचालित वाक् पहचान प्रणाली ध्वनि संकेतों को पाठ में परिवर्तित करती है।",
    "tamil": "தமிழ் உலகின் மிகப் பழமையான மொழிகளில் ஒன்றாகும்.",
    "bengali": "বাংলাদেশ দক্ষিণ এশিয়ার একটি দেশ। এর রাজধানী ঢাকা।",
    "telugu": "తెలుగు భాషా దక్షిణ భారతదేశంలో ఒక ప్రధాన భాష.",
    "marathi": "महाराष्ट्र हे भारतातील एक राज्य आहे. मुंबई ही त्याची राजधानी आहे.",
    "gujarati": "ગુજરાત ભારતનું એક રાજ્ય છે. તેની રાજધાની ગાંધીનગર છે.",
    "kannada": "ಕರ್ನಾಟಕ ದಕ್ಷಿಣ ಭಾರತದ ಒಂದು ರಾಜ್ಯ. ಬೆಂಗಳೂರು ಅದರ ರಾಜಧಾನಿ.",
    "malayalam": "കേരളം ഇന്ത്യയിലെ ഒരു സംസ്ഥാനമാണ്. തിരുവനന്തപുരം ആണ് തലസ്ഥാനം.",
    "odia": "ଓଡ଼ିଶା ଭାରତର ଏକ ରାଜ୍ୟ। ଭୁବନେଶ୍ୱର ଏହାର ରାଜଧାନୀ।",
    "punjabi": "ਪੰਜਾਬ ਭਾਰਤ ਦਾ ਇੱਕ ਰਾਜ ਹੈ। ਚੰਡੀਗੜ੍ਹ ਇਸ ਦੀ ਰਾਜਧਾਨੀ ਹੈ।",
    "chinese": "人工智能技术正在改变世界各地的产业格局。",
    "japanese": "日本語の自然言語処理は、形態素解析から始まります。",
    "arabic": "الذكاء الاصطناعي يغير العالم بسرعة كبيرة.",
    "korean": "한국어 자연어 처리는 형태소 분석에서 시작됩니다.",
    "mixed_hindi_english": "AI technology ने भारत में healthcare को transform किया है। Machine learning models अब Hindi speech recognition में भी काम करते हैं।",
    "mixed_code_switching": "मैंने yesterday एक new model train किया, accuracy 95.3% आई with batch_size=32.",
}

# Edge cases
EDGE_CASES = {
    "empty": "",
    "single_char": "a",
    "single_unicode": "अ",
    "whitespace_only": "   \t\n  ",
    "repeated_char": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
    "emoji": "🎉🚀🤖💡🔥 AI is amazing! 🌍🎯✨",
    "emoji_complex": "👨‍👩‍👧‍👦 Family emoji, 🏳️‍🌈 flag, 👩🏽‍💻 person",
    "special_tokens": "<s> </s> <pad> <unk> <mask> [CLS] [SEP] [PAD]",
    "html_xml": '<div class="container"><p>Hello &amp; welcome</p></div>',
    "json_str": '{"model": "gemma", "params": 270000000, "layers": [12, 24]}',
    "math": "∫₀^∞ e^(-x²) dx = √π/2, ∑_{n=1}^{∞} 1/n² = π²/6",
    "zalgo": "H̸̡̪̯ͨ͊̽̅̾̎E̮̟͈̣̖̰̤̤ C̷̙̲̝͖OME̶̡̛S",
    "long_word": "Pneumonoultramicroscopicsilicovolcanoconiosis",
    "camelCase": "thisIsACamelCaseVariableNameThatIsVeryLong",
    "snake_case": "this_is_a_snake_case_variable_name_that_is_very_long",
    "numbers_formats": "1,234,567.89 -42 0xFF 0b1010 1e-10 3.14159265358979",
    "newlines_tabs": "line1\nline2\n\nline4\ttabbed\t\there",
    "null_and_control": "text\x00with\x01control\x02chars",
    "rtl_mixed": "Hello مرحبا World عالم",
    "devanagari_conjuncts": "क्ष त्र ज्ञ श्र द्ध क्त स्त्र",
}

# ASR-specific texts (since Qwen3 is ASR-focused)
ASR_TEXTS = {
    "disfluent": "So um I was like uh going to the uh store and um yeah it was like really um crowded.",
    "filler_words": "you know what I mean like basically essentially actually literally",
    "transcript_style": "okay so first we need to uh open the file and then we can um process the data right",
    "noisy_transcript": "I [unintelligible] went to the [noise] store yesterday [laughter] and bought groceries",
    "timestamps": "[0.0-1.5] Hello everyone [1.5-3.0] welcome to the presentation [3.0-5.5] today we will discuss",
    "hindi_transcript": "नमस्ते आज हम बात करेंगे artificial intelligence के बारे में जो कि बहुत important topic है",
}


def load_tokenizers():
    """Load both tokenizers."""
    tokenizers = {}
    for name, model_id in MODELS.items():
        print(f"Loading tokenizer: {name} ({model_id})...")
        try:
            tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
            tokenizers[name] = tok
            print(f"  ✓ Loaded. Vocab size: {tok.vocab_size:,}")
        except Exception as e:
            print(f"  ✗ Failed: {e}")
            sys.exit(1)
    return tokenizers


def print_header(title):
    print(f"\n{'='*80}")
    print(f"  {title}")
    print(f"{'='*80}")


def print_subheader(title):
    print(f"\n--- {title} ---")


# ─── Test 1: Vocabulary Analysis ─────────────────────────────────────────────

def test_vocabulary_analysis(tokenizers):
    print_header("TEST 1: VOCABULARY ANALYSIS")

    for name, tok in tokenizers.items():
        print(f"\n[{name}]")
        vocab = tok.get_vocab()
        vocab_size = len(vocab)
        print(f"  Vocabulary size:        {vocab_size:,}")
        print(f"  Reported vocab_size:    {tok.vocab_size:,}")

        # Analyze token types
        byte_tokens = sum(1 for t in vocab if t.startswith("<0x") or t.startswith("\\x"))
        special_tokens = set()
        if hasattr(tok, 'all_special_tokens'):
            special_tokens = set(tok.all_special_tokens)
        print(f"  Special tokens:         {len(special_tokens)}")
        if special_tokens:
            shown = list(special_tokens)[:10]
            print(f"    Examples: {shown}")
        print(f"  Byte-level tokens:      {byte_tokens}")

        # Token length distribution
        token_lengths = [len(t) for t in vocab if t not in special_tokens]
        if token_lengths:
            print(f"  Token length (chars):   min={min(token_lengths)}, max={max(token_lengths)}, "
                  f"mean={statistics.mean(token_lengths):.2f}, median={statistics.median(token_lengths):.1f}")

        # Script coverage analysis
        scripts = Counter()
        for token in vocab:
            for ch in token:
                try:
                    script = unicodedata.name(ch, "UNKNOWN").split()[0]
                    scripts[script] += 1
                except:
                    scripts["UNKNOWN"] += 1
        top_scripts = scripts.most_common(15)
        print(f"  Top Unicode scripts in vocab:")
        for script, count in top_scripts:
            print(f"    {script:20s}: {count:>8,}")


# ─── Test 2: Tokenization Efficiency ─────────────────────────────────────────

def test_efficiency(tokenizers):
    print_header("TEST 2: TOKENIZATION EFFICIENCY (tokens per text)")

    all_text_groups = {
        "English": ENGLISH_TEXTS,
        "Multilingual": MULTILINGUAL_TEXTS,
        "Edge Cases": EDGE_CASES,
        "ASR-Specific": ASR_TEXTS,
    }

    results = {name: {} for name in tokenizers}

    for group_name, texts in all_text_groups.items():
        print_subheader(group_name)
        print(f"  {'Text':<25s}", end="")
        for name in tokenizers:
            print(f"  {name:>16s}", end="")
        print(f"  {'Ratio':>10s}  {'Chars':>6s}")
        print(f"  {'-'*25}", end="")
        for _ in tokenizers:
            print(f"  {'-'*16}", end="")
        print(f"  {'-'*10}  {'-'*6}")

        for text_name, text in texts.items():
            if not text:
                continue
            char_count = len(text)
            token_counts = {}
            for name, tok in tokenizers.items():
                ids = tok.encode(text, add_special_tokens=False)
                token_counts[name] = len(ids)

            names = list(tokenizers.keys())
            if token_counts[names[1]] > 0:
                ratio = token_counts[names[0]] / token_counts[names[1]]
            else:
                ratio = float('inf')

            print(f"  {text_name:<25s}", end="")
            for name in tokenizers:
                count = token_counts[name]
                # chars per token
                cpt = char_count / count if count > 0 else 0
                print(f"  {count:>8d} ({cpt:.1f})", end="")
            print(f"  {ratio:>9.3f}x  {char_count:>6d}")

            for name in tokenizers:
                results[name][f"{group_name}/{text_name}"] = token_counts[name]

    # Summary statistics
    print_subheader("Efficiency Summary")
    for name in tokenizers:
        counts = [v for v in results[name].values() if v > 0]
        if counts:
            print(f"  [{name}] Total tokens across all tests: {sum(counts):,}, "
                  f"Mean: {statistics.mean(counts):.1f}, Median: {statistics.median(counts):.1f}")

    return results


# ─── Test 3: Speed Benchmark ─────────────────────────────────────────────────

def test_speed(tokenizers):
    print_header("TEST 3: ENCODING/DECODING SPEED")

    test_texts = {
        "short (45 chars)": ENGLISH_TEXTS["simple"],
        "medium (200 chars)": ENGLISH_TEXTS["technical"],
        "long (700 chars)": ENGLISH_TEXTS["long_paragraph"],
        "hindi (100 chars)": MULTILINGUAL_TEXTS["hindi"],
        "code (180 chars)": ENGLISH_TEXTS["code"],
        "mixed hi-en": MULTILINGUAL_TEXTS["mixed_hindi_english"],
    }

    N_WARMUP = 50
    N_ITER = 500

    for text_name, text in test_texts.items():
        print_subheader(f"Text: {text_name}")
        print(f"  {'Model':<20s} {'Encode(ms)':>12s} {'Decode(ms)':>12s} {'Tok/s(enc)':>12s} {'Roundtrip':>10s}")

        for name, tok in tokenizers.items():
            # Warmup
            for _ in range(N_WARMUP):
                ids = tok.encode(text, add_special_tokens=False)
                tok.decode(ids)

            # Encode benchmark
            encode_times = []
            for _ in range(N_ITER):
                start = time.perf_counter()
                ids = tok.encode(text, add_special_tokens=False)
                encode_times.append(time.perf_counter() - start)

            # Decode benchmark
            decode_times = []
            for _ in range(N_ITER):
                start = time.perf_counter()
                tok.decode(ids)
                decode_times.append(time.perf_counter() - start)

            enc_median = statistics.median(encode_times) * 1000
            dec_median = statistics.median(decode_times) * 1000
            tokens_per_sec = len(ids) / statistics.median(encode_times) if statistics.median(encode_times) > 0 else 0

            # Roundtrip check
            decoded = tok.decode(ids, skip_special_tokens=True)
            roundtrip_ok = "✓" if decoded.strip() == text.strip() else "≈"

            print(f"  {name:<20s} {enc_median:>11.4f} {dec_median:>11.4f} {tokens_per_sec:>11,.0f} {roundtrip_ok:>10s}")


# ─── Test 4: Roundtrip Fidelity ──────────────────────────────────────────────

def test_roundtrip_fidelity(tokenizers):
    print_header("TEST 4: ROUNDTRIP FIDELITY (encode → decode)")

    all_texts = {}
    all_texts.update(ENGLISH_TEXTS)
    all_texts.update(MULTILINGUAL_TEXTS)
    all_texts.update(EDGE_CASES)
    all_texts.update(ASR_TEXTS)

    for name, tok in tokenizers.items():
        print_subheader(f"Model: {name}")
        failures = []
        for text_name, text in all_texts.items():
            if not text:
                continue
            try:
                ids = tok.encode(text, add_special_tokens=False)
                decoded = tok.decode(ids, skip_special_tokens=True)
                # Normalize whitespace for comparison
                orig_norm = " ".join(text.split())
                dec_norm = " ".join(decoded.split())
                if orig_norm != dec_norm:
                    failures.append((text_name, text[:60], decoded[:60]))
            except Exception as e:
                failures.append((text_name, text[:60], f"ERROR: {e}"))

        if failures:
            print(f"  Roundtrip mismatches: {len(failures)}/{len(all_texts)}")
            for tname, orig, dec in failures[:10]:
                print(f"    [{tname}]")
                print(f"      Original: {repr(orig)}")
                print(f"      Decoded:  {repr(dec)}")
        else:
            print(f"  All {len(all_texts)} texts roundtrip perfectly! ✓")


# ─── Test 5: Batch Throughput ─────────────────────────────────────────────────

def test_batch_throughput(tokenizers):
    print_header("TEST 5: BATCH THROUGHPUT")

    # Create a batch of 100 sentences
    batch = []
    for texts in [ENGLISH_TEXTS, MULTILINGUAL_TEXTS, ASR_TEXTS]:
        for text in texts.values():
            if text:
                batch.append(text)
    # Repeat to get ~100 items
    while len(batch) < 100:
        batch = batch + batch
    batch = batch[:100]

    total_chars = sum(len(t) for t in batch)
    print(f"  Batch size: {len(batch)} texts, {total_chars:,} total chars")

    N_ITER = 20

    for name, tok in tokenizers.items():
        # Warmup
        for _ in range(3):
            for text in batch:
                tok.encode(text, add_special_tokens=False)

        times = []
        total_tokens = 0
        for _ in range(N_ITER):
            start = time.perf_counter()
            for text in batch:
                ids = tok.encode(text, add_special_tokens=False)
                total_tokens += len(ids)
            times.append(time.perf_counter() - start)

        avg_tokens = total_tokens // N_ITER
        med_time = statistics.median(times)
        print(f"  [{name}] Median batch time: {med_time*1000:.2f}ms, "
              f"Tokens/batch: {avg_tokens:,}, "
              f"Throughput: {total_chars/med_time:,.0f} chars/s, "
              f"{avg_tokens/med_time:,.0f} tokens/s")


# ─── Test 6: Token Distribution Analysis ─────────────────────────────────────

def test_token_distribution(tokenizers):
    print_header("TEST 6: TOKEN DISTRIBUTION ANALYSIS")

    # Combine all texts
    corpus = ""
    for texts in [ENGLISH_TEXTS, MULTILINGUAL_TEXTS, ASR_TEXTS]:
        for text in texts.values():
            if text:
                corpus += text + " "

    for name, tok in tokenizers.items():
        print_subheader(f"Model: {name}")
        ids = tok.encode(corpus, add_special_tokens=False)

        # Token frequency
        token_freq = Counter(ids)
        tokens_decoded = [(tok.decode([tid]), count) for tid, count in token_freq.most_common(15)]

        print(f"  Total tokens: {len(ids):,}")
        print(f"  Unique tokens used: {len(token_freq):,} / {tok.vocab_size:,} ({100*len(token_freq)/tok.vocab_size:.2f}%)")
        print(f"  Top 15 most frequent tokens:")
        for token_str, count in tokens_decoded:
            print(f"    {repr(token_str):>25s}: {count:>4d} ({100*count/len(ids):.1f}%)")

        # Token length distribution (in characters)
        token_char_lengths = []
        for tid in ids:
            decoded = tok.decode([tid])
            token_char_lengths.append(len(decoded))

        print(f"  Token character lengths: min={min(token_char_lengths)}, max={max(token_char_lengths)}, "
              f"mean={statistics.mean(token_char_lengths):.2f}")


# ─── Test 7: Indic Language Deep Dive ─────────────────────────────────────────

def test_indic_deep_dive(tokenizers):
    print_header("TEST 7: INDIC LANGUAGE TOKENIZATION DEEP DIVE")

    indic_texts = {k: v for k, v in MULTILINGUAL_TEXTS.items()
                   if k in ["hindi", "hindi_technical", "tamil", "bengali", "telugu",
                            "marathi", "gujarati", "kannada", "malayalam", "odia", "punjabi",
                            "mixed_hindi_english", "mixed_code_switching"]}

    print(f"\n  {'Language':<22s}", end="")
    for name in tokenizers:
        print(f"  {'Tokens':>8s} {'Ch/Tok':>7s}", end="")
    print(f"  {'Winner':>12s}")

    for text_name, text in indic_texts.items():
        char_count = len(text)
        print(f"  {text_name:<22s}", end="")
        token_counts = {}
        for name, tok in tokenizers.items():
            ids = tok.encode(text, add_special_tokens=False)
            token_counts[name] = len(ids)
            cpt = char_count / len(ids) if len(ids) > 0 else 0
            print(f"  {len(ids):>8d} {cpt:>7.2f}", end="")

        names = list(tokenizers.keys())
        winner = names[0] if token_counts[names[0]] <= token_counts[names[1]] else names[1]
        print(f"  {winner:>12s}")

    # Show actual tokenization for Hindi
    print_subheader("Hindi Tokenization Detail")
    hindi = MULTILINGUAL_TEXTS["hindi"]
    for name, tok in tokenizers.items():
        ids = tok.encode(hindi, add_special_tokens=False)
        tokens = [tok.decode([tid]) for tid in ids]
        print(f"\n  [{name}] ({len(ids)} tokens)")
        print(f"  Text: {hindi}")
        print(f"  Tokens: {tokens}")


# ─── Test 8: Compression Ratio ───────────────────────────────────────────────

def test_compression_ratio(tokenizers):
    print_header("TEST 8: COMPRESSION RATIO (bytes per token)")

    all_groups = {
        "English": ENGLISH_TEXTS,
        "Indic": {k: v for k, v in MULTILINGUAL_TEXTS.items()
                  if k in ["hindi", "tamil", "bengali", "telugu", "marathi"]},
        "CJK": {k: v for k, v in MULTILINGUAL_TEXTS.items()
                if k in ["chinese", "japanese", "korean"]},
        "Code": {"code": ENGLISH_TEXTS["code"]},
        "Mixed": {k: v for k, v in MULTILINGUAL_TEXTS.items()
                  if "mixed" in k},
        "ASR": ASR_TEXTS,
    }

    print(f"\n  {'Category':<15s}", end="")
    for name in tokenizers:
        print(f"  {name + ' (B/tok)':>20s}", end="")
    print(f"  {'Better':>15s}")

    for group_name, texts in all_groups.items():
        combined = " ".join(t for t in texts.values() if t)
        byte_count = len(combined.encode('utf-8'))

        print(f"  {group_name:<15s}", end="")
        bpt = {}
        for name, tok in tokenizers.items():
            ids = tok.encode(combined, add_special_tokens=False)
            bpt[name] = byte_count / len(ids) if len(ids) > 0 else 0
            print(f"  {bpt[name]:>20.3f}", end="")

        names = list(tokenizers.keys())
        # Higher bytes per token = more efficient (fewer tokens needed)
        winner = names[0] if bpt[names[0]] >= bpt[names[1]] else names[1]
        print(f"  {winner:>15s}")


# ─── Test 9: Special / Unusual Input Handling ─────────────────────────────────

def test_special_inputs(tokenizers):
    print_header("TEST 9: SPECIAL INPUT HANDLING")

    tests = {
        "empty string": "",
        "single space": " ",
        "BOM": "\ufeff",
        "zero-width space": "\u200b",
        "zero-width joiner": "\u200d",
        "right-to-left mark": "\u200f",
        "null byte": "\x00",
        "very long repeat": "a" * 10000,
        "alternating scripts": "हa ि b क c",
        "combining diacritics": "a\u0301 e\u0300 o\u0302",  # á è ô
        "surrogate-safe emoji": "😀" * 50,
    }

    for name, tok in tokenizers.items():
        print_subheader(f"Model: {name}")
        for test_name, text in tests.items():
            try:
                ids = tok.encode(text, add_special_tokens=False)
                decoded = tok.decode(ids, skip_special_tokens=True)
                status = "✓"
                info = f"{len(ids)} tokens"
                if text and decoded.strip() != text.strip() and test_name not in ["null byte", "BOM", "zero-width space", "zero-width joiner", "right-to-left mark"]:
                    status = "≈"
                    info += " (lossy)"
            except Exception as e:
                status = "✗"
                info = str(e)[:60]

            print(f"  {status} {test_name:<25s}: {info}")


# ─── Test 10: Subword Segmentation Quality ───────────────────────────────────

def test_subword_quality(tokenizers):
    print_header("TEST 10: SUBWORD SEGMENTATION QUALITY")

    # Words where we can judge segmentation quality
    words = [
        "unhappiness", "unbelievable", "internationalization",
        "preprocessing", "tokenization", "transformer",
        "भारतीय", "स्वतंत्रता", "अंतरराष्ट्रीय",
        "counterproductive", "misunderstanding",
        "autobiography", "electromagnetic",
    ]

    for name, tok in tokenizers.items():
        print_subheader(f"Model: {name}")
        for word in words:
            ids = tok.encode(word, add_special_tokens=False)
            subwords = [tok.decode([tid]) for tid in ids]
            print(f"  {word:<30s} → {subwords}")


# ─── Main ────────────────────────────────────────────────────────────────────

def main():
    print("=" * 80)
    print("  TOKENIZER BENCHMARK: Gemma-3-270M vs Qwen3-ASR-1.7B")
    print("=" * 80)

    tokenizers = load_tokenizers()

    test_vocabulary_analysis(tokenizers)
    test_efficiency(tokenizers)
    test_speed(tokenizers)
    test_roundtrip_fidelity(tokenizers)
    test_batch_throughput(tokenizers)
    test_token_distribution(tokenizers)
    test_indic_deep_dive(tokenizers)
    test_compression_ratio(tokenizers)
    test_special_inputs(tokenizers)
    test_subword_quality(tokenizers)

    print_header("BENCHMARK COMPLETE")
    print("\nAll 10 test suites finished.")


if __name__ == "__main__":
    main()
