#!/usr/bin/env python3
"""Phase 3: Validate the tar-offset loader on random samples.

Tests decode success rate, checks 16kHz/mono invariant, measures throughput.

Usage:
  python3 tools/phase3_validate_loader.py --samples 10000 --workers 8
"""

import argparse
import json
import os
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import numpy as np
import pyarrow.parquet as pq

os.environ["OMP_NUM_THREADS"] = "1"

PHASE3_DIR = Path("/workspace/maya-asr/artifacts/phase3")


def validate_sample(args_tuple):
    """Validate one sample: read via offset, check sr/mono."""
    tar_path, offset, nbytes, expected_dur, language = args_tuple
    import io
    import soundfile

    try:
        fd = os.open(tar_path, os.O_RDONLY)
        raw = os.pread(fd, int(nbytes), int(offset))
        os.close(fd)

        if len(raw) != int(nbytes):
            return {"ok": False, "error": "short_read", "language": language}

        data, sr = soundfile.read(io.BytesIO(raw))
        if sr != 16000:
            return {"ok": False, "error": f"sr={sr}", "language": language}
        if len(data.shape) > 1:
            return {"ok": False, "error": "not_mono", "language": language}

        actual_dur = len(data) / sr
        return {
            "ok": True,
            "language": language,
            "duration": actual_dur,
            "decode_bytes": len(raw),
        }
    except Exception as e:
        return {"ok": False, "error": str(e)[:100], "language": language}


def main():
    parser = argparse.ArgumentParser(description="Validate tar-offset loader")
    parser.add_argument("--samples", type=int, default=10000)
    parser.add_argument("--workers", type=int, default=16)
    parser.add_argument("--manifest", type=str, default=str(PHASE3_DIR / "train_manifest.parquet"))
    args = parser.parse_args()

    print(f"Loading manifest: {args.manifest}")
    df = pq.read_table(
        args.manifest,
        columns=["tar_path", "tar_offset_data", "tar_nbytes", "duration_s", "language"],
    ).to_pandas()

    print(f"Total rows: {len(df):,}")

    # Sample
    n = min(args.samples, len(df))
    sample = df.sample(n, random_state=42)
    print(f"Validating {n:,} random samples with {args.workers} workers...")

    # Build args
    tasks = [
        (str(row["tar_path"]), int(row["tar_offset_data"]), int(row["tar_nbytes"]),
         float(row["duration_s"]), str(row["language"]))
        for _, row in sample.iterrows()
    ]

    t0 = time.time()
    results = []
    with ProcessPoolExecutor(max_workers=args.workers) as executor:
        for r in executor.map(validate_sample, tasks, chunksize=100):
            results.append(r)

    elapsed = time.time() - t0

    # Stats
    ok = sum(1 for r in results if r["ok"])
    fail = sum(1 for r in results if not r["ok"])
    total_audio = sum(r.get("duration", 0) for r in results if r["ok"])
    total_bytes = sum(r.get("decode_bytes", 0) for r in results if r["ok"])

    # Per-language
    lang_stats = {}
    for r in results:
        lang = r["language"]
        lang_stats.setdefault(lang, {"ok": 0, "fail": 0})
        if r["ok"]:
            lang_stats[lang]["ok"] += 1
        else:
            lang_stats[lang]["fail"] += 1

    # Errors
    errors = {}
    for r in results:
        if not r["ok"]:
            err = r["error"]
            errors[err] = errors.get(err, 0) + 1

    report = {
        "samples_tested": n,
        "success": ok,
        "failed": fail,
        "success_rate": round(ok / max(n, 1) * 100, 4),
        "elapsed_s": round(elapsed, 1),
        "audio_hours_decoded": round(total_audio / 3600, 2),
        "throughput_audio_sec_per_sec": round(total_audio / max(elapsed, 0.01), 0),
        "per_language": lang_stats,
        "errors": errors,
    }

    out_path = PHASE3_DIR / "loader_validation.json"
    with open(out_path, "w") as f:
        json.dump(report, f, indent=2)

    print(f"\n{'='*50}")
    print(f"Loader Validation Report")
    print(f"{'='*50}")
    print(f"  Samples:    {n:,}")
    print(f"  Success:    {ok:,} ({ok/n*100:.2f}%)")
    print(f"  Failed:     {fail:,}")
    print(f"  Elapsed:    {elapsed:.1f}s")
    print(f"  Throughput: {total_audio/max(elapsed,0.01):.0f}x realtime")
    if errors:
        print(f"\n  Errors:")
        for err, cnt in sorted(errors.items(), key=lambda x: -x[1]):
            print(f"    {err}: {cnt}")
    print(f"\n  Per-language:")
    for lang in sorted(lang_stats):
        s = lang_stats[lang]
        print(f"    {lang}: {s['ok']}/{s['ok']+s['fail']} ok")
    print(f"\n  Result: {'PASS' if fail == 0 else 'FAIL'}")
    print(f"  Report: {out_path}")

    sys.exit(0 if ok / max(n, 1) >= 0.9999 else 1)


if __name__ == "__main__":
    main()
