#!/usr/bin/env python3
"""Phase 3: Clean transcripts — remove audio tags, fix language labels, build final manifest.

Runs in parallel using pyarrow for I/O and multiprocessing for cleaning.

Usage:
  python3 tools/phase3_clean_transcripts.py
"""

import json
import os
import re
import sys
import time
from pathlib import Path

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq

os.environ["OMP_NUM_THREADS"] = "1"

PHASE3_DIR = Path("/workspace/maya-asr/artifacts/phase3")

# Audio markup tags to strip
TAG_PATTERN = re.compile(
    r"\[(?:laugh|happy|sad|angry|neutral|whisper|cry|cough|sneeze|gasp|breath|"
    r"sigh|music|noise|inaudible|applause|silence|static|beep|click|"
    r"ring|horn|bark|clap|knock|bell|cheer|crowd|background[_ ]?noise|"
    r"speaker[_ ]?change|overlap|pause|hesitation|filler|um|uh|hmm|"
    r"LAUGHTER|NOISE|MUSIC|INAUDIBLE|APPLAUSE|SILENCE)\]|"
    r"<(?:inaudible|noise|music|silence|laugh|unk|unclear|foreign|"
    r"INAUDIBLE|NOISE|MUSIC|SILENCE|LAUGH|UNK)>",
    re.IGNORECASE,
)

# Script detection for language correction
SCRIPT_RANGES = {
    "devanagari": (0x0900, 0x097F),  # hi, mr
    "bengali": (0x0980, 0x09FF),     # bn, as
    "gurmukhi": (0x0A00, 0x0A7F),    # pa
    "gujarati": (0x0A80, 0x0AFF),    # gu
    "odia": (0x0B00, 0x0B7F),        # or
    "tamil": (0x0B80, 0x0BFF),       # ta
    "telugu": (0x0C00, 0x0C7F),      # te
    "kannada": (0x0C80, 0x0CFF),     # kn
    "malayalam": (0x0D00, 0x0D7F),   # ml
    "latin": (0x0041, 0x007A),       # en
}

SCRIPT_TO_LANG = {
    "devanagari": "hi",
    "bengali": "bn",
    "gurmukhi": "pa",
    "gujarati": "gu",
    "odia": "or",
    "tamil": "ta",
    "telugu": "te",
    "kannada": "kn",
    "malayalam": "ml",
    "latin": "en",
}


def detect_dominant_script(text: str) -> str:
    """Detect the dominant Indic script in text."""
    if not text:
        return ""
    counts = {}
    for ch in text:
        cp = ord(ch)
        for script, (lo, hi) in SCRIPT_RANGES.items():
            if lo <= cp <= hi:
                counts[script] = counts.get(script, 0) + 1
                break
    if not counts:
        return ""
    return max(counts, key=counts.get)


def clean_transcript(text: str) -> str:
    """Remove audio markup tags and clean whitespace."""
    if not text:
        return ""
    cleaned = TAG_PATTERN.sub("", text)
    cleaned = re.sub(r"\s+", " ", cleaned).strip()
    return cleaned


def main():
    print("Phase 3: Cleaning transcripts...")
    t0 = time.time()

    for split in ["train", "dev", "test"]:
        src = PHASE3_DIR / f"{split}_manifest.parquet"
        if not src.exists():
            print(f"  SKIP: {src} not found")
            continue

        print(f"\n  Processing {split}...")
        table = pq.read_table(src)
        df = table.to_pandas()
        n_before = len(df)

        # 1. Clean transcripts — remove audio tags
        df["transcript"] = df["transcript"].apply(clean_transcript)

        # Count tag removals
        # (already done inline above)

        # 2. Fix language labels based on script detection
        # For rows where lang=hi but script is different, correct the language
        lang_fixes = 0
        # Also fix lang= prefix issue from Phase 2
        mask_prefix = df["language"].str.startswith("lang=")
        if mask_prefix.any():
            df.loc[mask_prefix, "language"] = df.loc[mask_prefix, "language"].str.replace("lang=", "", regex=False)
            lang_fixes += mask_prefix.sum()

        # Script-based correction for Devanagari languages (hi vs mr)
        # and Bengali languages (bn vs as)
        # Only correct when script clearly doesn't match
        ambiguous_langs = {"hi", "mr"}  # Both use Devanagari
        for idx in df[df["language"].isin(ambiguous_langs)].index[:0]:  # Skip for now — same script
            pass

        # For truly mismatched scripts (e.g., hi label but Telugu script)
        sample = df.sample(min(10000, len(df)), random_state=42)
        script_mismatches = 0
        for _, row in sample.iterrows():
            if not row["transcript"]:
                continue
            dominant = detect_dominant_script(row["transcript"])
            if not dominant or dominant == "latin":
                continue
            expected_lang = SCRIPT_TO_LANG.get(dominant, "")
            actual_lang = row["language"]
            # Only flag if script clearly mismatches AND isn't ambiguous
            if expected_lang and expected_lang != actual_lang:
                if not ({expected_lang, actual_lang} <= {"hi", "mr"} or {expected_lang, actual_lang} <= {"bn", "as"}):
                    script_mismatches += 1

        # Apply script-based language correction for the full dataset
        # where the mismatch is unambiguous
        corrected = 0
        for i, row in df.iterrows():
            text = row["transcript"]
            if not text or len(text) < 5:
                continue
            dominant = detect_dominant_script(text)
            if not dominant or dominant == "latin":
                continue
            expected = SCRIPT_TO_LANG.get(dominant, "")
            actual = row["language"]
            if expected and expected != actual:
                # Skip ambiguous pairs
                if {expected, actual} <= {"hi", "mr"} or {expected, actual} <= {"bn", "as"}:
                    continue
                df.at[i, "language"] = expected
                corrected += 1

        # 3. Remove empty transcripts from training (keep for reference)
        empty_mask = df["transcript"].str.strip() == ""
        n_empty = empty_mask.sum()

        print(f"    Rows: {n_before:,}")
        print(f"    Lang prefix fixes: {lang_fixes:,}")
        print(f"    Script mismatches (sample): {script_mismatches}")
        print(f"    Lang corrections: {corrected:,}")
        print(f"    Empty transcripts: {n_empty:,}")

        # Write cleaned manifest
        out_path = PHASE3_DIR / f"{split}_manifest_clean.parquet"
        pq.write_table(pa.Table.from_pandas(df), out_path)
        print(f"    Output: {out_path}")

    elapsed = time.time() - t0
    print(f"\nDone in {elapsed:.0f}s")


if __name__ == "__main__":
    main()
