from __future__ import annotations

import duckdb
from pathlib import Path


def main() -> None:
    out_dir = Path("final_data/english_mixed_reroute")
    out_dir.mkdir(parents=True, exist_ok=True)

    con = duckdb.connect(":memory:")
    con.execute("SET threads TO 4")
    con.execute("SET memory_limit = '8GB'")

    # Build the reroute subset:
    # - only English-hinted rows that were locally skipped
    # - only rows containing at least one Indic-script character
    # - choose target language primarily from actual script in text
    # - use queue/corrected language only for ambiguous script families
    sql = """
    CREATE OR REPLACE TEMP TABLE english_mixed_subset AS
    WITH base AS (
        SELECT
            video_id || '/' || segment_file AS row_id,
            video_id,
            segment_file AS segment_id,
            queue_language,
            corrected_language,
            gemini_lang,
            transcription AS text,
            CASE
                WHEN regexp_matches(transcription, '[\u0C00-\u0C7F]') THEN 'te'
                WHEN regexp_matches(transcription, '[\u0B80-\u0BFF]') THEN 'ta'
                WHEN regexp_matches(transcription, '[\u0C80-\u0CFF]') THEN 'kn'
                WHEN regexp_matches(transcription, '[\u0D00-\u0D7F]') THEN 'ml'
                WHEN regexp_matches(transcription, '[\u0A80-\u0AFF]') THEN 'gu'
                WHEN regexp_matches(transcription, '[\u0A00-\u0A7F]') THEN 'pa'
                WHEN regexp_matches(transcription, '[\u0B00-\u0B7F]') THEN 'or'
                WHEN regexp_matches(transcription, '[\u0980-\u09FF]') THEN
                    CASE
                        WHEN corrected_language IN ('bn', 'as') THEN corrected_language
                        WHEN queue_language IN ('bn', 'as') THEN queue_language
                        ELSE 'bn'
                    END
                WHEN regexp_matches(transcription, '[\u0900-\u097F]') THEN
                    CASE
                        WHEN corrected_language IN ('hi', 'mr') THEN corrected_language
                        WHEN queue_language IN ('hi', 'mr') THEN queue_language
                        ELSE 'hi'
                    END
                ELSE NULL
            END AS target_language_code
        FROM read_parquet('final_data/final_cleaned_segments_with_variants.parquet')
        WHERE gemini_lang = 'en'
          AND variant_route = 'local_skip_fully_roman'
          AND regexp_matches(transcription, '[\u0900-\u0D7F]')
    )
    SELECT
        row_id,
        video_id,
        segment_id,
        queue_language,
        corrected_language,
        gemini_lang,
        target_language_code,
        text
    FROM base
    WHERE target_language_code IS NOT NULL
    """
    con.execute(sql)

    total = con.execute("SELECT count(*) FROM english_mixed_subset").fetchone()[0]
    print(f"Prepared reroute subset: {total:,} rows")

    breakdown = con.execute(
        """
        SELECT target_language_code, count(*) AS cnt
        FROM english_mixed_subset
        GROUP BY 1 ORDER BY cnt DESC
        """
    ).fetchall()
    print("Target language breakdown:")
    for lang, cnt in breakdown:
        print(f"  {lang}: {cnt:,}")

    # Split evenly into 4 parts for 4 API keys/processes.
    con.execute(
        """
        CREATE OR REPLACE TEMP TABLE english_mixed_partitioned AS
        SELECT
            *,
            (row_number() OVER (ORDER BY row_id) - 1) % 4 AS part_idx
        FROM english_mixed_subset
        """
    )

    for part_idx in range(4):
        out_path = out_dir / f"part_{part_idx}.parquet"
        con.execute(
            f"""
            COPY (
                SELECT row_id, video_id, segment_id, target_language_code, text
                FROM english_mixed_partitioned
                WHERE part_idx = {part_idx}
                ORDER BY row_id
            )
            TO '{out_path.as_posix()}'
            (FORMAT PARQUET, COMPRESSION ZSTD)
            """
        )
        part_count = con.execute(
            f"SELECT count(*) FROM english_mixed_partitioned WHERE part_idx = {part_idx}"
        ).fetchone()[0]
        print(f"  part_{part_idx}: {part_count:,} rows -> {out_path}")


if __name__ == "__main__":
    main()
