#!/usr/bin/env bash
# Auto-resume training wrapper.
# Automatically restarts training from the latest checkpoint on failure.
# Maintains same wandb run ID across restarts.
# Saves sampler state so resumed runs skip already-trained batches.
#
# Usage:
#   bash scripts/auto_resume_train.sh [extra train_prod.py args...]

set -uo pipefail

PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$PROJECT_ROOT"

source .env
export WANDB_API_KEY
export WANDB_PROJECT=maya-asr
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# State file for wandb run ID persistence
WANDB_STATE_FILE="/tmp/maya_asr_wandb_run_id"
LOG_FILE="$PROJECT_ROOT/prod_training_tdt_v2.log"
MAX_RETRIES=10
RETRY_DELAY=30  # seconds between retries

# Default training args (can be overridden via CLI)
DEFAULT_ARGS=(
    --config configs/train/stage1_prod_8xh200.yaml
    --pretrained-encoder pretrained/parakeet_tdt_1.1b_encoder.pt
    --max-batch-dur 120
    --max-batch-size 24
    --max-tokens-in-batch 400
)

# Find latest checkpoint
find_latest_checkpoint() {
    local exp_base="experiments/stage1_prod/maya_asr_stage1_tdt_v2"
    local latest_ckpt=""
    local latest_step=0

    for ckpt in $(find "$exp_base" -name "*-last.ckpt" 2>/dev/null); do
        local step
        step=$(python3 -c "
import torch
m = torch.load('$ckpt', map_location='cpu', weights_only=False)
print(m.get('global_step', 0))
" 2>/dev/null || echo "0")
        if [ "$step" -gt "$latest_step" ]; then
            latest_step=$step
            latest_ckpt=$ckpt
        fi
    done

    if [ -n "$latest_ckpt" ] && [ "$latest_step" -gt 0 ]; then
        echo "$latest_ckpt"
    fi
}

# Get or create persistent wandb run ID
get_wandb_run_id() {
    if [ -f "$WANDB_STATE_FILE" ]; then
        cat "$WANDB_STATE_FILE"
    else
        local run_id
        run_id="tdt-v2-$(date +%Y%m%d-%H%M%S)"
        echo "$run_id" > "$WANDB_STATE_FILE"
        echo "$run_id"
    fi
}

echo "============================================================"
echo "  Maya ASR — Auto-Resume Training Wrapper"
echo "============================================================"
echo

WANDB_RUN_ID=$(get_wandb_run_id)
export WANDB_RUN_ID
export WANDB_RESUME=allow
export WANDB_RUN_NAME="tdt-parakeet-finetune-stage1-v2"
echo "Wandb run ID: $WANDB_RUN_ID (persistent across restarts)"
echo

for attempt in $(seq 1 $MAX_RETRIES); do
    echo "--- Attempt $attempt/$MAX_RETRIES ---"

    # Find latest checkpoint for resume
    RESUME_CKPT=$(find_latest_checkpoint)
    RESUME_ARGS=()
    if [ -n "$RESUME_CKPT" ]; then
        RESUME_STEP=$(python3 -c "
import torch
m = torch.load('$RESUME_CKPT', map_location='cpu', weights_only=False)
print(m.get('global_step', 0))
" 2>/dev/null || echo "0")
        echo "Resuming from: $RESUME_CKPT (step $RESUME_STEP)"
        RESUME_ARGS=(--resume-from-checkpoint "$RESUME_CKPT")
    else
        echo "No checkpoint found, starting from scratch"
    fi

    echo "Starting training at $(date -Iseconds)"

    # Launch training
    PYTHONUNBUFFERED=1 python3 -u scripts/train_prod.py \
        "${DEFAULT_ARGS[@]}" \
        "${RESUME_ARGS[@]}" \
        "$@" \
        > "$LOG_FILE" 2>&1
    EXIT_CODE=$?

    echo "Training exited with code $EXIT_CODE at $(date -Iseconds)"

    if [ $EXIT_CODE -eq 0 ]; then
        echo "Training completed successfully!"
        break
    fi

    # Check if it was an OOM
    if grep -q "OutOfMemoryError\|CUDA out of memory" "$LOG_FILE" 2>/dev/null; then
        echo "CRASH: OOM detected. Will retry with same checkpoint."
    elif grep -q "NCCL.*timeout" "$LOG_FILE" 2>/dev/null; then
        echo "CRASH: NCCL timeout (likely caused by OOM on another rank). Will retry."
    else
        echo "CRASH: Unknown error. Check $LOG_FILE"
    fi

    # Clean up GPU memory
    echo "Cleaning up GPU memory..."
    sleep $RETRY_DELAY

    if [ $attempt -lt $MAX_RETRIES ]; then
        echo "Retrying in ${RETRY_DELAY}s..."
        echo
    else
        echo "Max retries ($MAX_RETRIES) reached. Giving up."
        exit 1
    fi
done
