"""
Build a validation-recovery manifest from:
  - historical transcription rows (`transcription_results.parquet`)
  - current validation archive rows (`golden/redo/dispose` CSVs)

Outputs:
  - recover_manifest.parquet     segment-level target set (tx rows missing validation)
  - recover_video_summary.csv    per-video recovery rollup
  - recover_summary.json         aggregate counts
"""
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import duckdb


def parse_args():
    p = argparse.ArgumentParser(description="Build validation-recovery manifest")
    p.add_argument("--tx", default="data/transcription_results.parquet", help="Path to transcription_results parquet")
    p.add_argument("--validation-dir", default="data", help="Dir containing golden/redo/dispose CSVs")
    p.add_argument("--output-dir", default="data", help="Dir to write recovery artifacts")
    return p.parse_args()


def main():
    args = parse_args()
    tx_path = Path(args.tx)
    validation_dir = Path(args.validation_dir)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    validation_csvs = [
        validation_dir / "golden_segments.csv",
        validation_dir / "redo_segments.csv",
        validation_dir / "dispose_segments.csv",
    ]
    existing_csvs = [str(p) for p in validation_csvs if p.exists()]

    if not tx_path.exists():
        raise SystemExit(f"Missing transcription parquet: {tx_path}")

    con = duckdb.connect()
    con.execute("SET threads = 8")
    con.execute("SET memory_limit = '16GB'")

    t0 = time.time()
    print(f"Loading transcriptions from {tx_path}")
    con.execute(f"""
        CREATE VIEW tx AS
        SELECT
            video_id,
            segment_file,
            expected_language_hint,
            detected_language,
            quality_score,
            asr_eligible,
            tts_clean_eligible,
            tts_expressive_eligible,
            created_at
        FROM read_parquet('{tx_path}')
    """)

    if existing_csvs:
        csv_list = ", ".join([f"'{p}'" for p in existing_csvs])
        print(f"Loading current validation archive rows from {len(existing_csvs)} CSV(s)")
        con.execute(f"""
            CREATE VIEW val AS
            SELECT DISTINCT video_id, segment_file
            FROM read_csv_auto([{csv_list}], union_by_name=true, header=true)
        """)
    else:
        print("No validation CSVs found; treating current validation archive as empty")
        con.execute("""
            CREATE VIEW val AS
            SELECT
                CAST(NULL AS VARCHAR) AS video_id,
                CAST(NULL AS VARCHAR) AS segment_file
            WHERE 1 = 0
        """)

    con.execute("""
        CREATE TABLE recover_manifest AS
        WITH val_ids AS (
            SELECT DISTINCT video_id, segment_file FROM val
        ),
        video_rollup AS (
            SELECT
                tx.video_id,
                count(*) AS tx_segments,
                count(val_ids.segment_file) AS validated_segments_current
            FROM tx
            LEFT JOIN val_ids
              ON tx.video_id = val_ids.video_id
             AND tx.segment_file = val_ids.segment_file
            GROUP BY tx.video_id
        )
        SELECT
            tx.video_id,
            tx.segment_file,
            tx.expected_language_hint,
            tx.detected_language,
            tx.quality_score,
            tx.asr_eligible,
            tx.tts_clean_eligible,
            tx.tts_expressive_eligible,
            tx.created_at,
            vr.tx_segments,
            vr.validated_segments_current,
            (vr.tx_segments - vr.validated_segments_current) AS missing_segments_for_video,
            CASE
                WHEN COALESCE(vr.validated_segments_current, 0) = 0 THEN 'archive_missing'
                ELSE 'partial_gap'
            END AS recover_reason
        FROM tx
        LEFT JOIN val_ids
          ON tx.video_id = val_ids.video_id
         AND tx.segment_file = val_ids.segment_file
        JOIN video_rollup vr
          ON tx.video_id = vr.video_id
        WHERE val_ids.segment_file IS NULL
    """)

    manifest_path = output_dir / "recover_manifest.parquet"
    con.execute(f"""
        COPY (
            SELECT *
            FROM recover_manifest
            ORDER BY video_id, segment_file
        ) TO '{manifest_path}' (FORMAT PARQUET, COMPRESSION ZSTD)
    """)

    summary_csv = output_dir / "recover_video_summary.csv"
    con.execute(f"""
        COPY (
            SELECT
                video_id,
                max(expected_language_hint) AS expected_language_hint,
                max(detected_language) AS sample_detected_language,
                max(recover_reason) AS recover_reason,
                max(tx_segments) AS tx_segments,
                max(validated_segments_current) AS validated_segments_current,
                count(*) AS missing_segments
            FROM recover_manifest
            GROUP BY video_id
            ORDER BY missing_segments DESC, video_id
        ) TO '{summary_csv}' (HEADER, DELIMITER ',')
    """)

    agg = con.execute("""
        WITH video_rollup AS (
            SELECT
                video_id,
                max(tx_segments) AS tx_segments,
                max(validated_segments_current) AS validated_segments_current,
                count(*) AS missing_segments,
                max(recover_reason) AS recover_reason
            FROM recover_manifest
            GROUP BY video_id
        )
        SELECT
            count(*) AS videos_to_recover,
            sum(missing_segments) AS segments_to_recover,
            sum(CASE WHEN recover_reason = 'archive_missing' THEN 1 ELSE 0 END) AS archive_missing_videos,
            sum(CASE WHEN recover_reason = 'partial_gap' THEN 1 ELSE 0 END) AS partial_gap_videos,
            round(avg(missing_segments), 2) AS avg_missing_segments_per_video
        FROM video_rollup
    """).fetchone()

    full = con.execute("""
        SELECT
            count(*) AS total_tx_segments,
            count(DISTINCT video_id) AS total_tx_videos
        FROM tx
    """).fetchone()

    summary = {
        "total_tx_segments": int(full[0]),
        "total_tx_videos": int(full[1]),
        "videos_to_recover": int(agg[0] or 0),
        "segments_to_recover": int(agg[1] or 0),
        "archive_missing_videos": int(agg[2] or 0),
        "partial_gap_videos": int(agg[3] or 0),
        "avg_missing_segments_per_video": float(agg[4] or 0.0),
        "manifest_path": str(manifest_path),
        "summary_csv": str(summary_csv),
        "elapsed_s": round(time.time() - t0, 2),
    }

    summary_json = output_dir / "recover_summary.json"
    summary_json.write_text(json.dumps(summary, indent=2))

    print(json.dumps(summary, indent=2))


if __name__ == "__main__":
    main()
