"""
Patch all Stage B final-shard metadata.parquet files in R2.

Three fixes applied to every row in every shard:
  1. Normalize Unicode decimal digits → ASCII (all text fields)
  2. Null transcription_native where script doesn't match segment_language
  3. Clean transcription_romanized: IAST→ASCII, strip leaked Indic/Arabic chars
"""
from __future__ import annotations

import hashlib
import io
import json
import logging
import os
import sys
import time
import unicodedata
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path

import boto3
import psycopg2
import pyarrow as pa
import pyarrow.parquet as pq
from botocore.config import Config
from dotenv import load_dotenv

ROOT = Path("/home/ubuntu/transcripts")
load_dotenv(ROOT / ".env")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

BUCKET = "finalsftdata"
RUN_ID = "production-20260312"

SCRIPT_RANGES: dict[str, list[tuple[int, int]]] = {
    "Devanagari": [(0x0900, 0x097F), (0xA8E0, 0xA8FF)],
    "Bengali": [(0x0980, 0x09FF)],
    "Gurmukhi": [(0x0A00, 0x0A7F)],
    "Gujarati": [(0x0A80, 0x0AFF)],
    "Oriya": [(0x0B00, 0x0B7F)],
    "Tamil": [(0x0B80, 0x0BFF)],
    "Telugu": [(0x0C00, 0x0C7F)],
    "Kannada": [(0x0C80, 0x0CFF)],
    "Malayalam": [(0x0D00, 0x0D7F)],
}

LANG_TO_TARGET: dict[str, str] = {
    "hi": "Devanagari", "mr": "Devanagari", "bn": "Bengali", "as": "Bengali",
    "pa": "Gurmukhi", "gu": "Gujarati", "or": "Oriya", "ta": "Tamil",
    "te": "Telugu", "kn": "Kannada", "ml": "Malayalam",
}

INDIC_BLOCK_MIN = 0x0900
INDIC_BLOCK_MAX = 0x0D7F
ARABIC_BLOCK_MIN = 0x0600
ARABIC_BLOCK_MAX = 0x06FF


def normalize_digits(text: str | None) -> str | None:
    if not text:
        return text
    out = []
    for ch in text:
        d = unicodedata.decimal(ch, -1)
        if d >= 0:
            out.append(str(d))
        else:
            out.append(ch)
    return "".join(out)


def has_target_script_chars(text: str, target_script: str) -> bool:
    ranges = SCRIPT_RANGES.get(target_script, [])
    if not ranges:
        return True
    for ch in text:
        cp = ord(ch)
        for lo, hi in ranges:
            if lo <= cp <= hi:
                return True
    return False


def clean_romanized(text: str | None) -> str | None:
    if not text:
        return text
    nfkd = unicodedata.normalize("NFKD", text)
    out = []
    for ch in nfkd:
        cat = unicodedata.category(ch)
        if cat.startswith("M"):
            continue
        cp = ord(ch)
        if INDIC_BLOCK_MIN <= cp <= INDIC_BLOCK_MAX:
            continue
        if ARABIC_BLOCK_MIN <= cp <= ARABIC_BLOCK_MAX:
            continue
        out.append(ch)
    return "".join(out)


def sha256_bytes(data: bytes) -> str:
    return hashlib.sha256(data).hexdigest()


@dataclass
class ShardInfo:
    shard_id: str
    language: str
    metadata_key: str
    manifest_key: str


@dataclass
class PatchStats:
    shards_processed: int = 0
    shards_failed: int = 0
    rows_total: int = 0
    digits_patched: int = 0
    native_nulled: int = 0
    roman_cleaned: int = 0
    errors: list[str] = field(default_factory=list)


def make_s3():
    return boto3.client(
        "s3",
        endpoint_url=os.environ["R2_ENDPOINT_URL"],
        aws_access_key_id=os.environ["R2_ACCESS_KEY_ID"],
        aws_secret_access_key=os.environ["R2_SECRET_ACCESS_KEY"],
        region_name="auto",
        config=Config(max_pool_connections=64, retries={"max_attempts": 3}),
    )


