"""
Low-load repetition scanner for the transcript corpus.

Why this exists:
- The first scanner materialized all 60M rows in Python and got OOM-killed.
- This version streams the DuckDB join in small Arrow record batches.
- It keeps CPU/RAM bounded with configurable worker count and batch size.
- It is resumable: each completed batch writes a metadata marker and optional
  parquet of flagged rows.

Scans:
- original `transcription`
- `native_script_text`
- `romanized_text`

Strict flagging rules:
- consecutive repeated words only
- consecutive repeated 2-4 word phrases only
- repeated character chunks / stutter loops
- empty romanized/native outputs where unexpected
- extremely long single tokens
"""
from __future__ import annotations

import argparse
import json
import re
import time
from collections import Counter
from multiprocessing import Pool
from pathlib import Path

import duckdb
import pandas as pd


OUTPUT_FLAGS = Path("final_data/repetition_flags.parquet")
OUTPUT_SUMMARY = Path("final_data/repetition_summary.json")
OUTPUT_BATCH_DIR = Path("final_data/repetition_scan_batches")

CONSEC_WORD_THRESHOLD = 5
CONSEC_NGRAM_REPEAT_THRESHOLD = 3
CHAR_STUTTER_THRESHOLD = 8
MAX_SINGLE_TOKEN_LENGTH = 200

WORD_SPLIT = re.compile(r"\S+")
CHAR_STUTTER_RE = re.compile(r"(.{1,4}?)\1{" + str(CHAR_STUTTER_THRESHOLD - 1) + r",}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Streaming repetition scanner")
    parser.add_argument("--workers", type=int, default=4)
    parser.add_argument("--batch-rows", type=int, default=100000)
    parser.add_argument("--duckdb-threads", type=int, default=2)
    parser.add_argument("--max-batches", type=int, default=0, help="0 = scan all batches")
    parser.add_argument("--output-batch-dir", type=Path, default=OUTPUT_BATCH_DIR)
    return parser.parse_args()


def normalize_token(token: str) -> str:
    return token.lower().rstrip(".,!?;:'\"()[]{}")


def detect_consecutive_word_repeats(text: str) -> list[tuple[str, int]]:
    if not text:
        return []
    words = WORD_SPLIT.findall(text)
    if len(words) < CONSEC_WORD_THRESHOLD:
        return []

    results: list[tuple[str, int]] = []
    i = 0
    while i < len(words):
        word = normalize_token(words[i])
        count = 1
        while i + count < len(words):
            other = normalize_token(words[i + count])
            if other != word:
                break
            count += 1
        if word and count >= CONSEC_WORD_THRESHOLD:
            results.append((word, count))
        i += count
    return results


def detect_consecutive_ngram_repeats(text: str, n: int) -> list[tuple[str, int]]:
    """
    Only flag phrases repeated back-to-back, e.g.
    "तो क्या तो क्या तो क्या" or "go go go go".

    This avoids false positives like "रहा है" appearing multiple times across
    different clauses, which is normal grammar and not a hallucination.
    """
    if not text:
        return []
    words = WORD_SPLIT.findall(text)
    if len(words) < n * CONSEC_NGRAM_REPEAT_THRESHOLD:
        return []

    normalized = [normalize_token(w) for w in words]
    results: list[tuple[str, int]] = []
    max_start = len(normalized) - n

    i = 0
    while i <= max_start:
        phrase_tokens = normalized[i : i + n]
        if any(not t for t in phrase_tokens):
            i += 1
            continue
        phrase = " ".join(phrase_tokens)
        count = 1
        j = i + n
        while j + n <= len(normalized) and normalized[j : j + n] == phrase_tokens:
            count += 1
            j += n
        if count >= CONSEC_NGRAM_REPEAT_THRESHOLD:
            results.append((phrase, count))
            i = j
        else:
            i += 1
    return results


def detect_char_stutter(text: str) -> list[tuple[str, int]]:
    if not text or len(text) < CHAR_STUTTER_THRESHOLD:
        return []
    results: list[tuple[str, int]] = []
    for match in CHAR_STUTTER_RE.finditer(text):
        pattern = match.group(1)
        full = match.group(0)
        repeat_count = len(full) // max(len(pattern), 1)
        if pattern.strip() and repeat_count >= CHAR_STUTTER_THRESHOLD:
            results.append((pattern, repeat_count))
    return results


