"""
Heavy parallel repetition scanner across 60M+ transcript rows.
Scans original transcription, native_script_text, and romanized_text for:
  1. Word-level repetition (same word/token repeated N+ times consecutively)
  2. N-gram repetition (same 2-4 word phrase repeated)
  3. Character-level stutter (same char/syllable repeated)
  4. Very short or empty text where shouldn't be
  5. Extremely long single-word tokens (encoding artifacts)

Uses multiprocessing to max out all vCPUs.
Output: parquet with flagged rows + analytics summary.
"""
from __future__ import annotations
import re
import json
import time
import sys
from pathlib import Path
from multiprocessing import Pool, cpu_count
from collections import Counter

import pyarrow.parquet as pq
import pandas as pd

INPUT = Path("final_data/transcript_variants_clean.parquet")
ORIG = Path("final_data/final_cleaned_segments.parquet")
OUTPUT_FLAGS = Path("final_data/repetition_flags.parquet")
OUTPUT_SUMMARY = Path("final_data/repetition_summary.json")

CONSEC_WORD_THRESHOLD = 5
NGRAM_REPEAT_THRESHOLD = 4
CHAR_STUTTER_THRESHOLD = 8
MIN_TEXT_LENGTH = 2
MAX_SINGLE_TOKEN_LENGTH = 200

# Pre-compiled patterns
WORD_SPLIT = re.compile(r'\S+')
CHAR_STUTTER_RE = re.compile(r'(.{1,4}?)\1{' + str(CHAR_STUTTER_THRESHOLD - 1) + r',}')
LONG_REPEAT_RE = re.compile(r'(\b\S+\b)(?:\s+\1){' + str(CONSEC_WORD_THRESHOLD - 1) + r',}', re.IGNORECASE)


def detect_consecutive_word_repeats(text: str) -> list[tuple[str, int]]:
    """Find words repeated N+ times consecutively."""
    if not text:
        return []
    words = WORD_SPLIT.findall(text)
    if len(words) < CONSEC_WORD_THRESHOLD:
        return []
    results = []
    i = 0
    while i < len(words):
        w = words[i].lower().rstrip('.,!?;:')
        count = 1
        while i + count < len(words) and words[i + count].lower().rstrip('.,!?;:') == w:
            count += 1
        if count >= CONSEC_WORD_THRESHOLD and len(w) > 0:
            results.append((w, count))
        i += count
    return results


def detect_ngram_repeats(text: str, n: int = 3) -> list[tuple[str, int]]:
    """Find n-gram phrases repeated threshold+ times."""
    if not text:
        return []
    words = WORD_SPLIT.findall(text)
    if len(words) < n * NGRAM_REPEAT_THRESHOLD:
        return []
    ngrams = []
    for i in range(len(words) - n + 1):
        ng = ' '.join(w.lower().rstrip('.,!?;:') for w in words[i:i+n])
        ngrams.append(ng)
    counts = Counter(ngrams)
    return [(ng, c) for ng, c in counts.most_common(5) if c >= NGRAM_REPEAT_THRESHOLD]


def detect_char_stutter(text: str) -> list[tuple[str, int]]:
    """Find character-level stuttering patterns."""
    if not text or len(text) < CHAR_STUTTER_THRESHOLD:
        return []
    results = []
    for m in CHAR_STUTTER_RE.finditer(text):
        pattern = m.group(1)
        full = m.group(0)
        repeat_count = len(full) // max(len(pattern), 1)
        if repeat_count >= CHAR_STUTTER_THRESHOLD and len(pattern.strip()) > 0:
            results.append((pattern, repeat_count))
    return results


def scan_row(row: dict) -> dict | None:
    """Scan a single row for repetition patterns. Returns flag dict or None."""
    flags = []
    severity = 0

    for field in ['transcription', 'native_script_text', 'romanized_text']:
        text = row.get(field, '') or ''
        if not text.strip():
            if field == 'romanized_text' and row.get('processing_route', '') != 'local_skip_fully_roman':
                if row.get('language_code', '') != 'en':
                    flags.append(f'{field}:empty')
                    severity = max(severity, 3)
            continue

        # Consecutive word repeats
        word_reps = detect_consecutive_word_repeats(text)
        for word, count in word_reps:
            flags.append(f'{field}:word_repeat:{word}x{count}')
            severity = max(severity, 5 if count >= 20 else 3 if count >= 10 else 1)

        # N-gram repeats (bigrams and trigrams)
        for n in [2, 3]:
            ng_reps = detect_ngram_repeats(text, n)
            for ng, count in ng_reps:
                flags.append(f'{field}:{n}gram_repeat:{ng}x{count}')
                severity = max(severity, 4 if count >= 10 else 2)

        # Character stutter
        stutter = detect_char_stutter(text)
        for pattern, count in stutter:
            flags.append(f'{field}:char_stutter:{repr(pattern)}x{count}')
            severity = max(severity, 5 if count >= 30 else 3 if count >= 15 else 1)

        # Extremely long single token (encoding artifact)
        words = WORD_SPLIT.findall(text)
        for w in words:
            if len(w) > MAX_SINGLE_TOKEN_LENGTH:
                flags.append(f'{field}:long_token:{len(w)}chars')
                severity = max(severity, 4)
                break

    if not flags:
        return None

    return {
        'row_id': row.get('row_id', ''),
        'video_id': row.get('video_id', ''),
        'language_code': row.get('language_code', ''),
        'processing_route': row.get('processing_route', ''),
        'severity': severity,
        'flag_count': len(flags),
        'flags': json.dumps(flags, ensure_ascii=False),
        'transcription_preview': (row.get('transcription', '') or '')[:200],
        'native_preview': (row.get('native_script_text', '') or '')[:200],
        'roman_preview': (row.get('romanized_text', '') or '')[:200],
    }


