#!/usr/bin/env bash
# Recycle validation fleet: push updated code and restart workers.
# Copies changed files via SCP, kills the running python process,
# and the onstart restart loop picks up the new code within 10s.
#
# Usage: ./scripts/recycle_fleet.sh [--max N] [--dry-run]

set -euo pipefail

MAX_PARALLEL=20
MAX_INSTANCES=0  # 0 = all
DRY_RUN=false
SSH_KEY="/home/ubuntu/.ssh/id_ed25519.backup_1769119693"
SSH_OPTS="-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ConnectTimeout=10 -o LogLevel=ERROR"

while [[ $# -gt 0 ]]; do
    case "$1" in
        --max) MAX_INSTANCES="$2"; shift 2 ;;
        --dry-run) DRY_RUN=true; shift ;;
        *) echo "Unknown: $1"; exit 1 ;;
    esac
done

FILES_TO_PUSH=(
    "validations/worker.py"
    "validations/config.py"
)

echo "=== Fleet Recycle ==="
echo "Files to push: ${FILES_TO_PUSH[*]}"
echo "Max parallel: ${MAX_PARALLEL}"

INSTANCES=$(vastai show instances --raw 2>/dev/null | python3 -c "
import json, sys
data = json.load(sys.stdin)
running = [d for d in data if d.get('actual_status') == 'running' and d.get('ssh_host') and d.get('ssh_port')]
for d in running:
    print(f'{d[\"id\"]}|{d[\"ssh_host\"]}|{d[\"ssh_port\"]}')
")

TOTAL=$(echo "$INSTANCES" | wc -l)
echo "Found $TOTAL running instances with SSH"

if [ "$MAX_INSTANCES" -gt 0 ]; then
    INSTANCES=$(echo "$INSTANCES" | head -n "$MAX_INSTANCES")
    TOTAL=$(echo "$INSTANCES" | wc -l)
    echo "Limited to $TOTAL instances"
fi

if [ "$DRY_RUN" = true ]; then
    echo "[DRY RUN] Would recycle $TOTAL instances"
    exit 0
fi

SUCCESS=0
FAILED=0
PIDS=()

recycle_one() {
    local id="$1" host="$2" port="$3"
    local ssh_dest="root@${host}"
    local target="/app"

    # Push updated files
    for f in "${FILES_TO_PUSH[@]}"; do
        scp ${SSH_OPTS} -i "$SSH_KEY" -P "$port" \
            "/home/ubuntu/transcripts/${f}" \
            "${ssh_dest}:${target}/${f}" 2>/dev/null || return 1
    done

    # SIGKILL the python worker — SIGTERM doesn't work on GPU-bound processes in
    # uninterruptible sleep. The onstart restart loop will pick up new code in ~10s.
    ssh ${SSH_OPTS} -i "$SSH_KEY" -p "$port" "$ssh_dest" \
        "pkill -9 -f 'python -m validations.main' 2>/dev/null; sleep 1; pkill -9 -f 'python -m validations' 2>/dev/null; echo 'killed'" 2>/dev/null || true

    return 0
}

COUNT=0
for line in $INSTANCES; do
    IFS='|' read -r ID HOST PORT <<< "$line"
    COUNT=$((COUNT + 1))

    (
        if recycle_one "$ID" "$HOST" "$PORT"; then
            echo "[${COUNT}/${TOTAL}] OK: instance ${ID} (${HOST}:${PORT})"
        else
            echo "[${COUNT}/${TOTAL}] FAIL: instance ${ID} (${HOST}:${PORT})"
        fi
    ) &
    PIDS+=($!)

    # Throttle parallel connections
    if [ ${#PIDS[@]} -ge $MAX_PARALLEL ]; then
        for pid in "${PIDS[@]}"; do
            wait "$pid" 2>/dev/null && SUCCESS=$((SUCCESS + 1)) || FAILED=$((FAILED + 1))
        done
        PIDS=()
    fi
done

# Wait for remaining
for pid in "${PIDS[@]}"; do
    wait "$pid" 2>/dev/null && SUCCESS=$((SUCCESS + 1)) || FAILED=$((FAILED + 1))
done

echo ""
echo "=== Recycle Complete ==="
echo "Total: $TOTAL | Attempted: $COUNT"
echo "Workers will auto-restart within ~10s of being recycled"
