#!/usr/bin/env python3
"""Fast vectorized language-ID fix based on Unicode script detection.

Processes 78M+ rows in minutes using NumPy vectorization instead of per-row Python.
Fixes language labels where the dominant Unicode script contradicts the labeled language.

Rules:
- Unambiguous scripts (Tamil, Telugu, Kannada, Malayalam, Gujarati, Gurmukhi, Odia):
  relabel to the script's primary language.
- Ambiguous scripts:
  - Devanagari: if labeled as non-Devanagari lang, relabel to 'hi' (most common)
  - Bengali: if labeled as non-Bengali lang, relabel to 'bn' (most common)
- English rows and rows with <6 chars of text are skipped.
"""

import sys
import time
from collections import Counter
from concurrent.futures import ProcessPoolExecutor

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq

# Script Unicode ranges
SCRIPTS = {
    'devanagari': (0x0900, 0x097F),
    'bengali':    (0x0980, 0x09FF),
    'gurmukhi':   (0x0A00, 0x0A7F),
    'gujarati':   (0x0A80, 0x0AFF),
    'odia':       (0x0B00, 0x0B7F),
    'tamil':      (0x0B80, 0x0BFF),
    'telugu':     (0x0C00, 0x0C7F),
    'kannada':    (0x0C80, 0x0CFF),
    'malayalam':  (0x0D00, 0x0D7F),
}

# Valid scripts per language
LANG_VALID_SCRIPTS = {
    'hi': {'devanagari'}, 'mr': {'devanagari'},
    'bn': {'bengali'}, 'as': {'bengali'},
    'pa': {'gurmukhi'}, 'gu': {'gujarati'}, 'or': {'odia'},
    'ta': {'tamil'}, 'te': {'telugu'}, 'kn': {'kannada'}, 'ml': {'malayalam'},
}

# Unambiguous script -> language mapping
SCRIPT_PRIMARY = {
    'gurmukhi': 'pa', 'gujarati': 'gu', 'odia': 'or',
    'tamil': 'ta', 'telugu': 'te', 'kannada': 'kn', 'malayalam': 'ml',
}

# For ambiguous scripts, map to most common language if mismatch
SCRIPT_FALLBACK = {
    'devanagari': 'hi',  # hi/mr share Devanagari, default to hi
    'bengali': 'bn',      # bn/as share Bengali, default to bn
}


def detect_dominant_script(text: str) -> str | None:
    """Detect the dominant Indic script in a text string using NumPy."""
    if not text or len(text) < 6:
        return None
    try:
        arr = np.frombuffer(text.encode('utf-32-le'), dtype=np.uint32)
    except (UnicodeEncodeError, ValueError):
        return None

    best_script = None
    best_count = 0

    for script, (lo, hi) in SCRIPTS.items():
        count = int(np.sum((arr >= lo) & (arr <= hi)))
        if count > best_count:
            best_count = count
            best_script = script

    # Require at least 2 script chars to be confident
    return best_script if best_count >= 2 else None


def process_chunk(args):
    """Process a chunk of (index, lang, transcript) tuples. Returns list of (index, new_lang)."""
    chunk_idx, langs, transcripts = args
    fixes = []
    for i in range(len(langs)):
        lang = langs[i]
        if lang == 'en' or lang not in LANG_VALID_SCRIPTS:
            continue

        text = transcripts[i]
        if text is None or len(text) < 6:
            continue

        dominant = detect_dominant_script(text)
        if dominant is None:
            continue

        valid_scripts = LANG_VALID_SCRIPTS[lang]
        if dominant in valid_scripts:
            continue  # No mismatch

        # Mismatch found - determine correct language
        if dominant in SCRIPT_PRIMARY:
            new_lang = SCRIPT_PRIMARY[dominant]
        elif dominant in SCRIPT_FALLBACK:
            # Ambiguous script - only fix if current lang doesn't use this script
            new_lang = SCRIPT_FALLBACK[dominant]
        else:
            continue

        if new_lang != lang:
            fixes.append((chunk_idx + i, lang, new_lang, dominant))

    return fixes


