#!/usr/bin/env bash
# Resume-from-R2 drill: prove checkpoint backup + restore + RESUME works.
#
# Flow:
#   1. Train 10 steps -> checkpoint at step 10
#   2. Upload checkpoint to R2
#   3. Delete local checkpoint
#   4. Restore from R2
#   5. Resume training from restored checkpoint for 10 more steps
#   6. Verify global_step > original (proves actual resume)
#
# Usage:
#   bash scripts/drill_resume_from_r2.sh

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$PROJECT_ROOT"

if [ -n "${PYTHON:-}" ] && command -v "$PYTHON" &>/dev/null; then
    PY="$PYTHON"
elif command -v python &>/dev/null; then
    PY=python
elif command -v python3 &>/dev/null; then
    PY=python3
else
    echo "FAIL: No python interpreter found"; exit 1
fi

MODEL_NAME="maya-asr-drill-$(date +%s)"
INITIAL_STEPS=10
RESUME_STEPS=20
RESTORE_DIR="/tmp/maya_asr_drill_restore_$$"
HASH_FILE="/tmp/maya_drill_hashes_$$.json"

echo "============================================"
echo "  Resume-from-R2 Drill"
echo "============================================"
echo "  Model:          $MODEL_NAME"
echo "  Initial steps:  $INITIAL_STEPS"
echo "  Resume to:      $RESUME_STEPS"
echo ""

# --- Step 1: Initial training ---
echo "--- Step 1: Training ($INITIAL_STEPS steps) ---"
bash scripts/run_stage1_smoke.sh --max-steps=$INITIAL_STEPS
echo ""

