from __future__ import annotations

import json
import time
from pathlib import Path

import duckdb


ROOT = Path("/home/ubuntu/transcripts")
DATA_DIR = ROOT / "data"
FINAL_DIR = ROOT / "final_data"

SEGMENT_MAP_GLOB = (DATA_DIR / "phase1_incremental" / "segment_map_v1" / "**" / "*.parquet").as_posix()
VIDEO_ROLLUP_KEPT = (FINAL_DIR / "video_rollup_final_kept_subset.parquet").as_posix()
CLASSIFICATION_FINAL = (DATA_DIR / "video_tts_classification_final.csv").as_posix()
DROPPED_VIDEOS = (DATA_DIR / "video_tts_dropped_by_channel.csv").as_posix()
VIDEO_QUEUE = (DATA_DIR / "video_queue.csv.gz").as_posix()
YOUTUBE_META = (DATA_DIR / "youtube_video_metadata_all.csv").as_posix()

TARGET_LANGS = ("en", "hi", "te", "ta", "kn", "ml", "gu", "pa", "bn", "or", "mr", "as")


def write_json(path: Path, payload: dict) -> None:
    path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")


def fetchone_dict(con: duckdb.DuckDBPyConnection, query: str) -> dict:
    rel = con.execute(query)
    row = rel.fetchone()
    if row is None:
        return {}
    cols = [d[0] for d in rel.description]
    return dict(zip(cols, row))


def pct(numerator: int, denominator: int) -> float:
    if denominator == 0:
        return 0.0
    return round(100.0 * numerator / denominator, 6)


