#!/usr/bin/env python3
"""Post-run validation for the SFT encoding pipeline.

Verifies that every DONE shard in the DB has a matching xcodec2_tokens.parquet
in R2, that the segment counts match, and reports totals per dataset.

Usage:
    python scripts/validate_sft_output.py                   # full validation
    python scripts/validate_sft_output.py --dataset final-export  # one dataset
    python scripts/validate_sft_output.py --sample 100      # spot-check 100 shards
    python scripts/validate_sft_output.py --deep 10         # download & decode 10 parquets
"""
from __future__ import annotations

import argparse
import io
import os
import random
import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import boto3
import numpy as np
import pyarrow.parquet as pq
from dotenv import load_dotenv

PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from codecbench.pipeline.sft_supabase import SFTOrchestrator


def get_s3():
    return boto3.client(
        "s3",
        endpoint_url=os.environ["R2_ENDPOINT_URL"],
        aws_access_key_id=os.environ["R2_ACCESS_KEY_ID"],
        aws_secret_access_key=os.environ["R2_SECRET_ACCESS_KEY"],
        region_name="auto",
    )


def check_parquet_exists(s3, bucket: str, shard_key: str) -> dict:
    """HEAD the parquet file for a shard. Returns status dict."""
    key = f"{shard_key}xcodec2_tokens.parquet"
    try:
        resp = s3.head_object(Bucket=bucket, Key=key)
        return {
            "shard_key": shard_key,
            "exists": True,
            "size_mb": resp["ContentLength"] / 1e6,
        }
    except Exception as e:
        return {"shard_key": shard_key, "exists": False, "error": str(e)[:200]}


def deep_check_parquet(s3, bucket: str, shard_row: dict) -> dict:
    """Download parquet and verify segment count + token integrity."""
    shard_key = shard_row["shard_key"]
    key = f"{shard_key}xcodec2_tokens.parquet"
    try:
        body = s3.get_object(Bucket=bucket, Key=key)["Body"].read()
        table = pq.read_table(io.BytesIO(body))
        seg_ids = table.column("segment_id").to_pylist()
        raw_tokens = table.column("xcodec2_tokens").to_pylist()
        token_counts = table.column("token_count").to_pylist()

        actual_segments = len(seg_ids)
        db_segments = shard_row.get("segments_encoded") or 0

        non_empty_tokens = sum(1 for t in raw_tokens if len(t) > 0)
        valid_counts = sum(1 for t, c in zip(raw_tokens, token_counts)
                          if len(np.frombuffer(t, dtype=np.uint16)) == c)

        return {
            "shard_key": shard_key,
            "ok": True,
            "parquet_segments": actual_segments,
            "db_segments": db_segments,
            "segments_match": actual_segments == db_segments,
            "non_empty_tokens": non_empty_tokens,
            "valid_token_counts": valid_counts,
            "size_mb": len(body) / 1e6,
        }
    except Exception as e:
        return {"shard_key": shard_key, "ok": False, "error": str(e)[:300]}