# Find latest checkpoint dir
CKPT_DIR=$(ls -dt experiments/stage1_smoke/smoke_run/*/checkpoints 2>/dev/null | head -1)
if [ -z "$CKPT_DIR" ] || [ ! -d "$CKPT_DIR" ]; then
    echo "FAIL: No checkpoint directory found after training"
    exit 1
fi
echo "OK: Checkpoint dir: $CKPT_DIR"

# Record hashes
echo ""
echo "--- Recording original file hashes ---"
$PY -c "
import hashlib, json, sys
from pathlib import Path
ckpt_dir = Path('$CKPT_DIR')
files = sorted(f for f in ckpt_dir.iterdir() if f.suffix in ('.ckpt', '.nemo'))
result = {}
for f in files:
    h = hashlib.sha256(f.read_bytes()).hexdigest()
    result[f.name] = {'size': f.stat().st_size, 'sha256': h}
    print(f'  {f.name}: {f.stat().st_size} bytes', file=sys.stderr)
with open('$HASH_FILE', 'w') as fh:
    json.dump(result, fh)
" 2>&1

# --- Step 2: Upload to R2 ---
echo ""
echo "--- Step 2: Uploading to R2 ---"
$PY scripts/upload_checkpoint.py \
    --checkpoint-dir "$CKPT_DIR" \
    --model-name "$MODEL_NAME" \
    --step $INITIAL_STEPS

# --- Step 3: Delete local checkpoint files ---
echo ""
echo "--- Step 3: Deleting local checkpoint files ---"
for f in "$CKPT_DIR"/*.ckpt "$CKPT_DIR"/*.nemo; do
    if [ -f "$f" ]; then
        echo "  Deleting: $(basename "$f")"
        rm "$f"
    fi
done

REMAINING=$(find "$CKPT_DIR" \( -name "*.ckpt" -o -name "*.nemo" \) 2>/dev/null | wc -l)
if [ "$REMAINING" -gt 0 ]; then
    echo "FAIL: $REMAINING checkpoint files still remain"
    exit 1
fi
echo "OK: All local checkpoint files deleted"

# --- Step 4: Restore from R2 ---
echo ""
echo "--- Step 4: Restoring from R2 ---"
mkdir -p "$RESTORE_DIR"
$PY scripts/restore_checkpoint.py \
    --model-name "$MODEL_NAME" \
    --step $INITIAL_STEPS \
    --target-dir "$RESTORE_DIR"

# --- Step 5: Verify restore integrity ---
echo ""
echo "--- Step 5: Verifying restored files ---"
$PY -c "
import hashlib, json, sys
from pathlib import Path

orig = json.load(open('$HASH_FILE'))
restore_dir = Path('$RESTORE_DIR')
all_ok = True
for name, info in orig.items():
    restored = restore_dir / name
    if not restored.exists():
        print(f'  FAIL: {name} not restored')
        all_ok = False
        continue
    actual_size = restored.stat().st_size
    if actual_size != info['size']:
        print(f'  FAIL: {name} size mismatch')
        all_ok = False
        continue
    actual_sha = hashlib.sha256(restored.read_bytes()).hexdigest()
    if actual_sha != info['sha256']:
        print(f'  FAIL: {name} SHA-256 mismatch')
        all_ok = False
        continue
    print(f'  OK: {name} (SHA-256 match)')

if not all_ok:
    print('RESTORE INTEGRITY: FAIL')
    sys.exit(1)
print('RESTORE INTEGRITY: PASS')
"

# --- Step 6: Resume training from restored checkpoint ---
echo ""
echo "--- Step 6: Resuming training from restored checkpoint (to step $RESUME_STEPS) ---"

# Find the .ckpt file (not .nemo) for resume
RESUME_CKPT=$(ls "$RESTORE_DIR"/*.ckpt 2>/dev/null | head -1)
if [ -z "$RESUME_CKPT" ]; then
    echo "FAIL: No .ckpt file in restore dir for resume"
    exit 1
fi
echo "  Resume checkpoint: $RESUME_CKPT"

# Copy restored checkpoint back to experiment dir so NeMo can find it
cp "$RESUME_CKPT" "$CKPT_DIR/"
RESUME_CKPT_LOCAL="$CKPT_DIR/$(basename "$RESUME_CKPT")"

# Run training with --resume-from-checkpoint, targeting more steps
$PY scripts/train_smoke.py \
    --config configs/train/stage1_smoke.yaml \
    --max-steps $RESUME_STEPS \
    --smoke \
    --train-manifest data/manifests/smoke_train_extracted.jsonl \
    --val-manifest data/manifests/smoke_val_extracted.jsonl \
    --resume-from-checkpoint "$RESUME_CKPT_LOCAL" \
    2>&1 | tee /tmp/maya_drill_resume_output_$$.txt

RESUME_EXIT=${PIPESTATUS[0]}

# Check that training actually ran and advanced steps
echo ""
echo "--- Step 7: Verifying resume ---"
FINAL_STEPS=$(grep -oP 'Steps: \K\d+' /tmp/maya_drill_resume_output_$$.txt | tail -1)
rm -f /tmp/maya_drill_resume_output_$$.txt

if [ -z "$FINAL_STEPS" ]; then
    echo "FAIL: Could not extract final step count from training output"
    DRILL_EXIT=1
elif [ "$FINAL_STEPS" -ge "$RESUME_STEPS" ]; then
    echo "  Initial steps: $INITIAL_STEPS"
    echo "  Final steps:   $FINAL_STEPS (target: $RESUME_STEPS)"
    echo "  RESUME: PASS (global_step advanced)"
    DRILL_EXIT=0
else
    echo "  Initial steps: $INITIAL_STEPS"
    echo "  Final steps:   $FINAL_STEPS (expected >= $RESUME_STEPS)"
    echo "  RESUME: FAIL (global_step did not advance sufficiently)"
    DRILL_EXIT=1
fi

# Cleanup
rm -rf "$RESTORE_DIR"
rm -f "$HASH_FILE"

echo ""
echo "--- Cleanup: removing drill objects from R2 ---"
$PY -c "
import os, boto3
from dotenv import load_dotenv
from pathlib import Path
load_dotenv(Path('$PROJECT_ROOT/.env'))
s3 = boto3.client('s3',
    endpoint_url=os.environ['R2_ENDPOINT_URL'],
    aws_access_key_id=os.environ['R2_ACCESS_KEY_ID'],
    aws_secret_access_key=os.environ['R2_SECRET_ACCESS_KEY'],
    region_name='auto')
bucket = os.environ.get('R2_BUCKET_CHECKPOINTS', 'ptcheckpoints')
prefix = '$MODEL_NAME/'
resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
for obj in resp.get('Contents', []):
    s3.delete_object(Bucket=bucket, Key=obj['Key'])
    print(f'  Deleted: {obj[\"Key\"]}')
print('OK: R2 drill objects cleaned up')
" 2>/dev/null || echo "WARN: R2 cleanup failed (non-critical)"

echo ""
echo "============================================"
if [ $DRILL_EXIT -eq 0 ]; then
    echo "  Resume-from-R2 Drill: PASS"
else
    echo "  Resume-from-R2 Drill: FAIL"
fi
echo "============================================"
exit $DRILL_EXIT