def scan_row(row: dict) -> dict | None:
    flags: list[str] = []
    severity = 0

    for field in ("transcription", "native_script_text", "romanized_text"):
        text = row.get(field, "") or ""

        if not text.strip():
            if field == "romanized_text" and row.get("language_code") != "en":
                flags.append(f"{field}:empty")
                severity = max(severity, 3)
            continue

        for word, count in detect_consecutive_word_repeats(text):
            flags.append(f"{field}:word_repeat:{word}x{count}")
            severity = max(severity, 5 if count >= 20 else 3 if count >= 10 else 1)

        for n in (2, 3, 4):
            for ngram, count in detect_consecutive_ngram_repeats(text, n):
                flags.append(f"{field}:consecutive_{n}gram_repeat:{ngram}x{count}")
                severity = max(severity, 5 if count >= 8 else 4 if count >= 5 else 2)

        for pattern, count in detect_char_stutter(text):
            flags.append(f"{field}:char_stutter:{repr(pattern)}x{count}")
            severity = max(severity, 5 if count >= 30 else 3 if count >= 15 else 1)

        for token in WORD_SPLIT.findall(text):
            if len(token) > MAX_SINGLE_TOKEN_LENGTH:
                flags.append(f"{field}:long_token:{len(token)}chars")
                severity = max(severity, 4)
                break

    if not flags:
        return None

    return {
        "row_id": row.get("row_id", ""),
        "video_id": row.get("video_id", ""),
        "language_code": row.get("language_code", ""),
        "processing_route": row.get("processing_route", ""),
        "severity": severity,
        "flag_count": len(flags),
        "flags": json.dumps(flags, ensure_ascii=False),
        "transcription_preview": (row.get("transcription", "") or "")[:200],
        "native_preview": (row.get("native_script_text", "") or "")[:200],
        "roman_preview": (row.get("romanized_text", "") or "")[:200],
    }


def process_rows(rows: list[dict]) -> list[dict]:
    flagged: list[dict] = []
    for row in rows:
        result = scan_row(row)
        if result is not None:
            flagged.append(result)
    return flagged