def main():
    ap = argparse.ArgumentParser(description="Validate SFT encoding output")
    ap.add_argument("--dataset", type=str, default=None, help="Filter by dataset name")
    ap.add_argument("--sample", type=int, default=None,
                    help="Spot-check N random shards (HEAD only)")
    ap.add_argument("--deep", type=int, default=0,
                    help="Download and verify N parquets for segment count + token integrity")
    ap.add_argument("--workers", type=int, default=32, help="Parallel workers")
    args = ap.parse_args()

    load_dotenv(PROJECT_ROOT / ".env")
    bucket = os.environ.get("R2_BUCKET_DESTINATION", "finalsftdata")
    s3 = get_s3()
    orch = SFTOrchestrator()

    where = "WHERE status = 'DONE'"
    params: tuple = ()
    if args.dataset:
        where += " AND dataset = %s"
        params = (args.dataset,)

    shards = orch._query(
        f"SELECT shard_key, dataset, language, segments_encoded, total_audio_s, output_r2_key "
        f"FROM sft_encoding_shards {where} ORDER BY shard_key",
        params,
    )
    print(f"\n{'='*70}")
    print(f"  SFT OUTPUT VALIDATION")
    print(f"{'='*70}")
    print(f"  Total DONE shards in DB: {len(shards)}")

    # DB-level summary
    ds_summary = defaultdict(lambda: {"shards": 0, "segments": 0, "audio_h": 0.0})
    for s in shards:
        ds = s["dataset"]
        ds_summary[ds]["shards"] += 1
        ds_summary[ds]["segments"] += s.get("segments_encoded") or 0
        ds_summary[ds]["audio_h"] += (s.get("total_audio_s") or 0) / 3600

    print(f"\n  {'Dataset':<20} {'Shards':>8} {'Segments':>12} {'Audio (h)':>10}")
    print(f"  {'-'*55}")
    total_segs = 0
    total_audio = 0.0
    for ds in sorted(ds_summary):
        d = ds_summary[ds]
        total_segs += d["segments"]
        total_audio += d["audio_h"]
        print(f"  {ds:<20} {d['shards']:>8,} {d['segments']:>12,} {d['audio_h']:>10,.1f}")
    print(f"  {'-'*55}")
    print(f"  {'TOTAL':<20} {len(shards):>8,} {total_segs:>12,} {total_audio:>10,.1f}")

    # R2 existence check (HEAD)
    check_shards = shards
    if args.sample:
        check_shards = random.sample(shards, min(args.sample, len(shards)))
    print(f"\n  Checking R2 parquet existence for {len(check_shards)} shards...")

    exists_count = 0
    missing = []
    total_size_mb = 0.0
    with ThreadPoolExecutor(max_workers=args.workers) as pool:
        futures = {
            pool.submit(check_parquet_exists, s3, bucket, s["shard_key"]): s
            for s in check_shards
        }
        for i, fut in enumerate(as_completed(futures)):
            result = fut.result()
            if result["exists"]:
                exists_count += 1
                total_size_mb += result.get("size_mb", 0)
            else:
                missing.append(result)
            if (i + 1) % 500 == 0:
                print(f"    ... checked {i+1}/{len(check_shards)}")

    print(f"\n  R2 parquet check: {exists_count}/{len(check_shards)} exist "
          f"({total_size_mb:,.1f} MB total)")
    if missing:
        print(f"  MISSING parquets: {len(missing)}")
        for m in missing[:20]:
            print(f"    {m['shard_key']}: {m.get('error','?')[:80]}")

    # Deep verification
    if args.deep > 0:
        deep_shards = random.sample(shards, min(args.deep, len(shards)))
        print(f"\n  Deep-checking {len(deep_shards)} parquets (download + verify)...")
        deep_ok = 0
        deep_mismatch = []
        with ThreadPoolExecutor(max_workers=min(args.workers, 8)) as pool:
            futures = {
                pool.submit(deep_check_parquet, s3, bucket, s): s
                for s in deep_shards
            }
            for fut in as_completed(futures):
                result = fut.result()
                if result.get("ok") and result.get("segments_match"):
                    deep_ok += 1
                else:
                    deep_mismatch.append(result)
        print(f"  Deep check: {deep_ok}/{len(deep_shards)} passed")
        if deep_mismatch:
            print(f"  Mismatches:")
            for m in deep_mismatch:
                print(f"    {m['shard_key']}: parquet_segs={m.get('parquet_segments')} "
                      f"db_segs={m.get('db_segments')} err={m.get('error','')[:80]}")

    # Final verdict
    print(f"\n{'='*70}")
    all_ok = exists_count == len(check_shards) and len(missing) == 0
    if all_ok:
        print(f"  VALIDATION PASSED")
        print(f"  {len(shards)} shards, {total_segs:,} segments, {total_audio:,.1f}h audio")
        print(f"  final-export segments: {ds_summary.get('final-export', {}).get('segments', 0):,}")
    else:
        print(f"  VALIDATION FAILED — {len(missing)} missing parquets")
    print(f"{'='*70}\n")


if __name__ == "__main__":
    main()