def patch_one_shard(s3, shard: ShardInfo) -> dict:
    result = {
        "shard_id": shard.shard_id,
        "rows": 0, "digits_patched": 0, "native_nulled": 0, "roman_cleaned": 0,
        "error": None, "new_metadata_sha256": None, "new_metadata_size": None,
    }
    try:
        meta_body = s3.get_object(Bucket=BUCKET, Key=shard.metadata_key)["Body"].read()
        table = pq.read_table(io.BytesIO(meta_body))
        rows = table.to_pylist()
        result["rows"] = len(rows)

        target_script = LANG_TO_TARGET.get(shard.language)
        changed = False

        for row in rows:
            mixed_orig = row.get("transcription_mixed")
            native_orig = row.get("transcription_native")
            roman_orig = row.get("transcription_romanized")

            mixed_new = normalize_digits(mixed_orig)
            native_new = normalize_digits(native_orig)
            roman_new = normalize_digits(roman_orig)

            if mixed_new != mixed_orig or native_new != native_orig or roman_new != roman_orig:
                result["digits_patched"] += 1
                changed = True
            row["transcription_mixed"] = mixed_new
            row["transcription_native"] = native_new
            row["transcription_romanized"] = roman_new

            if target_script and native_new and not has_target_script_chars(native_new, target_script):
                row["transcription_native"] = ""
                result["native_nulled"] += 1
                changed = True

            roman_cleaned = clean_romanized(row["transcription_romanized"])
            if roman_cleaned != row["transcription_romanized"]:
                row["transcription_romanized"] = roman_cleaned
                result["roman_cleaned"] += 1
                changed = True

        if not changed:
            return result

        new_table = pa.Table.from_pylist(rows, schema=table.schema)
        buf = io.BytesIO()
        pq.write_table(new_table, buf, compression="zstd")
        new_meta_bytes = buf.getvalue()

        result["new_metadata_sha256"] = sha256_bytes(new_meta_bytes)
        result["new_metadata_size"] = len(new_meta_bytes)
        s3.put_object(Bucket=BUCKET, Key=shard.metadata_key, Body=new_meta_bytes)

        manifest_body = s3.get_object(Bucket=BUCKET, Key=shard.manifest_key)["Body"].read()
        manifest = json.loads(manifest_body.decode("utf-8"))
        manifest["metadata_sha256"] = result["new_metadata_sha256"]
        manifest["metadata_size_bytes"] = result["new_metadata_size"]
        manifest["patched_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
        s3.put_object(
            Bucket=BUCKET, Key=shard.manifest_key,
            Body=json.dumps(manifest, indent=2).encode("utf-8"),
        )
    except Exception as exc:
        result["error"] = str(exc)[:300]
    return result


def discover_shards_from_r2(s3) -> list[ShardInfo]:
    """List all shard metadata.parquet keys under the R2 prefix and derive shard info."""
    prefix = SHARD_PREFIX
    shards: list[ShardInfo] = []
    paginator = s3.get_paginator("list_objects_v2")
    for page in paginator.paginate(Bucket=BUCKET, Prefix=prefix):
        for item in page.get("Contents", []):
            key = str(item["Key"])
            if not key.endswith("/metadata.parquet"):
                continue
            parts = key.rsplit("/", 2)
            if len(parts) < 3:
                continue
            shard_dir = parts[-2]
            lang_dir = parts[-3]
            language = lang_dir.split("=", 1)[-1] if "=" in lang_dir else lang_dir
            manifest_key = key.rsplit("/", 1)[0] + "/manifest.json"
            shards.append(ShardInfo(
                shard_id=shard_dir, language=language,
                metadata_key=key, manifest_key=manifest_key,
            ))
    shards.sort(key=lambda s: (s.language, s.shard_id))
    return shards


def main():
    s3 = make_s3()

    db_url = os.environ.get("DATABASE_URL", "").strip()
    if db_url:
        import psycopg2
        conn = psycopg2.connect(db_url)
        cur = conn.cursor()
        cur.execute("""
            SELECT shard_id, language, metadata_key, manifest_key
            FROM final_export_shards
            WHERE run_id = %s
            ORDER BY language, shard_id
        """, (RUN_ID,))
        shards = [
            ShardInfo(shard_id=r[0], language=r[1], metadata_key=r[2], manifest_key=r[3])
            for r in cur.fetchall()
        ]
        conn.close()
    else:
        logger.info("No DATABASE_URL; discovering shards from R2 listing...")
        shards = discover_shards_from_r2(s3)
    logger.info("Loaded %d shards to patch", len(shards))

    max_workers = int(sys.argv[1]) if len(sys.argv) > 1 else 48
    stats = PatchStats()
    s3 = make_s3()

    t0 = time.time()
    with ThreadPoolExecutor(max_workers=max_workers) as pool:
        futures = {pool.submit(patch_one_shard, s3, shard): shard for shard in shards}
        for i, future in enumerate(as_completed(futures), 1):
            shard = futures[future]
            result = future.result()
            stats.rows_total += result["rows"]
            stats.digits_patched += result["digits_patched"]
            stats.native_nulled += result["native_nulled"]
            stats.roman_cleaned += result["roman_cleaned"]
            if result["error"]:
                stats.shards_failed += 1
                stats.errors.append(f'{shard.shard_id}: {result["error"]}')
                logger.warning("[%d/%d] FAIL %s: %s", i, len(shards), shard.shard_id, result["error"][:120])
            else:
                stats.shards_processed += 1
                if i % 100 == 0 or i == len(shards):
                    elapsed = time.time() - t0
                    rate = i / elapsed if elapsed > 0 else 0
                    eta = (len(shards) - i) / rate if rate > 0 else 0
                    logger.info(
                        "[%d/%d] ok  rate=%.1f/s  ETA=%.0fs  digits=%d native_nulled=%d roman=%d",
                        i, len(shards), rate, eta,
                        stats.digits_patched, stats.native_nulled, stats.roman_cleaned,
                    )

    elapsed = time.time() - t0
    logger.info("=== PATCH COMPLETE ===")
    logger.info("shards_processed=%d shards_failed=%d rows=%d elapsed=%.1fs",
                stats.shards_processed, stats.shards_failed, stats.rows_total, elapsed)
    logger.info("digits_patched=%d native_nulled=%d roman_cleaned=%d",
                stats.digits_patched, stats.native_nulled, stats.roman_cleaned)
    for err in stats.errors[:20]:
        logger.error("  %s", err)

    if stats.shards_failed > 0:
        logger.warning("Re-run to retry %d failed shards", stats.shards_failed)

    conn = psycopg2.connect(os.environ["DATABASE_URL"])
    conn.autocommit = True
    cur = conn.cursor()
    cur.execute("""
        UPDATE final_export_shards
        SET metadata_json = jsonb_set(
            coalesce(metadata_json, '{}'::jsonb),
            '{metadata_patched_at}',
            to_jsonb(now()::text)
        )
        WHERE run_id = %s
    """, (RUN_ID,))
    logger.info("Updated %d shard rows with patch timestamp", cur.rowcount)
    conn.close()


if __name__ == "__main__":
    main()
