#!/usr/bin/env python3
"""Analyze tokenizer efficiency across 12 languages. CPU-only, no PyTorch."""

import sentencepiece as spm
import pandas as pd
import unicodedata
import random
import re

random.seed(42)

# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.Load('/workspace/maya-asr/tokenizers/stage1_prod_bpe/tokenizer.model')

vocab_size = sp.GetPieceSize()
print(f"={'='*70}")
print(f"TOKENIZER ANALYSIS: {'/workspace/maya-asr/tokenizers/stage1_prod_bpe/tokenizer.model'}")
print(f"Total vocab size: {vocab_size}")
print(f"={'='*70}\n")

# Count Indic vs Latin tokens
indic_count = 0
latin_count = 0
other_count = 0

indic_scripts = {
    'DEVANAGARI', 'BENGALI', 'TAMIL', 'TELUGU', 'GUJARATI',
    'GURMUKHI', 'KANNADA', 'MALAYALAM', 'ORIYA',
}

for i in range(vocab_size):
    piece = sp.IdToPiece(i)
    # Remove the sentencepiece underscore marker
    clean = piece.replace('\u2581', '')
    if not clean:
        other_count += 1
        continue
    # Check script of first real character
    scripts_found = set()
    for ch in clean:
        try:
            name = unicodedata.name(ch, '')
            for s in indic_scripts:
                if s in name:
                    scripts_found.add('INDIC')
                    break
            else:
                if 'LATIN' in name:
                    scripts_found.add('LATIN')
                elif name:
                    scripts_found.add('OTHER')
        except:
            scripts_found.add('OTHER')

    if 'INDIC' in scripts_found:
        indic_count += 1
    elif 'LATIN' in scripts_found:
        latin_count += 1
    else:
        other_count += 1

print(f"Vocab breakdown by script:")
print(f"  Indic script tokens: {indic_count} ({100*indic_count/vocab_size:.1f}%)")
print(f"  Latin script tokens: {latin_count} ({100*latin_count/vocab_size:.1f}%)")
print(f"  Other (special/num): {other_count} ({100*other_count/vocab_size:.1f}%)")
print()

# Load parquet - only needed columns
print("Loading training data (transcript + language columns only)...")
df = pd.read_parquet(
    '/workspace/maya-asr/artifacts/phase3/production_train_final.parquet',
    columns=['transcript', 'language']
)
print(f"Total rows: {len(df)}")
print(f"Languages found: {sorted(df['language'].unique())}")
print()

LANGUAGES = ['hi', 'bn', 'ta', 'te', 'mr', 'gu', 'kn', 'ml', 'pa', 'or', 'as', 'en']

LANG_NAMES = {
    'hi': 'Hindi', 'bn': 'Bengali', 'ta': 'Tamil', 'te': 'Telugu',
    'mr': 'Marathi', 'gu': 'Gujarati', 'kn': 'Kannada', 'ml': 'Malayalam',
    'pa': 'Punjabi', 'or': 'Odia', 'as': 'Assamese', 'en': 'English'
}

# Analysis
print(f"{'Language':<12} {'Avg tok/word':>12} {'Min':>5} {'Max':>5} {'Words sampled':>14}")
print(f"{'-'*12} {'-'*12} {'-'*5} {'-'*5} {'-'*14}")

details = {}

for lang in LANGUAGES:
    subset = df[df['language'] == lang]
    if len(subset) == 0:
        print(f"{lang:<12} {'(no data)':>12}")
        continue

    # Collect unique words
    all_words = set()
    for transcript in subset['transcript'].dropna().values:
        words = transcript.strip().split()
        for w in words:
            # Skip very short or purely punctuation
            w = w.strip()
            if len(w) >= 2:
                all_words.add(w)
            if len(all_words) >= 5000:
                break
        if len(all_words) >= 5000:
            break

    word_list = list(all_words)
    random.shuffle(word_list)
    sample_words = word_list[:100]

    token_counts = []
    examples = []

    for word in sample_words:
        pieces = sp.EncodeAsPieces(word)
        n = len(pieces)
        token_counts.append(n)
        if len(examples) < 3:
            examples.append((word, pieces, n))

    avg = sum(token_counts) / len(token_counts)
    mn = min(token_counts)
    mx = max(token_counts)

    print(f"{LANG_NAMES.get(lang, lang):<12} {avg:>12.2f} {mn:>5} {mx:>5} {len(sample_words):>14}")
    details[lang] = {
        'avg': avg, 'min': mn, 'max': mx,
        'examples': examples, 'count': len(sample_words)
    }

# Print examples
print(f"\n{'='*70}")
print("EXAMPLE TOKENIZATIONS (3 per language)")
print(f"{'='*70}")

for lang in LANGUAGES:
    if lang not in details:
        continue
    info = details[lang]
    print(f"\n--- {LANG_NAMES.get(lang, lang)} ({lang}) | avg={info['avg']:.2f} tok/word ---")
    for word, pieces, n in info['examples']:
        pieces_str = ' | '.join(pieces)
        print(f"  \"{word}\" -> [{n} tokens] {pieces_str}")

print(f"\n{'='*70}")
print("SUMMARY")
print(f"{'='*70}")
print(f"Vocab size: {vocab_size}")
print(f"Indic tokens: {indic_count}, Latin tokens: {latin_count}, Other: {other_count}")

# Rank languages by efficiency
ranked = sorted(details.items(), key=lambda x: x[1]['avg'])
print(f"\nLanguages ranked by tokenization efficiency (lower = better):")
for i, (lang, info) in enumerate(ranked, 1):
    print(f"  {i:>2}. {LANG_NAMES.get(lang, lang):<12} {info['avg']:.2f} tokens/word")
