#!/usr/bin/env bash
# Final benchmark: grad_ckpt ON enables larger BS, which wins on throughput
set -euo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"

MODEL_PATH="/root/data/qwen3_asr_weights"
TRAIN_FILE="/root/gemini-asr/lf_asr/artifacts/phase2/train.parquet"
OUTPUT_DIR="/root/data/qwen3-asr-phase2-bench"
LOG_DIR="/root/data/qwen3-asr-phase2-bench/logs"
STEPS=100

mkdir -p "$OUTPUT_DIR" "$LOG_DIR"

run_exp() {
    local ID="$1"; shift
    echo -e "\n=== $ID ==="
    rm -rf "$OUTPUT_DIR/checkpoint-*" 2>/dev/null || true

    (while true; do nvidia-smi --query-gpu=index,memory.used,utilization.gpu --format=csv,noheader; sleep 5; done) > "$LOG_DIR/${ID}_gpu.csv" 2>/dev/null &
    GPID=$!

    local T0=$(date +%s)
    torchrun --standalone --nproc_per_node=8 finetuning/qwen3_asr_sft_phase2.py \
        --model_path "$MODEL_PATH" --train_file "$TRAIN_FILE" --output_dir "$OUTPUT_DIR" \
        --grad_acc 1 --lr 2e-5 --max_steps $STEPS --log_steps 10 \
        --save_steps 999999 --eval_steps 999999 \
        --num_workers 8 --prefetch_factor 4 --pin_memory 1 --persistent_workers 1 \
        --max_open_tars 32 --ddp_static_graph 1 --language_tag_mode auto \
        --max_duration_s 30.0 --seed 42 "$@" \
        2>&1 | tee "$LOG_DIR/${ID}.log"
    local RET=$?
    local T1=$(date +%s)
    kill $GPID 2>/dev/null || true

    echo ""
    echo "--- $ID ---"
    if [ $RET -eq 0 ]; then
        echo "Wall: $((T1-T0))s | Steps: $STEPS"
        grep "'train_samples_per_second'" "$LOG_DIR/${ID}.log" | tail -1
        MAX_MEM=$(awk -F', ' '{gsub(/ MiB/,"",$2); print $2}' "$LOG_DIR/${ID}_gpu.csv" 2>/dev/null | sort -n | tail -1)
        AVG_UTIL=$(awk -F', ' '{gsub(/ %/,"",$3); sum+=$3; n++} END{printf "%.1f", sum/n}' "$LOG_DIR/${ID}_gpu.csv" 2>/dev/null)
        echo "Peak mem: ${MAX_MEM:-?} MiB | GPU util: ${AVG_UTIL:-?}%"
    else
        echo "FAILED (exit $RET)"
    fi
    return $RET
}

echo "=== Final Benchmark Suite $(date) ==="

# E3: BS=32 + grad_ckpt ON
run_exp "E3_bs32_gc" --batch_size 32 --gradient_checkpointing 1

# E4: BS=48 + grad_ckpt ON
run_exp "E4_bs48_gc" --batch_size 48 --gradient_checkpointing 1 || echo "(E4 failed)"

# E5: BS=64 + grad_ckpt ON
run_exp "E5_bs64_gc" --batch_size 64 --gradient_checkpointing 1 || echo "(E5 failed)"

# E6: Best BS from above + compile mode=default (not reduce-overhead, avoids CUDAGraph issue)
run_exp "E6_bs32_gc_compile_default" --batch_size 32 --gradient_checkpointing 1 \
    --torch_compile 1 --compile_mode default || echo "(E6 failed)"

echo -e "\n=== All done $(date) ==="
