#!/usr/bin/env python3
"""Restore a training checkpoint from Cloudflare R2.

Downloads .ckpt/.nemo files and verifies integrity via SHA-256.

Usage:
  # Restore latest step
  python3 scripts/restore_checkpoint.py \
    --model-name maya-asr-stage1 --target-dir /tmp/restored

  # Restore specific step
  python3 scripts/restore_checkpoint.py \
    --model-name maya-asr-stage1 --step 50 --target-dir /tmp/restored

  # List available steps
  python3 scripts/restore_checkpoint.py \
    --model-name maya-asr-stage1 --list
"""

import argparse
import hashlib
import json
import os
import re
import sys
from pathlib import Path

import boto3
from dotenv import load_dotenv

ENV_FILE = Path(__file__).resolve().parent.parent / ".env"
if ENV_FILE.exists():
    load_dotenv(ENV_FILE)


def get_s3_client():
    endpoint = os.environ.get("R2_ENDPOINT_URL")
    key_id = os.environ.get("R2_ACCESS_KEY_ID")
    secret = os.environ.get("R2_SECRET_ACCESS_KEY")
    if not all([endpoint, key_id, secret]):
        print("ERROR: R2 credentials not set.", file=sys.stderr)
        sys.exit(1)
    return boto3.client(
        "s3",
        endpoint_url=endpoint,
        aws_access_key_id=key_id,
        aws_secret_access_key=secret,
        region_name="auto",
    )


def list_steps(s3, bucket: str, model_name: str) -> list[int]:
    """List all available step numbers for a model on R2."""
    prefix = f"{model_name}/"
    paginator = s3.get_paginator("list_objects_v2")
    steps = set()
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"):
        for cp in page.get("CommonPrefixes", []):
            # cp["Prefix"] = "model/step_00000050/"
            m = re.search(r"step_(\d+)", cp["Prefix"])
            if m:
                steps.add(int(m.group(1)))
    return sorted(steps)


def _sha256(path: Path) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()


def main():
    parser = argparse.ArgumentParser(description="Restore checkpoint from R2")
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--step", type=int, default=None, help="Step to restore (default: latest)")
    parser.add_argument("--target-dir", type=Path, default=None)
    parser.add_argument("--bucket", type=str, default=None)
    parser.add_argument("--list", action="store_true", default=False, help="List available steps")
    args = parser.parse_args()

    bucket = args.bucket or os.environ.get("R2_BUCKET_CHECKPOINTS", "ptcheckpoints")
    s3 = get_s3_client()

    # List mode
    if args.list:
        steps = list_steps(s3, bucket, args.model_name)
        if not steps:
            print(f"No checkpoints found for {args.model_name} in {bucket}")
            sys.exit(1)
        print(f"Available steps for {args.model_name}:")
        for step in steps:
            print(f"  step_{step:08d}")
        return

    # Determine step
    if args.step is not None:
        step = args.step
    else:
        steps = list_steps(s3, bucket, args.model_name)
        if not steps:
            print(f"ERROR: No checkpoints for {args.model_name}", file=sys.stderr)
            sys.exit(1)
        step = steps[-1]
        print(f"Latest step: {step}")

    if args.target_dir is None:
        print("ERROR: --target-dir required for restore", file=sys.stderr)
        sys.exit(1)

    prefix = f"{args.model_name}/step_{step:08d}"
    target = args.target_dir
    target.mkdir(parents=True, exist_ok=True)

    print(f"Restoring {args.model_name} step {step}")
    print(f"  Source: s3://{bucket}/{prefix}/")
    print(f"  Target: {target}")
    print()

    # Download metadata first
    meta_key = f"{prefix}/upload_metadata.json"
    meta_local = target / "upload_metadata.json"
    try:
        s3.download_file(bucket, meta_key, str(meta_local))
        metadata = json.loads(meta_local.read_text())
        print(f"  Metadata: {len(metadata.get('files', []))} files recorded")
    except Exception:
        print("  WARN: upload_metadata.json not found, downloading without verification")
        metadata = None

    # List all objects under prefix
    resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix + "/")
    objects = resp.get("Contents", [])
    ckpt_objects = [o for o in objects if o["Key"].endswith((".ckpt", ".nemo"))]

    if not ckpt_objects:
        print(f"ERROR: No .ckpt/.nemo files at {prefix}/", file=sys.stderr)
        sys.exit(1)

    # Build expected checksums from metadata
    expected = {}
    if metadata:
        for f in metadata.get("files", []):
            expected[f["name"]] = {
                "size_bytes": f["size_bytes"],
                "sha256": f.get("sha256"),
            }

    # Download and verify
    print(f"\n--- Downloading {len(ckpt_objects)} files ---")
    all_ok = True
    restored_files = []
    for obj in ckpt_objects:
        key = obj["Key"]
        filename = key.split("/")[-1]
        local_path = target / filename
        remote_size = obj["Size"]

        print(f"  Downloading: {filename} ({remote_size / (1024**2):.1f} MB)")
        s3.download_file(bucket, key, str(local_path))

        # Verify size
        local_size = local_path.stat().st_size
        if local_size != remote_size:
            print(f"    FAIL: size mismatch local={local_size} remote={remote_size}")
            all_ok = False
            continue

        # Verify SHA-256 if available
        exp = expected.get(filename, {})
        exp_sha = exp.get("sha256")
        if exp_sha:
            actual_sha = _sha256(local_path)
            if actual_sha != exp_sha:
                print("    FAIL: SHA-256 mismatch")
                all_ok = False
                continue
            print("    OK: size + SHA-256 verified")
        else:
            print("    OK: size verified (no SHA-256 in metadata)")

        restored_files.append(local_path)

    print()
    if all_ok and restored_files:
        print(f"Restore: PASS ({len(restored_files)} files)")
        print(f"  Target dir: {target}")
        for f in restored_files:
            print(f"  {f.name} ({f.stat().st_size / (1024**2):.1f} MB)")
    else:
        print("Restore: FAIL")
        sys.exit(1)


if __name__ == "__main__":
    main()