def split_rows(rows: list[dict], num_parts: int) -> list[list[dict]]:
    if not rows:
        return []
    part_size = max(len(rows) // num_parts, 1)
    return [rows[i : i + part_size] for i in range(0, len(rows), part_size)]


def batch_meta_path(batch_dir: Path, batch_idx: int) -> Path:
    return batch_dir / f"batch_{batch_idx:06d}.json"


def batch_flags_path(batch_dir: Path, batch_idx: int) -> Path:
    return batch_dir / f"batch_{batch_idx:06d}.parquet"


def aggregate_outputs(batch_dir: Path, total_scanned: int, elapsed: float, args: argparse.Namespace) -> None:
    meta_files = sorted(batch_dir.glob("batch_*.json"))
    if not meta_files:
        print("No batch metadata found to aggregate.", flush=True)
        return

    all_meta = [json.loads(path.read_text()) for path in meta_files]
    total_flagged = sum(item["flagged_rows"] for item in all_meta)
    processed_rows = sum(item["rows_scanned"] for item in all_meta)

    parquet_files = sorted(batch_dir.glob("batch_*.parquet"))
    if parquet_files:
        flags_df = pd.concat((pd.read_parquet(path) for path in parquet_files), ignore_index=True)
        flags_df.to_parquet(OUTPUT_FLAGS, index=False)

        severity_counts = flags_df["severity"].value_counts().sort_index().to_dict()
        lang_counts = flags_df["language_code"].value_counts().to_dict()
        route_counts = flags_df["processing_route"].value_counts().to_dict()
        flag_type_counts: Counter[str] = Counter()

        for flags_json in flags_df["flags"]:
            for flag in json.loads(flags_json):
                flag_type_counts[":".join(flag.split(":")[:2])] += 1
    else:
        severity_counts = {}
        lang_counts = {}
        route_counts = {}
        flag_type_counts = Counter()

    summary = {
        "total_scanned": processed_rows,
        "total_flagged": total_flagged,
        "flagged_pct": round(total_flagged / max(processed_rows, 1) * 100, 6),
        "workers": args.workers,
        "duckdb_threads": args.duckdb_threads,
        "batch_rows": args.batch_rows,
        "elapsed_seconds": round(elapsed, 1),
        "batch_count": len(meta_files),
        "by_severity": severity_counts,
        "by_language": lang_counts,
        "by_route": route_counts,
        "by_flag_type": dict(flag_type_counts.most_common(30)),
    }
    OUTPUT_SUMMARY.write_text(json.dumps(summary, indent=2, ensure_ascii=False))

    print(f"Wrote {OUTPUT_FLAGS}", flush=True)
    print(f"Wrote {OUTPUT_SUMMARY}", flush=True)
    print(f"Aggregated {len(meta_files)} batches, {processed_rows:,} rows, {total_flagged:,} flags", flush=True)
    if severity_counts:
        print("\n=== Severity ===", flush=True)
        for severity, count in sorted(severity_counts.items()):
            print(f"  {severity}: {count:,}", flush=True)
        print("\n=== Top Flag Types ===", flush=True)
        for flag_type, count in flag_type_counts.most_common(15):
            print(f"  {count:>8,}: {flag_type}", flush=True)


def main() -> None:
    args = parse_args()
    start = time.time()
    args.output_batch_dir.mkdir(parents=True, exist_ok=True)

    con = duckdb.connect(":memory:")
    con.execute(f"SET memory_limit = '8GB'")
    con.execute(f"SET threads TO {args.duckdb_threads}")

    total_rows = con.execute(
        "SELECT count(*) FROM read_parquet('final_data/transcript_variants_clean.parquet')"
    ).fetchone()[0]
    print(f"Total rows: {total_rows:,}", flush=True)
    print(
        f"Streaming with workers={args.workers}, duckdb_threads={args.duckdb_threads}, "
        f"batch_rows={args.batch_rows:,}",
        flush=True,
    )

    query = """
        SELECT
            v.row_id,
            v.video_id,
            v.language_code,
            v.processing_route,
            v.native_script_text,
            v.romanized_text,
            o.transcription
        FROM read_parquet('final_data/transcript_variants_clean.parquet') v
        JOIN read_parquet('final_data/final_cleaned_segments.parquet') o
            ON o.video_id || '/' || o.segment_file = v.row_id
    """

    reader = con.execute(query).fetch_record_batch(rows_per_batch=args.batch_rows)

    rows_scanned = 0
    batch_idx = 0

    with Pool(args.workers) as pool:
        for batch in reader:
            batch_idx += 1
            meta_path = batch_meta_path(args.output_batch_dir, batch_idx)
            flags_path = batch_flags_path(args.output_batch_dir, batch_idx)

            if meta_path.exists():
                meta = json.loads(meta_path.read_text())
                rows_scanned += meta["rows_scanned"]
                print(
                    f"  batch {batch_idx}: resume-skip {rows_scanned:,}/{total_rows:,} "
                    f"({rows_scanned / total_rows * 100:.1f}%)",
                    flush=True,
                )
                if args.max_batches and batch_idx >= args.max_batches:
                    break
                continue

            batch_rows = batch.to_pylist()
            sub_batches = split_rows(batch_rows, args.workers)
            flagged_parts = pool.map(process_rows, sub_batches)
            batch_flagged: list[dict] = []
            for flagged in flagged_parts:
                batch_flagged.extend(flagged)

            rows_scanned += len(batch_rows)
            if batch_flagged:
                pd.DataFrame(batch_flagged).to_parquet(flags_path, index=False)
            meta_path.write_text(
                json.dumps(
                    {
                        "batch_idx": batch_idx,
                        "rows_scanned": len(batch_rows),
                        "flagged_rows": len(batch_flagged),
                        "flags_file": flags_path.name if batch_flagged else None,
                    },
                    ensure_ascii=False,
                )
            )
            elapsed = time.time() - start
            print(
                f"  batch {batch_idx}: scanned {rows_scanned:,}/{total_rows:,} "
                f"({rows_scanned / total_rows * 100:.1f}%) | "
                f"flagged {len(batch_flagged):,} in batch | {elapsed:.0f}s",
                flush=True,
            )

            if args.max_batches and batch_idx >= args.max_batches:
                break

    elapsed = time.time() - start
    print(f"\nScan complete in {elapsed:.0f}s", flush=True)
    aggregate_outputs(args.output_batch_dir, rows_scanned, elapsed, args)


if __name__ == "__main__":
    main()