def main():
    input_path = 'artifacts/phase3/production_train.parquet'
    output_path = 'artifacts/phase3/production_train_final.parquet'

    print(f"Loading {input_path}...")
    t0 = time.time()
    table = pq.read_table(input_path)
    print(f"Loaded {len(table):,} rows in {time.time()-t0:.1f}s")

    langs = table.column('language').to_pylist()
    transcripts = table.column('transcript').to_pylist()

    # Filter to non-English rows for processing
    non_en_indices = [i for i, l in enumerate(langs) if l != 'en' and l in LANG_VALID_SCRIPTS]
    print(f"Processing {len(non_en_indices):,} non-English rows with {16} workers...")

    # Split into chunks for parallel processing
    CHUNK_SIZE = 100_000
    chunks = []
    for start in range(0, len(non_en_indices), CHUNK_SIZE):
        end = min(start + CHUNK_SIZE, len(non_en_indices))
        idx_slice = non_en_indices[start:end]
        chunk_langs = [langs[i] for i in idx_slice]
        chunk_texts = [transcripts[i] for i in idx_slice]
        # Pass the global start index for this chunk
        chunks.append((non_en_indices[start], chunk_langs, chunk_texts))
        # Fix: we need actual indices, not just start offset
        # Rewrite to pass actual indices

    # Rebuild chunks properly - pass actual global indices
    chunks = []
    for start in range(0, len(non_en_indices), CHUNK_SIZE):
        end = min(start + CHUNK_SIZE, len(non_en_indices))
        idx_slice = non_en_indices[start:end]
        chunk_langs = [langs[i] for i in idx_slice]
        chunk_texts = [transcripts[i] for i in idx_slice]
        chunks.append((0, chunk_langs, chunk_texts))  # chunk-local indices
        # Store global indices separately

    # Better approach: each chunk returns (local_idx, old_lang, new_lang, script)
    # We map back using non_en_indices
    all_fixes = []
    fix_log = Counter()

    t1 = time.time()
    with ProcessPoolExecutor(max_workers=16) as executor:
        chunk_data = []
        chunk_global_indices = []
        for start in range(0, len(non_en_indices), CHUNK_SIZE):
            end = min(start + CHUNK_SIZE, len(non_en_indices))
            idx_slice = non_en_indices[start:end]
            chunk_langs = [langs[i] for i in idx_slice]
            chunk_texts = [transcripts[i] for i in idx_slice]
            chunk_data.append((0, chunk_langs, chunk_texts))
            chunk_global_indices.append(idx_slice)

        for chunk_i, fixes in enumerate(executor.map(process_chunk, chunk_data, chunksize=1)):
            global_indices = chunk_global_indices[chunk_i]
            for local_idx, old_lang, new_lang, script in fixes:
                global_idx = global_indices[local_idx]
                all_fixes.append((global_idx, new_lang))
                fix_log[(old_lang, new_lang, script)] += 1

            if (chunk_i + 1) % 10 == 0:
                print(f"  Processed {(chunk_i+1)*CHUNK_SIZE:,} / {len(non_en_indices):,} rows, {len(all_fixes):,} fixes so far")

    elapsed_detect = time.time() - t1
    print(f"\nScript detection: {elapsed_detect:.1f}s, found {len(all_fixes):,} mismatches")

    # Apply fixes
    if all_fixes:
        print(f"\nApplying {len(all_fixes):,} fixes...")
        lang_array = table.column('language').to_pylist()
        for global_idx, new_lang in all_fixes:
            lang_array[global_idx] = new_lang

        # Rebuild table with fixed language column
        col_idx = table.schema.get_field_index('language')
        table = table.set_column(col_idx, 'language', pa.array(lang_array, type=pa.string()))

    # Write output
    print(f"Writing {output_path}...")
    t2 = time.time()
    pq.write_table(table, output_path, row_group_size=500_000)
    print(f"Written in {time.time()-t2:.1f}s")

    # Report
    print(f"\n{'='*60}")
    print(f"Language-ID Fix Report")
    print(f"{'='*60}")
    print(f"  Total rows:     {len(table):,}")
    print(f"  Rows checked:   {len(non_en_indices):,}")
    print(f"  Fixes applied:  {len(all_fixes):,}")
    print(f"  Total time:     {time.time()-t0:.1f}s")

    if fix_log:
        print(f"\n  Fix breakdown (from -> to [script]):")
        for (old, new, script), count in fix_log.most_common(30):
            print(f"    {old} -> {new} ({script}): {count:,}")

    # Final distribution
    vc = pc.value_counts(table.column('language'))
    print(f"\n  Final language distribution:")
    for v in vc:
        lang = v['values'].as_py()
        cnt = v['counts'].as_py()
        print(f"    {lang}: {cnt:,}")

    # Verify zero mismatches
    print(f"\n  Verifying zero mismatches on output...")
    remaining = 0
    # Quick spot-check on 500K samples
    sample_size = min(500_000, len(non_en_indices))
    rng = np.random.RandomState(42)
    check_indices = rng.choice(non_en_indices, size=sample_size, replace=False)
    out_langs = table.column('language').to_pylist()
    out_texts = table.column('transcript').to_pylist()
    for idx in check_indices:
        lang = out_langs[idx]
        if lang not in LANG_VALID_SCRIPTS:
            continue
        dom = detect_dominant_script(out_texts[idx] or "")
        if dom and dom not in LANG_VALID_SCRIPTS.get(lang, set()):
            remaining += 1

    print(f"  Spot-check: {remaining} mismatches in {sample_size:,} samples")
    print(f"  Output: {output_path}")

    return 0 if remaining == 0 else 1


if __name__ == '__main__':
    sys.exit(main())
