#!/usr/bin/env python3
"""Validate the R2 checkpoint upload pipeline end-to-end.

Creates a small test file, uploads to R2, validates size, downloads back,
verifies content matches, then cleans up. No surprises in production.

Usage:
  python3 scripts/validate_r2_pipeline.py
  python3 scripts/validate_r2_pipeline.py --dry-run
"""

import argparse
import hashlib
import os
import sys
import tempfile
from pathlib import Path

import boto3
from dotenv import load_dotenv

ENV_FILE = Path(__file__).resolve().parent.parent / ".env"
if ENV_FILE.exists():
    load_dotenv(ENV_FILE)


def main():
    parser = argparse.ArgumentParser(description="Validate R2 checkpoint pipeline")
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    endpoint = os.environ.get("R2_ENDPOINT_URL")
    key_id = os.environ.get("R2_ACCESS_KEY_ID")
    secret = os.environ.get("R2_SECRET_ACCESS_KEY")
    bucket = os.environ.get("R2_BUCKET_CHECKPOINTS", "ptcheckpoints")

    print("============================================")
    print("  R2 Checkpoint Pipeline Validation")
    print("============================================")
    print()

    # Check credentials
    if not all([endpoint, key_id, secret]):
        print("FAIL: R2 credentials not set in .env")
        sys.exit(1)
    print("[PASS] R2 credentials present")

    s3 = boto3.client(
        "s3",
        endpoint_url=endpoint,
        aws_access_key_id=key_id,
        aws_secret_access_key=secret,
        region_name="auto",
    )

    # Test connectivity
    try:
        s3.head_bucket(Bucket=bucket)
        print(f"[PASS] Bucket '{bucket}' accessible")
    except Exception as e:
        print(f"FAIL: Cannot access bucket '{bucket}': {e}")
        sys.exit(1)

    if args.dry_run:
        print("\n[DRY RUN] Skipping upload/download test.")
        print("All connectivity checks passed.")
        return

    # Create test file
    test_key = "_maya_asr_pipeline_test/validation_test.bin"
    test_content = os.urandom(1024 * 100)  # 100 KB
    local_hash = hashlib.sha256(test_content).hexdigest()

    with tempfile.NamedTemporaryFile(delete=False, suffix=".bin") as f:
        f.write(test_content)
        local_path = f.name

    try:
        # Upload
        print(f"\nUploading test file ({len(test_content)} bytes)...")
        s3.upload_file(local_path, bucket, test_key)
        print("[PASS] Upload succeeded")

        # Validate remote size
        resp = s3.head_object(Bucket=bucket, Key=test_key)
        remote_size = resp["ContentLength"]
        assert remote_size == len(test_content), (
            f"Size mismatch: local={len(test_content)} remote={remote_size}"
        )
        print(f"[PASS] Remote size matches ({remote_size} bytes)")

        # Download and verify content
        with tempfile.NamedTemporaryFile(delete=False, suffix=".bin") as f:
            download_path = f.name
        s3.download_file(bucket, test_key, download_path)
        downloaded = Path(download_path).read_bytes()
        download_hash = hashlib.sha256(downloaded).hexdigest()
        assert local_hash == download_hash, (
            f"Content mismatch: local={local_hash[:16]}... remote={download_hash[:16]}..."
        )
        print("[PASS] Downloaded content SHA-256 matches")
        Path(download_path).unlink()

        # Cleanup remote
        s3.delete_object(Bucket=bucket, Key=test_key)
        print("[PASS] Remote test file cleaned up")

    except Exception as e:
        print(f"\nFAIL: {e}")
        # Try cleanup
        try:
            s3.delete_object(Bucket=bucket, Key=test_key)
        except Exception:
            pass
        sys.exit(1)
    finally:
        Path(local_path).unlink(missing_ok=True)

    print()
    print("============================================")
    print("  R2 Pipeline Validation: ALL PASS")
    print("============================================")
    print()
    print(f"  Endpoint:  {endpoint}")
    print(f"  Bucket:    {bucket}")
    print("  Upload:    OK")
    print("  Download:  OK")
    print("  Integrity: OK (SHA-256 verified)")
    print()
    print("Ready for production checkpoint uploads.")


if __name__ == "__main__":
    main()
