#!/usr/bin/env bash
# Resilient production training — auto-restarts on transient CUDA/NCCL crashes
# Saves checkpoint every 5000 steps; on crash, resumes from latest checkpoint
set -uo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"
export WANDB_MODE="${WANDB_MODE:-online}"
export WANDB_RESUME="${WANDB_RESUME:-allow}"
export WANDB_RUN_ID="${WANDB_RUN_ID:-}"
export WANDB_INIT_TIMEOUT=300
# Increase NCCL timeout to 30 min (default 10 min) — helps with checkpoint I/O stalls
export NCCL_TIMEOUT=1800

MODEL_PATH="/root/data/qwen3_asr_weights"
BUCKET_DIR="/root/gemini-asr/lf_asr/artifacts/phase2/buckets"
BUCKET_CONFIG="${BUCKET_DIR}/bucket_config.json"
OUTPUT_DIR="/root/data/qwen3-asr-phase2-out"
MAX_BATCH_SEQ_LEN="${MAX_BATCH_SEQ_LEN:-700}"
IGNORE_DATA_SKIP="${IGNORE_DATA_SKIP:-1}"

MAX_RESTARTS=20
RESTART_COOLDOWN=30  # seconds between restarts

mkdir -p "$OUTPUT_DIR"

log() { echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*" | tee -a "$OUTPUT_DIR/resilient.log"; }

run_training() {
    local resume_flag="$1"
    local resume_from="${2:-}"
    local -a resume_args

    if [ "$resume_flag" = "1" ] && [ -n "$resume_from" ]; then
        resume_args=(--resume_from "$resume_from" --resume 0)
    else
        resume_args=(--resume "$resume_flag")
    fi

    torchrun --standalone --nproc_per_node=8 \
        finetuning/qwen3_asr_sft_phase2.py \
        --model_path "$MODEL_PATH" \
        --bucket_dir "$BUCKET_DIR" \
        --bucket_config "$BUCKET_CONFIG" \
        --output_dir "$OUTPUT_DIR" \
        --batch_size 16 \
        --grad_acc 1 \
        --lr 2e-5 \
        --warmup_ratio 0.02 \
        --lr_scheduler_type cosine \
        --epochs 1 \
        --max_steps 612448 \
        --log_steps 50 \
        --save_steps 2000 \
        --save_total_limit 5 \
        --num_workers 4 \
        --prefetch_factor 4 \
        --pin_memory 1 \
        --persistent_workers 1 \
        --max_open_tars 64 \
        --use_indexed_tar 1 \
        --worker_decode 1 \
        --gradient_checkpointing 0 \
        --attn_implementation flash_attention_2 \
        --ddp_static_graph 1 \
        --max_batch_seq_len "$MAX_BATCH_SEQ_LEN" \
        --language_tag_mode auto \
        --seed 42 \
        --profiling 1 \
        --wandb_project maya-asr \
        --wandb_run_name qwen3-asr-1.7b \
        --ignore_data_skip "$IGNORE_DATA_SKIP" \
        "${resume_args[@]}" \
        2>&1 | tee -a "$OUTPUT_DIR/train.log"

    return ${PIPESTATUS[0]}
}

log "============================================"
log "  Qwen3-ASR Phase 2 — Resilient Training"
log "  Max restarts: $MAX_RESTARTS"
log "  Save steps: 2000 (for faster resume)"
log "  max_batch_seq_len: $MAX_BATCH_SEQ_LEN"
log "  ignore_data_skip: $IGNORE_DATA_SKIP"
log "============================================"

restart_count=0
# First run: start fresh
resume=0

# Check if there's already a checkpoint (previous crashed run)
latest_ckpt=$(ls -d "$OUTPUT_DIR"/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1)
if [ -n "$latest_ckpt" ]; then
    log "Found existing checkpoint: $latest_ckpt — will resume"
    resume=1
fi

while [ $restart_count -le $MAX_RESTARTS ]; do
    log "--- Launch #$((restart_count + 1)) (resume=$resume, resume_from=${latest_ckpt:-none}) ---"

    run_training "$resume" "${latest_ckpt:-}"
    exit_code=$?

    # Check if training completed successfully
    if [ $exit_code -eq 0 ]; then
        if grep -q "train_runtime" "$OUTPUT_DIR/train.log" 2>/dev/null; then
            log "=== Training COMPLETED successfully ==="
            exit 0
        fi
        log "torchrun exited 0 but no train_runtime found — possible issue"
    fi

    log "Training crashed with exit code $exit_code"

    # Check for checkpoint to resume from
    latest_ckpt=$(ls -d "$OUTPUT_DIR"/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1)
    if [ -n "$latest_ckpt" ]; then
        ckpt_step=$(basename "$latest_ckpt" | grep -oP '\d+')
        log "Latest checkpoint: $latest_ckpt (step $ckpt_step)"
        resume=1
    else
        log "No checkpoint found — will restart from scratch"
        resume=0
    fi

    restart_count=$((restart_count + 1))

    if [ $restart_count -gt $MAX_RESTARTS ]; then
        log "FATAL: Exceeded max restarts ($MAX_RESTARTS). Giving up."
        exit 1
    fi

    # Kill any stale GPU processes
    pkill -9 -f "qwen3_asr_sft_phase2" 2>/dev/null || true
    sleep 5

    # Clear CUDA error state
    nvidia-smi -r 2>/dev/null || true

    log "Cooling down for ${RESTART_COOLDOWN}s before restart..."
    sleep $RESTART_COOLDOWN
done
