"""
Validation analysis: reads all parquet shards, computes aggregate metrics,
and buckets segments into Golden / Redo / Dispose sets.

Usage:
  python scripts/analyze_validation.py [--shards data/validation_shards] [--output data]

Outputs:
  - Summary metrics to stdout
  - data/golden_segments.csv  — production-ready segments
  - data/redo_segments.csv    — salvageable, needs re-transcription
  - data/dispose_segments.csv — not worth keeping
  - data/video_summary.csv    — per-video rollup with bucket counts
"""
from __future__ import annotations

import argparse
import time
from pathlib import Path

import duckdb


# === Bucketing thresholds ===
# Golden: high confidence across all signals
GOLDEN_LID_AGREE = 3
GOLDEN_CTC_MIN = 0.7
GOLDEN_QUALITY_MIN = 0.5
GOLDEN_DURATION_MIN = 2.0

# Dispose: clearly bad data
DISPOSE_CTC_MAX = 0.3
DISPOSE_DURATION_MAX = 1.0

# Redo: everything that's not golden and not dispose


def main():
    parser = argparse.ArgumentParser(description="Validation metrics + Golden/Redo/Dispose bucketing")
    parser.add_argument("--shards", type=str, default="data/validation_shards")
    parser.add_argument("--output", type=str, default="data")
    args = parser.parse_args()

    shard_glob = f"{args.shards}/**/*.parquet"
    out = Path(args.output)
    out.mkdir(parents=True, exist_ok=True)

    con = duckdb.connect()
    con.execute("SET threads = 8")
    con.execute("SET memory_limit = '16GB'")

    t0 = time.time()
    print(f"Loading shards from {shard_glob} ...")
    con.execute(f"""
        CREATE VIEW val AS
        SELECT * FROM read_parquet('{shard_glob}', hive_partitioning=false)
    """)

    # ── 1. Overall summary ──────────────────────────────────────────
    print("\n" + "=" * 70)
    print("  VALIDATION DATASET SUMMARY")
    print("=" * 70)

    row = con.execute("""
        SELECT
            count(*) as total_segments,
            count(DISTINCT video_id) as total_videos,
            round(sum(duration_s) / 3600, 1) as total_hours,
            round(avg(duration_s), 2) as avg_duration_s,
            round(sum(CASE WHEN lid_consensus THEN 1 ELSE 0 END) * 100.0 / count(*), 1) as consensus_pct,
            round(avg(mms_confidence) * 100, 1) as avg_mms_conf,
            round(avg(vox_confidence) * 100, 1) as avg_vox_conf
        FROM val
    """).fetchone()

    total_segs, total_vids, total_hrs, avg_dur, consensus_pct, avg_mms, avg_vox = row
    print(f"\n  Total segments:     {total_segs:>12,}")
    print(f"  Total videos:       {total_vids:>12,}")
    print(f"  Total hours:        {total_hrs:>12,}")
    print(f"  Avg duration:       {avg_dur:>12} s")
    print(f"  LID consensus:      {consensus_pct:>12}%")
    print(f"  Avg MMS confidence: {avg_mms:>12}%")
    print(f"  Avg Vox confidence: {avg_vox:>12}%")

    # ── 2. Language distribution ────────────────────────────────────
    print(f"\n{'─' * 70}")
    print("  LANGUAGE DISTRIBUTION (by consensus_lang)")
    print(f"{'─' * 70}")

    lang_rows = con.execute("""
        SELECT
            consensus_lang as lang,
            count(*) as segments,
            count(DISTINCT video_id) as videos,
            round(sum(duration_s) / 3600, 1) as hours,
            round(sum(CASE WHEN lid_consensus THEN 1 ELSE 0 END) * 100.0 / count(*), 1) as consensus_pct,
            round(avg(CASE WHEN conformer_multi_ctc_normalized IS NOT NULL
                       THEN conformer_multi_ctc_normalized END), 3) as avg_ctc
        FROM val
        GROUP BY consensus_lang
        ORDER BY segments DESC
    """).fetchall()

    print(f"\n  {'Lang':<6} {'Segments':>10} {'Videos':>8} {'Hours':>8} {'Consensus%':>11} {'Avg CTC':>8}")
    print(f"  {'─'*6} {'─'*10} {'─'*8} {'─'*8} {'─'*11} {'─'*8}")
    for lang, segs, vids, hrs, cpct, actc in lang_rows:
        lang_display = lang if lang else "(empty)"
        actc_display = f"{actc:.3f}" if actc is not None else "N/A"
        print(f"  {lang_display:<6} {segs:>10,} {vids:>8,} {hrs:>8,} {cpct:>11} {actc_display:>8}")

    # ── 3. LID agreement breakdown ──────────────────────────────────
    print(f"\n{'─' * 70}")
    print("  LID AGREEMENT (Gemini + MMS + VoxLingua)")
    print(f"{'─' * 70}")

    agree_rows = con.execute("""
        SELECT
            lid_agree_count,
            count(*) as cnt,
            round(count(*) * 100.0 / (SELECT count(*) FROM val), 1) as pct,
            round(sum(duration_s) / 3600, 1) as hours
        FROM val
        GROUP BY lid_agree_count
        ORDER BY lid_agree_count DESC
    """).fetchall()

    print(f"\n  {'Agree':>6} {'Segments':>12} {'%':>8} {'Hours':>8}")
    print(f"  {'─'*6} {'─'*12} {'─'*8} {'─'*8}")
    for agree, cnt, pct, hrs in agree_rows:
        label = f"{agree}/3"
        print(f"  {label:>6} {cnt:>12,} {pct:>8} {hrs:>8,}")

    # ── 4. Conformer CTC distribution ───────────────────────────────
    print(f"\n{'─' * 70}")
    print("  CONFORMER CTC NORMALIZED DISTRIBUTION")
    print(f"{'─' * 70}")

    ctc_rows = con.execute("""
        SELECT
            CASE
                WHEN conformer_multi_ctc_normalized IS NULL THEN 'NULL (no model)'
                WHEN conformer_multi_ctc_normalized >= 0.9 THEN '0.9-1.0 (excellent)'
                WHEN conformer_multi_ctc_normalized >= 0.7 THEN '0.7-0.9 (good)'
                WHEN conformer_multi_ctc_normalized >= 0.5 THEN '0.5-0.7 (fair)'
                WHEN conformer_multi_ctc_normalized >= 0.3 THEN '0.3-0.5 (poor)'
                ELSE '0.0-0.3 (bad)'
            END as bucket,
            count(*) as cnt,
            round(count(*) * 100.0 / (SELECT count(*) FROM val), 1) as pct
        FROM val
        GROUP BY bucket
        ORDER BY bucket
    """).fetchall()

    print(f"\n  {'Bucket':<25} {'Segments':>12} {'%':>8}")
    print(f"  {'─'*25} {'─'*12} {'─'*8}")
    for bkt, cnt, pct in ctc_rows:
        print(f"  {bkt:<25} {cnt:>12,} {pct:>8}")

    # ── 5. Gemini quality score distribution ────────────────────────
    print(f"\n{'─' * 70}")
    print("  GEMINI QUALITY SCORE DISTRIBUTION")
    print(f"{'─' * 70}")

    qual_rows = con.execute("""
        SELECT
            CASE
                WHEN gemini_quality_score >= 0.9 THEN '0.9-1.0'
                WHEN gemini_quality_score >= 0.7 THEN '0.7-0.9'
                WHEN gemini_quality_score >= 0.5 THEN '0.5-0.7'
                WHEN gemini_quality_score >= 0.3 THEN '0.3-0.5'
                WHEN gemini_quality_score > 0.0 THEN '0.0-0.3'
                ELSE '0.0 (unscored)'
            END as bucket,
            count(*) as cnt,
            round(count(*) * 100.0 / (SELECT count(*) FROM val), 1) as pct
        FROM val
        GROUP BY bucket
        ORDER BY bucket
    """).fetchall()

    print(f"\n  {'Bucket':<20} {'Segments':>12} {'%':>8}")
    print(f"  {'─'*20} {'─'*12} {'─'*8}")
    for bkt, cnt, pct in qual_rows:
        print(f"  {bkt:<20} {cnt:>12,} {pct:>8}")

    # ── 6. Apply bucketing: Golden / Redo / Dispose ─────────────────
    print(f"\n{'=' * 70}")
    print("  BUCKETING: Golden / Redo / Dispose")
    print(f"{'=' * 70}")

    # Materialize bucket column for efficiency
    con.execute(f"""
        CREATE TABLE bucketed AS
        SELECT *,
            CASE
                -- DISPOSE: clearly bad
                WHEN lid_consensus = false AND lid_agree_count < 2 THEN 'dispose'
                WHEN conformer_multi_ctc_normalized IS NOT NULL
                     AND conformer_multi_ctc_normalized < {DISPOSE_CTC_MAX} THEN 'dispose'
                WHEN duration_s < {DISPOSE_DURATION_MAX} THEN 'dispose'
                -- GOLDEN: high confidence across all signals
                WHEN lid_agree_count >= {GOLDEN_LID_AGREE}
                     AND (conformer_multi_ctc_normalized >= {GOLDEN_CTC_MIN}
                          OR conformer_multi_ctc_normalized IS NULL)
                     AND (gemini_quality_score >= {GOLDEN_QUALITY_MIN}
                          OR gemini_quality_score = 0)
                     AND duration_s >= {GOLDEN_DURATION_MIN} THEN 'golden'
                -- REDO: everything else
                ELSE 'redo'
            END as bucket
        FROM val
    """)

    bucket_rows = con.execute("""
        SELECT
            bucket,
            count(*) as segments,
            count(DISTINCT video_id) as videos,
            round(sum(duration_s) / 3600, 1) as hours,
            round(count(*) * 100.0 / (SELECT count(*) FROM bucketed), 1) as pct
        FROM bucketed
        GROUP BY bucket
        ORDER BY segments DESC
    """).fetchall()

    print(f"\n  {'Bucket':<10} {'Segments':>12} {'Videos':>10} {'Hours':>10} {'%':>8}")
    print(f"  {'─'*10} {'─'*12} {'─'*10} {'─'*10} {'─'*8}")
    for bkt, segs, vids, hrs, pct in bucket_rows:
        print(f"  {bkt:<10} {segs:>12,} {vids:>10,} {hrs:>10,} {pct:>8}")

    # ── 7. Per-bucket language breakdown ────────────────────────────
    print(f"\n{'─' * 70}")
    print("  PER-BUCKET LANGUAGE BREAKDOWN")
    print(f"{'─' * 70}")

    for bucket_name in ['golden', 'redo', 'dispose']:
        bl_rows = con.execute(f"""
            SELECT
                consensus_lang as lang,
                count(*) as segments,
                round(sum(duration_s) / 3600, 1) as hours
            FROM bucketed
            WHERE bucket = '{bucket_name}'
            GROUP BY consensus_lang
            ORDER BY segments DESC
        """).fetchall()

        print(f"\n  [{bucket_name.upper()}]")
        print(f"  {'Lang':<6} {'Segments':>10} {'Hours':>8}")
        print(f"  {'─'*6} {'─'*10} {'─'*8}")
        for lang, segs, hrs in bl_rows:
            lang_display = lang if lang else "(empty)"
            print(f"  {lang_display:<6} {segs:>10,} {hrs:>8,}")

    # ── 8. Dispose reason breakdown ─────────────────────────────────
    print(f"\n{'─' * 70}")
    print("  DISPOSE REASONS")
    print(f"{'─' * 70}")

    dispose_rows = con.execute(f"""
        SELECT
            CASE
                WHEN lid_consensus = false AND lid_agree_count < 2 THEN 'LID disagreement (0-1/3)'
                WHEN conformer_multi_ctc_normalized IS NOT NULL
                     AND conformer_multi_ctc_normalized < {DISPOSE_CTC_MAX} THEN 'Low CTC score (<0.3)'
                WHEN duration_s < {DISPOSE_DURATION_MAX} THEN 'Too short (<1s)'
                ELSE 'Other'
            END as reason,
            count(*) as cnt,
            round(count(*) * 100.0 / (SELECT count(*) FROM bucketed WHERE bucket = 'dispose'), 1) as pct
        FROM bucketed
        WHERE bucket = 'dispose'
        GROUP BY reason
        ORDER BY cnt DESC
    """).fetchall()

    print(f"\n  {'Reason':<30} {'Count':>10} {'%':>8}")
    print(f"  {'─'*30} {'─'*10} {'─'*8}")
    for reason, cnt, pct in dispose_rows:
        print(f"  {reason:<30} {cnt:>10,} {pct:>8}")

    # ── 9. Export CSVs ──────────────────────────────────────────────
    print(f"\n{'=' * 70}")
    print("  EXPORTING CSVs")
    print(f"{'=' * 70}")

    export_cols = """
        video_id, segment_file, consensus_lang, duration_s,
        lid_consensus, lid_agree_count,
        mms_lang_iso1, mms_confidence,
        vox_lang_iso1, vox_confidence,
        conformer_multi_ctc_normalized,
        wav2vec_ctc_normalized,
        gemini_quality_score, gemini_lang
    """

    for bucket_name in ['golden', 'redo', 'dispose']:
        csv_path = out / f"{bucket_name}_segments.csv"
        con.execute(f"""
            COPY (
                SELECT {export_cols}
                FROM bucketed
                WHERE bucket = '{bucket_name}'
                ORDER BY video_id, segment_file
            ) TO '{csv_path}' (HEADER, DELIMITER ',')
        """)
        row_count = con.execute(f"SELECT count(*) FROM bucketed WHERE bucket = '{bucket_name}'").fetchone()[0]
        size_mb = csv_path.stat().st_size / 1e6
        print(f"  {csv_path}: {row_count:,} rows, {size_mb:.1f} MB")

    # Per-video summary
    video_csv = out / "video_summary.csv"
    con.execute(f"""
        COPY (
            SELECT
                video_id,
                mode(consensus_lang) as dominant_lang,
                count(*) as total_segments,
                round(sum(duration_s), 1) as total_duration_s,
                sum(CASE WHEN bucket = 'golden' THEN 1 ELSE 0 END) as golden_segs,
                sum(CASE WHEN bucket = 'redo' THEN 1 ELSE 0 END) as redo_segs,
                sum(CASE WHEN bucket = 'dispose' THEN 1 ELSE 0 END) as dispose_segs,
                round(avg(CASE WHEN conformer_multi_ctc_normalized IS NOT NULL
                          THEN conformer_multi_ctc_normalized END), 3) as avg_ctc,
                round(avg(gemini_quality_score), 3) as avg_quality,
                round(sum(CASE WHEN lid_consensus THEN 1 ELSE 0 END) * 100.0 / count(*), 1) as consensus_pct
            FROM bucketed
            GROUP BY video_id
            ORDER BY golden_segs DESC
        ) TO '{video_csv}' (HEADER, DELIMITER ',')
    """)
    vid_count = con.execute("SELECT count(DISTINCT video_id) FROM bucketed").fetchone()[0]
    size_mb = video_csv.stat().st_size / 1e6
    print(f"  {video_csv}: {vid_count:,} videos, {size_mb:.1f} MB")

    elapsed = time.time() - t0
    print(f"\n{'=' * 70}")
    print(f"  Analysis complete in {elapsed:.0f}s")
    print(f"{'=' * 70}")

    con.close()


if __name__ == "__main__":
    main()
