#!/usr/bin/env bash
# Training monitor — checks health and auto-resumes on failure
# Schedule: every 10min for first hour, every 30min for hours 1-3, every 1hr after
set -uo pipefail

OUTPUT_DIR="/root/data/qwen3-asr-phase2-out"
LOG_FILE="$OUTPUT_DIR/train.log"
MONITOR_LOG="$OUTPUT_DIR/monitor.log"
LAUNCH_SCRIPT="/root/data/Qwen3-ASR-official/launch_production.sh"

log() { echo "[$(date '+%Y-%m-%d %H:%M:%S')] $*" | tee -a "$MONITOR_LOG"; }

check_health() {
    # 1. Is torchrun running?
    if ! pgrep -f "qwen3_asr_sft_phase2" > /dev/null 2>&1; then
        log "ALERT: Training process NOT running!"
        return 1
    fi

    # 2. Check last log line age (stale = stuck)
    if [ -f "$LOG_FILE" ]; then
        last_mod=$(stat -c %Y "$LOG_FILE" 2>/dev/null || echo 0)
        now=$(date +%s)
        age=$((now - last_mod))
        if [ $age -gt 600 ]; then
            log "WARN: Log file stale for ${age}s (>600s)"
            return 2
        fi
    fi

    # 3. Check for NaN in recent loss
    if [ -f "$LOG_FILE" ]; then
        recent_loss=$(grep "'loss'" "$LOG_FILE" | tail -1 | grep -oP "'loss': [\d.]+" | grep -oP "[\d.]+$")
        if [ -n "$recent_loss" ]; then
            is_nan=$(echo "$recent_loss" | grep -ci "nan" || true)
            if [ "$is_nan" -gt 0 ]; then
                log "ALERT: NaN loss detected!"
                return 3
            fi
        fi
    fi

    # 4. Check GPU health
    gpu_errors=$(nvidia-smi --query-gpu=ecc.errors.uncorrected.volatile.total --format=csv,noheader 2>/dev/null | awk '{s+=$1} END{print s}')
    if [ "${gpu_errors:-0}" -gt 0 ]; then
        log "ALERT: GPU ECC errors detected: $gpu_errors"
        return 4
    fi

    # 5. Extract current step
    if [ -f "$LOG_FILE" ]; then
        current_step=$(grep -oP '\d+/\d+' "$LOG_FILE" | tail -1 | cut -d/ -f1)
        total_steps=$(grep -oP '\d+/\d+' "$LOG_FILE" | tail -1 | cut -d/ -f2)
        last_loss=$(grep "'loss'" "$LOG_FILE" | tail -1 | grep -oP "'loss': [0-9.]+" | grep -oP "[0-9.]+$")
        gpu_mem=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader | head -1)
        gpu_util=$(nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader | head -1)
        log "OK: step=${current_step:-?}/${total_steps:-?} loss=${last_loss:-?} gpu_mem=${gpu_mem} gpu_util=${gpu_util}"
    else
        log "OK: Training running but no log file yet"
    fi

    return 0
}

auto_resume() {
    log "Attempting auto-resume..."

    # Kill any stale processes
    pkill -f "qwen3_asr_sft_phase2" 2>/dev/null || true
    sleep 10

    # Check for checkpoint
    latest_ckpt=$(ls -d "$OUTPUT_DIR"/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1)
    if [ -n "$latest_ckpt" ]; then
        log "Resuming from: $latest_ckpt"
    else
        log "No checkpoint found — starting fresh"
    fi

    # Relaunch in background
    cd /root/data/Qwen3-ASR-official
    nohup bash "$LAUNCH_SCRIPT" >> "$OUTPUT_DIR/train.log" 2>&1 &
    log "Relaunched training (PID: $!)"
    sleep 30  # Wait for startup
}

# Main monitor loop
log "=== Training Monitor Started ==="
START_TIME=$(date +%s)
CHECK_COUNT=0

while true; do
    elapsed=$(( $(date +%s) - START_TIME ))
    elapsed_h=$(( elapsed / 3600 ))

    # Adaptive interval
    if [ $elapsed_h -lt 1 ]; then
        interval=600    # 10 min
    elif [ $elapsed_h -lt 3 ]; then
        interval=1800   # 30 min
    else
        interval=3600   # 1 hour
    fi

    CHECK_COUNT=$((CHECK_COUNT + 1))
    log "--- Check #$CHECK_COUNT (interval=${interval}s, elapsed=${elapsed_h}h) ---"

    if ! check_health; then
        exit_code=$?
        log "Health check failed (code=$exit_code)"

        if [ $exit_code -eq 1 ]; then
            # Process not running — auto-resume
            auto_resume
        else
            log "Non-resumable issue. Manual intervention needed."
        fi
    fi

    # Check if training completed
    if [ -f "$LOG_FILE" ] && grep -q "train_runtime" "$LOG_FILE" 2>/dev/null; then
        log "=== Training COMPLETED ==="
        grep "train_runtime\|train_loss\|train_samples" "$LOG_FILE" | tail -3 | while read line; do log "  $line"; done
        break
    fi

    sleep $interval
done

log "=== Monitor Exiting ==="