def process_chunk(args):
    """Process a chunk of rows. Called by multiprocessing pool."""
    chunk_idx, rows = args
    flagged = []
    for row in rows:
        result = scan_row(row)
        if result:
            flagged.append(result)
    return chunk_idx, flagged


def main():
    start = time.time()
    num_workers = max(cpu_count() - 1, 1)

    import duckdb
    con = duckdb.connect(':memory:')
    con.execute('SET memory_limit = "8GB"')

    total_rows = con.execute(
        "SELECT count(*) FROM read_parquet('final_data/transcript_variants_clean.parquet')"
    ).fetchone()[0]
    print(f'Total rows: {total_rows:,}', flush=True)

    # Stream in chunks of 500K via DuckDB LIMIT/OFFSET on both tables,
    # then process each chunk with multiprocessing
    CHUNK_ROWS = 500_000
    num_chunks = (total_rows + CHUNK_ROWS - 1) // CHUNK_ROWS
    print(f'Scanning with {num_workers} workers, {num_chunks} chunks of {CHUNK_ROWS:,}...', flush=True)

    all_flagged = []
    rows_scanned = 0

    for chunk_i in range(num_chunks):
        offset = chunk_i * CHUNK_ROWS
        chunk_df = con.execute(f'''
            SELECT
                v.row_id, v.video_id, v.language_code, v.processing_route,
                v.native_script_text, v.romanized_text,
                o.transcription
            FROM read_parquet('final_data/transcript_variants_clean.parquet') v
            JOIN read_parquet('final_data/final_cleaned_segments.parquet') o
                ON o.video_id || '/' || o.segment_file = v.row_id
            LIMIT {CHUNK_ROWS} OFFSET {offset}
        ''').df()
        rows = chunk_df.to_dict(orient='records')
        del chunk_df

        # Split into sub-chunks for multiprocessing
        sub_size = max(len(rows) // num_workers, 1000)
        sub_chunks = [(j, rows[j:j+sub_size]) for j in range(0, len(rows), sub_size)]

        with Pool(num_workers) as pool:
            for _, flagged in pool.imap_unordered(process_chunk, sub_chunks):
                all_flagged.extend(flagged)

        rows_scanned += len(rows)
        del rows
        elapsed = time.time() - start
        pct = rows_scanned / total_rows * 100
        print(f'  {pct:.0f}% ({rows_scanned:,}/{total_rows:,}) - {len(all_flagged)} flagged - {elapsed:.0f}s', flush=True)

    elapsed = time.time() - start
    print(f'\nScan complete in {elapsed:.0f}s', flush=True)
    print(f'Total flagged: {len(all_flagged):,} / {total_rows:,} ({len(all_flagged)/total_rows*100:.3f}%)', flush=True)

    if all_flagged:
        flag_df = pd.DataFrame(all_flagged)
        flag_df.to_parquet(OUTPUT_FLAGS, index=False)
        print(f'Wrote {OUTPUT_FLAGS}', flush=True)

        # Summary analytics
        severity_counts = flag_df['severity'].value_counts().sort_index().to_dict()
        lang_counts = flag_df['language_code'].value_counts().to_dict()
        route_counts = flag_df['processing_route'].value_counts().to_dict()

        # Parse flag types
        flag_type_counts = Counter()
        for flags_json in flag_df['flags']:
            for flag in json.loads(flags_json):
                flag_type = ':'.join(flag.split(':')[:2])
                flag_type_counts[flag_type] += 1

        summary = {
            'total_scanned': total_rows,
            'total_flagged': len(all_flagged),
            'flagged_pct': round(len(all_flagged) / total_rows * 100, 4),
            'by_severity': severity_counts,
            'by_language': lang_counts,
            'by_route': route_counts,
            'by_flag_type': dict(flag_type_counts.most_common(30)),
            'elapsed_seconds': round(elapsed, 1),
        }
        with open(OUTPUT_SUMMARY, 'w') as f:
            json.dump(summary, f, indent=2, ensure_ascii=False)
        print(f'Wrote {OUTPUT_SUMMARY}', flush=True)

        # Print summary
        print(f'\n=== SEVERITY BREAKDOWN ===')
        for sev, cnt in sorted(severity_counts.items()):
            label = {1: 'low', 2: 'medium', 3: 'high', 4: 'very_high', 5: 'critical'}.get(sev, '?')
            print(f'  severity {sev} ({label}): {cnt:,}')

        print(f'\n=== TOP FLAG TYPES ===')
        for ft, cnt in flag_type_counts.most_common(15):
            print(f'  {cnt:>8,}: {ft}')

        print(f'\n=== BY LANGUAGE (top 5) ===')
        for lang, cnt in sorted(lang_counts.items(), key=lambda x: -x[1])[:5]:
            print(f'  {lang}: {cnt:,}')
    else:
        print('No flagged rows found.')


if __name__ == '__main__':
    main()
