#!/usr/bin/env python3
"""Phase 2: Validate conversion results across all shards.

Checks:
- All queued shards have status files
- Input count == output count + errors (no silent loss)
- All outputs are 16kHz/mono
- Random decode spot-checks

Usage:
  python3 tools/phase2_validate_conversion.py
"""

import io
import json
import random
import sys
import tarfile
import time
from pathlib import Path

import pyarrow.parquet as pq
import soundfile

ARTIFACTS_DIR = Path("/workspace/maya-asr/artifacts/phase2")
QUEUE_PATH = ARTIFACTS_DIR / "conversion_queue.parquet"


def main():
    if not QUEUE_PATH.exists():
        print("ERROR: No conversion queue", file=sys.stderr)
        sys.exit(1)

    queue_df = pq.read_table(QUEUE_PATH).to_pandas()
    total = len(queue_df)
    print(f"Validating {total} shards...")

    t0 = time.time()
    success = 0
    failed = 0
    skipped = 0
    total_input = 0
    total_output = 0
    total_errors = 0
    total_hours = 0.0
    issues = []

    for _, row in queue_df.iterrows():
        shard_dir = Path(row["shard_dir"])
        status_file = shard_dir / "shard_conversion_status.json"

        if not status_file.exists():
            issues.append({"shard": str(shard_dir), "issue": "no status file"})
            failed += 1
            continue

        status = json.loads(status_file.read_text())

        if status["status"] == "skipped":
            skipped += 1
            continue
        elif status["status"] == "error":
            issues.append({"shard": str(shard_dir), "issue": status.get("error", "unknown")})
            failed += 1
            continue
        elif status["status"] != "success":
            issues.append({"shard": str(shard_dir), "issue": f"unexpected status: {status['status']}"})
            failed += 1
            continue

        success += 1
        total_input += status.get("input_count", 0)
        total_output += status.get("output_count", 0)
        total_errors += status.get("error_count", 0)
        total_hours += status.get("output_hours", 0)

        # Verify no silent loss
        expected = status.get("input_count", 0) - status.get("error_count", 0)
        if status.get("output_count", 0) != expected:
            issues.append(
                {
                    "shard": str(shard_dir),
                    "issue": f"count mismatch: input={status['input_count']} "
                    f"errors={status['error_count']} output={status['output_count']}",
                }
            )

    # Random spot-checks: decode from 10 random converted shards
    print("\nRunning spot-check decodes...")
    spot_issues = []
    converted_shards = [
        Path(row["shard_dir"])
        for _, row in queue_df.iterrows()
        if (Path(row["shard_dir"]) / "shard_conversion_status.json").exists()
        and json.loads((Path(row["shard_dir"]) / "shard_conversion_status.json").read_text()).get(
            "status"
        )
        == "success"
    ]

    check_count = min(10, len(converted_shards))
    for shard_dir in random.sample(converted_shards, check_count):
        tar_path = shard_dir / "audio.tar"
        try:
            with tarfile.open(tar_path, "r") as tf:
                members = [m for m in tf.getmembers() if m.isfile()]
                if members:
                    m = random.choice(members)
                    f = tf.extractfile(m)
                    data, sr = soundfile.read(io.BytesIO(f.read()))
                    if sr != 16000:
                        spot_issues.append(f"{shard_dir}: {m.name} sr={sr}")
                    if len(data.shape) > 1:
                        spot_issues.append(f"{shard_dir}: {m.name} not mono")
        except Exception as e:
            spot_issues.append(f"{shard_dir}: read error: {e}")

    elapsed = time.time() - t0

    # Write report
    report = {
        "total_shards": total,
        "converted_success": success,
        "converted_failed": failed,
        "skipped": skipped,
        "total_input_samples": total_input,
        "total_output_samples": total_output,
        "total_conversion_errors": total_errors,
        "silent_loss": total_input - total_output - total_errors,
        "total_hours": round(total_hours, 1),
        "spot_check_count": check_count,
        "spot_check_issues": spot_issues,
        "issues": issues[:50],
        "elapsed_s": round(elapsed, 1),
    }

    report_path = ARTIFACTS_DIR / "validation_report.json"
    with open(report_path, "w") as f:
        json.dump(report, f, indent=2)

    # Print
    print(f"\n{'='*50}")
    print(f"Validation Report")
    print(f"{'='*50}")
    print(f"  Total shards:     {total}")
    print(f"  Success:          {success}")
    print(f"  Failed:           {failed}")
    print(f"  Skipped:          {skipped}")
    print(f"  Input samples:    {total_input:,}")
    print(f"  Output samples:   {total_output:,}")
    print(f"  Conv errors:      {total_errors}")
    print(f"  Silent loss:      {report['silent_loss']}")
    print(f"  Total hours:      {total_hours:.1f}h")
    print(f"  Spot-check issues: {len(spot_issues)}")

    if issues:
        print(f"\nIssues ({len(issues)}):")
        for iss in issues[:10]:
            print(f"  {iss['shard']}: {iss['issue']}")

    all_ok = (
        failed == 0
        and report["silent_loss"] == 0
        and len(spot_issues) == 0
    )
    print(f"\nResult: {'PASS' if all_ok else 'FAIL'}")

    if not all_ok:
        sys.exit(1)


if __name__ == "__main__":
    main()