def main() -> None:
    FINAL_DIR.mkdir(parents=True, exist_ok=True)
    temp_dir = FINAL_DIR / "duckdb_tmp"
    temp_dir.mkdir(parents=True, exist_ok=True)

    con = duckdb.connect()
    con.execute("SET threads = 1")
    con.execute("SET memory_limit = '12GB'")
    con.execute("SET preserve_insertion_order = false")
    con.execute(f"SET temp_directory = '{temp_dir.as_posix()}'")

    started = time.time()

    con.execute(f"""
        CREATE OR REPLACE VIEW selected_videos AS
        SELECT DISTINCT video_id, recommended_action
        FROM read_csv_auto('{CLASSIFICATION_FINAL}', header=true)
    """)
    con.execute(f"""
        CREATE OR REPLACE VIEW dropped_videos AS
        SELECT DISTINCT video_id
        FROM read_csv_auto('{DROPPED_VIDEOS}', header=true)
    """)
    con.execute(f"""
        CREATE OR REPLACE VIEW queue_videos AS
        SELECT video_id, language AS queue_language
        FROM read_csv_auto('{VIDEO_QUEUE}', header=true)
    """)
    con.execute(f"""
        CREATE OR REPLACE VIEW youtube_meta AS
        SELECT
            video_id,
            regexp_extract(lower(coalesce(default_audio_language, '')), '^([a-z]+)', 1) AS youtube_audio_language,
            regexp_extract(lower(coalesce(default_language, '')), '^([a-z]+)', 1) AS youtube_default_language,
            channel_id,
            channel_title,
            title
        FROM read_csv_auto('{YOUTUBE_META}', header=true)
    """)
    con.execute(f"""
        CREATE OR REPLACE VIEW kept_video_rollup AS
        SELECT * FROM read_parquet('{VIDEO_ROLLUP_KEPT}')
    """)
    con.execute(f"""
        CREATE OR REPLACE VIEW kept_segments AS
        SELECT
            s.video_id,
            coalesce(q.queue_language, s.queue_language) AS queue_language,
            coalesce(nullif(s.tx_detected_language, ''), nullif(s.expected_language_hint, ''), coalesce(q.queue_language, s.queue_language)) AS gemini_lang
        FROM read_parquet('{SEGMENT_MAP_GLOB}', hive_partitioning=true, union_by_name=true) s
        JOIN selected_videos sv USING (video_id)
        LEFT JOIN queue_videos q USING (video_id)
    """)
    con.execute("""
        CREATE OR REPLACE VIEW kept_video_lang_counts AS
        SELECT
            video_id,
            gemini_lang,
            count(*) AS segments
        FROM kept_segments
        GROUP BY video_id, gemini_lang
    """)
    con.execute("""
        CREATE OR REPLACE VIEW kept_video_lang_ranked AS
        SELECT
            *,
            row_number() OVER (PARTITION BY video_id ORDER BY segments DESC, gemini_lang) AS rn
        FROM kept_video_lang_counts
    """)
    con.execute("""
        CREATE OR REPLACE VIEW kept_video_language_profile AS
        WITH video_totals AS (
            SELECT
                video_id,
                count(*) AS total_segments,
                count(DISTINCT gemini_lang) AS distinct_detected_languages,
                count(*) FILTER (WHERE gemini_lang NOT IN ('en','hi','te','ta','kn','ml','gu','pa','bn','or','mr','as') AND gemini_lang <> '') AS unexpected_language_segments,
                count(DISTINCT CASE
                    WHEN gemini_lang NOT IN ('en','hi','te','ta','kn','ml','gu','pa','bn','or','mr','as') AND gemini_lang <> ''
                    THEN gemini_lang
                END) AS distinct_unexpected_languages
            FROM kept_segments
            GROUP BY video_id
        ),
        dominant AS (
            SELECT
                video_id,
                gemini_lang AS dominant_gemini_language,
                segments AS dominant_gemini_segments
            FROM kept_video_lang_ranked
            WHERE rn = 1
        ),
        threshold_counts AS (
            SELECT
                video_id,
                count(*) FILTER (WHERE segments >= 2) AS detected_languages_2plus,
                count(*) FILTER (WHERE segments >= 3) AS detected_languages_3plus,
                string_agg(gemini_lang || ':' || segments, ', ' ORDER BY segments DESC, gemini_lang)
                    FILTER (WHERE rn <= 5) AS top_detected_languages
            FROM kept_video_lang_ranked
            GROUP BY video_id
        )
        SELECT
            vt.video_id,
            sv.recommended_action,
            q.queue_language,
            ym.youtube_audio_language,
            ym.youtube_default_language,
            ym.channel_id,
            ym.channel_title,
            ym.title,
            vt.total_segments,
            vt.distinct_detected_languages,
            coalesce(tc.detected_languages_2plus, 0) AS detected_languages_2plus,
            coalesce(tc.detected_languages_3plus, 0) AS detected_languages_3plus,
            vt.unexpected_language_segments,
            vt.distinct_unexpected_languages,
            d.dominant_gemini_language,
            d.dominant_gemini_segments,
            round(100.0 * d.dominant_gemini_segments / vt.total_segments, 6) AS dominant_gemini_share_pct,
            d.dominant_gemini_language NOT IN ('en','hi','te','ta','kn','ml','gu','pa','bn','or','mr','as') AS dominant_outside_target,
            d.dominant_gemini_language NOT IN ('en','hi','te','ta','kn','ml','gu','pa','bn','or','mr','as')
                AND (100.0 * d.dominant_gemini_segments / vt.total_segments) >= 80.0 AS dominant_outside_target_ge80,
            d.dominant_gemini_language = q.queue_language AS dominant_matches_queue_language,
            d.dominant_gemini_language = ym.youtube_audio_language AS dominant_matches_youtube_audio_language,
            d.dominant_gemini_language = ym.youtube_default_language AS dominant_matches_youtube_default_language,
            tc.top_detected_languages
        FROM video_totals vt
        JOIN dominant d USING (video_id)
        LEFT JOIN threshold_counts tc USING (video_id)
        JOIN selected_videos sv USING (video_id)
        LEFT JOIN queue_videos q USING (video_id)
        LEFT JOIN youtube_meta ym USING (video_id)
    """)

    con.execute(f"""
        COPY (
            SELECT *
            FROM kept_video_language_profile
            WHERE dominant_outside_target
            ORDER BY total_segments DESC, video_id
        ) TO '{(FINAL_DIR / "gemini_dominant_non_target_videos_all.csv").as_posix()}' (HEADER, DELIMITER ',')
    """)
    con.execute(f"""
        COPY (
            SELECT *
            FROM kept_video_language_profile
            WHERE dominant_outside_target_ge80
            ORDER BY total_segments DESC, video_id
        ) TO '{(FINAL_DIR / "gemini_dominant_non_target_videos_ge80.csv").as_posix()}' (HEADER, DELIMITER ',')
    """)

    con.execute("""
        CREATE OR REPLACE VIEW refined_keep_strict AS
        SELECT *
        FROM kept_video_rollup
        WHERE video_id NOT IN (
            SELECT video_id FROM kept_video_language_profile WHERE dominant_outside_target
        )
    """)
    con.execute("""
        CREATE OR REPLACE VIEW refined_keep_conservative AS
        SELECT *
        FROM kept_video_rollup
        WHERE video_id NOT IN (
            SELECT video_id FROM kept_video_language_profile WHERE dominant_outside_target_ge80
        )
    """)

    con.execute(f"""
        COPY (
            SELECT rk.*, p.dominant_gemini_language, p.dominant_gemini_share_pct, p.top_detected_languages
            FROM refined_keep_strict rk
            LEFT JOIN kept_video_language_profile p USING (video_id)
            ORDER BY rk.video_id
        ) TO '{(FINAL_DIR / "video_rollup_gemini_refined_strict.parquet").as_posix()}' (FORMAT PARQUET, COMPRESSION ZSTD)
    """)
    con.execute(f"""
        COPY (
            SELECT rk.*, p.dominant_gemini_language, p.dominant_gemini_share_pct, p.top_detected_languages
            FROM refined_keep_conservative rk
            LEFT JOIN kept_video_language_profile p USING (video_id)
            ORDER BY rk.video_id
        ) TO '{(FINAL_DIR / "video_rollup_gemini_refined_conservative.parquet").as_posix()}' (FORMAT PARQUET, COMPRESSION ZSTD)
    """)

    overlap_summary = fetchone_dict(
        con,
        """
        SELECT
            (SELECT count(*) FROM selected_videos) AS selected_videos,
            (SELECT count(*) FROM dropped_videos) AS dropped_videos,
            (SELECT count(*) FROM kept_video_language_profile WHERE dominant_outside_target) AS dominant_non_target_videos,
            (SELECT count(*) FROM kept_video_language_profile WHERE dominant_outside_target_ge80) AS dominant_non_target_videos_ge80,
            (SELECT count(*) FROM kept_video_language_profile p JOIN dropped_videos d USING (video_id) WHERE dominant_outside_target) AS overlap_non_target_with_dropped,
            (SELECT count(*) FROM kept_video_language_profile p JOIN dropped_videos d USING (video_id) WHERE dominant_outside_target_ge80) AS overlap_non_target_ge80_with_dropped
        """
    )

    dropped_transcript_summary = fetchone_dict(
        con,
        """
        SELECT
            count(*) AS transcribed_videos,
            sum(total_segments) AS transcribed_segments
        FROM read_parquet('/home/ubuntu/transcripts/final_data/video_rollup_final.parquet')
        WHERE video_id IN (SELECT video_id FROM dropped_videos)
        """
    )

    kept_summary = fetchone_dict(
        con,
        """
        SELECT
            count(*) AS videos,
            sum(total_segments) AS total_segments,
            sum(final_validated_segments) AS final_validated_segments,
            sum(final_missing_segments) AS final_missing_segments,
            sum(final_golden_segments) AS golden_segments,
            sum(final_redo_segments) AS redo_segments,
            sum(final_dispose_segments) AS dispose_segments
        FROM kept_video_rollup
        """
    )
    strict_summary = fetchone_dict(
        con,
        """
        SELECT
            count(*) AS videos,
            sum(total_segments) AS total_segments,
            sum(final_validated_segments) AS final_validated_segments,
            sum(final_missing_segments) AS final_missing_segments,
            sum(final_golden_segments) AS golden_segments,
            sum(final_redo_segments) AS redo_segments,
            sum(final_dispose_segments) AS dispose_segments
        FROM refined_keep_strict
        """
    )
    conservative_summary = fetchone_dict(
        con,
        """
        SELECT
            count(*) AS videos,
            sum(total_segments) AS total_segments,
            sum(final_validated_segments) AS final_validated_segments,
            sum(final_missing_segments) AS final_missing_segments,
            sum(final_golden_segments) AS golden_segments,
            sum(final_redo_segments) AS redo_segments,
            sum(final_dispose_segments) AS dispose_segments
        FROM refined_keep_conservative
        """
    )

    all_transcribed = fetchone_dict(
        con,
        """
        SELECT
            count(*) AS videos,
            sum(total_segments) AS total_segments,
            sum(final_validated_segments) AS final_validated_segments,
            sum(final_missing_segments) AS final_missing_segments
        FROM read_parquet('/home/ubuntu/transcripts/final_data/video_rollup_final.parquet')
        """
    )

    fresh_summary = {
        "generated_at_epoch_s": round(time.time(), 3),
        "elapsed_s": round(time.time() - started, 2),
        "overlap": overlap_summary,
        "all_transcribed": all_transcribed,
        "dropped_set_only": {
            **dropped_transcript_summary,
            "videos_pct_of_all": pct(dropped_transcript_summary["transcribed_videos"], all_transcribed["videos"]),
            "segments_pct_of_all": pct(dropped_transcript_summary["transcribed_segments"], all_transcribed["total_segments"]),
        },
        "kept_set": {
            **kept_summary,
            "coverage_pct": pct(kept_summary["final_validated_segments"], kept_summary["total_segments"]),
        },
        "refined_strict_remove_all_dominant_non_target": {
            **strict_summary,
            "coverage_pct": pct(strict_summary["final_validated_segments"], strict_summary["total_segments"]),
            "removed_videos": kept_summary["videos"] - strict_summary["videos"],
            "removed_segments": kept_summary["total_segments"] - strict_summary["total_segments"],
        },
        "refined_conservative_remove_dominant_non_target_ge80": {
            **conservative_summary,
            "coverage_pct": pct(conservative_summary["final_validated_segments"], conservative_summary["total_segments"]),
            "removed_videos": kept_summary["videos"] - conservative_summary["videos"],
            "removed_segments": kept_summary["total_segments"] - conservative_summary["total_segments"],
        },
    }

    write_json(FINAL_DIR / "gemini_refined_dataset_summary.json", fresh_summary)
    print(json.dumps(fresh_summary, indent=2, sort_keys=True))


if __name__ == "__main__":
    main()
