#!/usr/bin/env bash
# Push milestone checkpoints (every 50K opt steps) to R2.
#
# Path: ptcheckpoints/hybrid-tdt-gemma/MM-DD-YYYY/ckpt-50000/<files>
#
# Called by cron every 5 minutes. Tracks pushed milestones in state file.

set -euo pipefail

PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$PROJECT_ROOT"

MILESTONE_INTERVAL=20000
STATE_FILE="/tmp/milestone_push_state"
DATE_PREFIX="$(date +%m-%d-%Y)"
BUCKET="ptcheckpoints"

# Find latest checkpoint dir
# Check both experiment dirs (v1 and v2)
EXP_BASE=""
for candidate in "experiments/stage1_prod/maya_asr_stage1_tdt_v2" "experiments/stage1_prod/maya_asr_stage1_pretrained"; do
    if [ -d "$candidate" ]; then
        EXP_BASE="$candidate"
        break
    fi
done
[ -z "$EXP_BASE" ] && exit 0
# Override to search inside EXP_BASE
CKPT_DIR=""
for d in $(ls -dt "$EXP_BASE"/*/checkpoints 2>/dev/null); do
    if [ -d "$d" ]; then
        CKPT_DIR="$d"
        break
    fi
done

[ -z "$CKPT_DIR" ] && exit 0

# Get current opt step from latest checkpoint file
LATEST_CKPT=$(ls -t "$CKPT_DIR"/*-last.ckpt 2>/dev/null | head -1)
[ -z "$LATEST_CKPT" ] && exit 0

CURRENT_STEP=$(python3 -c "
import torch
ckpt = torch.load('$LATEST_CKPT', map_location='cpu', weights_only=False)
print(ckpt.get('global_step', 0))
" 2>/dev/null || echo "0")

[ "$CURRENT_STEP" = "0" ] && exit 0

# Load pushed milestones
touch "$STATE_FILE"
PUSHED=$(cat "$STATE_FILE")

# Check milestones
MILESTONE=$MILESTONE_INTERVAL
while [ "$MILESTONE" -le "$CURRENT_STEP" ]; do
    if echo "$PUSHED" | grep -q "^${MILESTONE}$"; then
        MILESTONE=$((MILESTONE + MILESTONE_INTERVAL))
        continue
    fi

    R2_PREFIX="hybrid-tdt-gemma/${DATE_PREFIX}/ckpt-${MILESTONE}"
    echo "$(date -Iseconds) Pushing milestone: opt_step=$MILESTONE to s3://${BUCKET}/${R2_PREFIX}/"

    # Upload using rclone (faster than boto3 for large files)
    source "$PROJECT_ROOT/.env"

    # Configure rclone for R2 if not already done
    if ! rclone listremotes 2>/dev/null | grep -q "r2:"; then
        rclone config create r2 s3 \
            provider Cloudflare \
            access_key_id "$R2_ACCESS_KEY_ID" \
            secret_access_key "$R2_SECRET_ACCESS_KEY" \
            endpoint "$R2_ENDPOINT_URL" \
            acl private \
            no_check_bucket true 2>/dev/null
    fi

    # Upload all checkpoint files
    rclone copy "$CKPT_DIR" "r2:${BUCKET}/${R2_PREFIX}/" \
        --include "*.ckpt" --include "*.nemo" \
        --progress --transfers 4 \
        2>&1 | tail -5

    if [ $? -eq 0 ]; then
        echo "$MILESTONE" >> "$STATE_FILE"
        echo "$(date -Iseconds) SUCCESS: ckpt-${MILESTONE} pushed to R2"
    else
        echo "$(date -Iseconds) FAILED: ckpt-${MILESTONE}" >&2
    fi

    MILESTONE=$((MILESTONE + MILESTONE_INTERVAL))
done
