#!/usr/bin/env python3
"""Upload a training checkpoint to Cloudflare R2.

Uploads .ckpt and .nemo files from an experiment checkpoint directory
to R2 with the naming convention:
  s3://{bucket}/{model_name}/step_{step:08d}/{filename}

Validates the upload by comparing local and remote file sizes.

Usage:
  # Upload a specific checkpoint dir
  python3 scripts/upload_checkpoint.py \
    --checkpoint-dir experiments/stage1_smoke/smoke_run/2026-03-23_05-19-22/checkpoints \
    --model-name maya-asr-stage1 \
    --step 50

  # Validate only (no upload)
  python3 scripts/upload_checkpoint.py \
    --checkpoint-dir experiments/stage1_smoke/smoke_run/2026-03-23_05-19-22/checkpoints \
    --model-name maya-asr-stage1 \
    --step 50 \
    --validate-only

  # Dry run
  python3 scripts/upload_checkpoint.py ... --dry-run
"""

import argparse
import hashlib
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path

import boto3
from dotenv import load_dotenv

# Load R2 credentials
ENV_FILE = Path(__file__).resolve().parent.parent / ".env"
if ENV_FILE.exists():
    load_dotenv(ENV_FILE)


def get_s3_client():
    """Create a boto3 S3 client configured for Cloudflare R2."""
    endpoint = os.environ.get("R2_ENDPOINT_URL")
    key_id = os.environ.get("R2_ACCESS_KEY_ID")
    secret = os.environ.get("R2_SECRET_ACCESS_KEY")

    if not all([endpoint, key_id, secret]):
        print("ERROR: R2 credentials not set. Check .env file.", file=sys.stderr)
        print("  Need: R2_ENDPOINT_URL, R2_ACCESS_KEY_ID, R2_SECRET_ACCESS_KEY", file=sys.stderr)
        sys.exit(1)

    return boto3.client(
        "s3",
        endpoint_url=endpoint,
        aws_access_key_id=key_id,
        aws_secret_access_key=secret,
        region_name="auto",
    )


def upload_file(s3, bucket: str, local_path: Path, remote_key: str, dry_run: bool):
    """Upload a single file to R2."""
    size_mb = local_path.stat().st_size / (1024 * 1024)
    if dry_run:
        print(f"  [DRY RUN] Would upload: {local_path.name} ({size_mb:.1f} MB) -> {remote_key}")
        return True

    print(f"  Uploading: {local_path.name} ({size_mb:.1f} MB) -> {remote_key}")
    s3.upload_file(str(local_path), bucket, remote_key)
    return True


def validate_upload(s3, bucket: str, local_path: Path, remote_key: str) -> bool:
    """Validate that remote file exists and matches local size."""
    local_size = local_path.stat().st_size
    try:
        resp = s3.head_object(Bucket=bucket, Key=remote_key)
        remote_size = resp["ContentLength"]
        if local_size != remote_size:
            print(
                f"  FAIL: Size mismatch for {remote_key}: local={local_size} remote={remote_size}",
                file=sys.stderr,
            )
            return False
        print(f"  OK: {remote_key} ({remote_size / (1024**2):.1f} MB)")
        return True
    except s3.exceptions.ClientError:
        print(f"  FAIL: {remote_key} not found on R2", file=sys.stderr)
        return False


def main():
    parser = argparse.ArgumentParser(description="Upload checkpoint to R2")
    parser.add_argument("--checkpoint-dir", type=Path, required=True)
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--step", type=int, required=True)
    parser.add_argument("--bucket", type=str, default=None)
    parser.add_argument("--dry-run", action="store_true", default=False)
    parser.add_argument("--validate-only", action="store_true", default=False)
    parser.add_argument(
        "--files",
        nargs="*",
        default=None,
        help="Specific filenames to upload (default: all .ckpt/.nemo in dir)",
    )
    args = parser.parse_args()

    bucket = args.bucket or os.environ.get("R2_BUCKET_CHECKPOINTS", "ptcheckpoints")
    prefix_fmt = getattr(args, "prefix_format", None) or f"{args.model_name}/step_{args.step:08d}"
    prefix = prefix_fmt

    if not args.checkpoint_dir.exists():
        print(f"ERROR: Checkpoint dir not found: {args.checkpoint_dir}", file=sys.stderr)
        sys.exit(1)

    # Find checkpoint files — either specific files or all .ckpt/.nemo
    if args.files:
        ckpt_files = sorted(args.checkpoint_dir / f for f in args.files if (args.checkpoint_dir / f).exists())
    else:
        ckpt_files = sorted(
            p for p in args.checkpoint_dir.iterdir() if p.is_file() and p.suffix in (".ckpt", ".nemo")
        )

    if not ckpt_files:
        print(f"ERROR: No .ckpt or .nemo files in {args.checkpoint_dir}", file=sys.stderr)
        sys.exit(1)

    print(f"Checkpoint upload: {args.model_name} step {args.step}")
    print(f"  Source: {args.checkpoint_dir}")
    print(f"  Target: s3://{bucket}/{prefix}/")
    print(f"  Files:  {len(ckpt_files)}")
    total_mb = sum(f.stat().st_size for f in ckpt_files) / (1024**2)
    print(f"  Size:   {total_mb:.1f} MB")
    print()

    s3 = get_s3_client()

    if args.validate_only:
        print("--- Validating existing uploads ---")
        all_ok = True
        for ckpt in ckpt_files:
            remote_key = f"{prefix}/{ckpt.name}"
            if not validate_upload(s3, bucket, ckpt, remote_key):
                all_ok = False
        print()
        if all_ok:
            print("Validation: PASS")
        else:
            print("Validation: FAIL")
            sys.exit(1)
        return

    # Upload
    print("--- Uploading ---")
    for ckpt in ckpt_files:
        remote_key = f"{prefix}/{ckpt.name}"
        upload_file(s3, bucket, ckpt, remote_key, args.dry_run)

    if args.dry_run:
        print("\n[DRY RUN] No files uploaded.")
        return

    # Validate
    print("\n--- Validating uploads ---")
    all_ok = True
    for ckpt in ckpt_files:
        remote_key = f"{prefix}/{ckpt.name}"
        if not validate_upload(s3, bucket, ckpt, remote_key):
            all_ok = False

    # Compute SHA-256 for each file
    def _sha256(path):
        h = hashlib.sha256()
        with open(path, "rb") as fh:
            for chunk in iter(lambda: fh.read(1 << 20), b""):
                h.update(chunk)
        return h.hexdigest()

    # Write upload metadata
    metadata = {
        "model_name": args.model_name,
        "step": args.step,
        "bucket": bucket,
        "prefix": prefix,
        "files": [
            {
                "name": f.name,
                "size_bytes": f.stat().st_size,
                "sha256": _sha256(f),
                "remote_key": f"{prefix}/{f.name}",
            }
            for f in ckpt_files
        ],
        "total_size_mb": round(total_mb, 1),
        "uploaded_at": datetime.now(timezone.utc).isoformat(),
        "validated": all_ok,
    }
    meta_file = args.checkpoint_dir / "upload_metadata.json"
    with open(meta_file, "w") as f:
        json.dump(metadata, f, indent=2)

    # Upload metadata to R2 too (for restore)
    remote_meta_key = f"{prefix}/upload_metadata.json"
    s3.upload_file(str(meta_file), bucket, remote_meta_key)

    print()
    if all_ok:
        print(f"Upload + validation: PASS ({total_mb:.1f} MB)")
        print(f"Metadata: {meta_file}")
    else:
        print("Upload: DONE but validation FAILED")
        sys.exit(1)


if __name__ == "__main__":
    main()
