#!/usr/bin/env bash
# Benchmark v2: adjusted after BS=32 OOM
set -euo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"

MODEL_PATH="/root/data/qwen3_asr_weights"
TRAIN_FILE="/root/gemini-asr/lf_asr/artifacts/phase2/train.parquet"
OUTPUT_DIR="/root/data/qwen3-asr-phase2-bench"
LOG_DIR="/root/data/qwen3-asr-phase2-bench/logs"
STEPS=100

mkdir -p "$OUTPUT_DIR" "$LOG_DIR"

run_experiment() {
    local EXP_ID="$1"
    shift

    echo ""
    echo "========================================"
    echo "  Experiment: $EXP_ID"
    echo "  Args: $@"
    echo "========================================"

    rm -rf "$OUTPUT_DIR/checkpoint-*" 2>/dev/null || true

    # Background GPU monitor
    (while true; do
        nvidia-smi --query-gpu=index,memory.used,utilization.gpu,power.draw --format=csv,noheader
        sleep 5
    done) > "$LOG_DIR/${EXP_ID}_gpu_trace.csv" 2>/dev/null &
    GPU_MON_PID=$!

    START_TIME=$(date +%s)

    torchrun --standalone --nproc_per_node=8 \
        finetuning/qwen3_asr_sft_phase2.py \
        --model_path "${MODEL_PATH}" \
        --train_file "${TRAIN_FILE}" \
        --output_dir "${OUTPUT_DIR}" \
        --grad_acc 1 \
        --lr 2e-5 \
        --max_steps "${STEPS}" \
        --log_steps 10 \
        --save_steps 999999 \
        --eval_steps 999999 \
        --num_workers 8 \
        --prefetch_factor 4 \
        --pin_memory 1 \
        --persistent_workers 1 \
        --max_open_tars 32 \
        --ddp_static_graph 1 \
        --language_tag_mode auto \
        --max_duration_s 30.0 \
        --seed 42 \
        "$@" \
        2>&1 | tee "$LOG_DIR/${EXP_ID}.log"

    RET=$?
    END_TIME=$(date +%s)
    ELAPSED=$((END_TIME - START_TIME))

    kill $GPU_MON_PID 2>/dev/null || true

    echo ""
    echo "--- $EXP_ID Summary ---"
    if [ $RET -ne 0 ]; then
        echo "STATUS: FAILED (exit $RET)"
    else
        echo "STATUS: SUCCESS"
        echo "Wall time: ${ELAPSED}s for $STEPS steps"
        grep "'train_samples_per_second'" "$LOG_DIR/${EXP_ID}.log" | tail -1 || true
        grep "'train_loss'" "$LOG_DIR/${EXP_ID}.log" | tail -1 || true
    fi
    if [ -f "$LOG_DIR/${EXP_ID}_gpu_trace.csv" ]; then
        MAX_MEM=$(awk -F', ' '{gsub(/ MiB/,"",$2); print $2}' "$LOG_DIR/${EXP_ID}_gpu_trace.csv" | sort -n | tail -1)
        AVG_UTIL=$(awk -F', ' '{gsub(/ %/,"",$3); sum+=$3; n++} END{printf "%.1f", sum/n}' "$LOG_DIR/${EXP_ID}_gpu_trace.csv")
        echo "Peak memory: ${MAX_MEM} MiB | Avg GPU util: ${AVG_UTIL}%"
    fi
    echo "========================="
    return $RET
}

echo "Benchmark v2 starting at $(date)"

# E2: BS=16 + torch.compile reduce-overhead
run_experiment "E2_bs16_compile" \
    --batch_size 16 --gradient_checkpointing 0 \
    --torch_compile 1 --compile_mode reduce-overhead

# E3: BS=24 no compile (test if it fits with 20GB headroom)
run_experiment "E3_bs24_nocompile" \
    --batch_size 24 --gradient_checkpointing 0 || echo "E3 OOM"

# E4: BS=32 + gradient_checkpointing ON (trade speed for larger batch)
run_experiment "E4_bs32_gc" \
    --batch_size 32 --gradient_checkpointing 1

# E5: BS=48 + gradient_checkpointing ON
run_experiment "E5_bs48_gc" \
    --batch_size 48 --gradient_checkpointing 1 || echo "E5 OOM"

# E6: BS=32 + gradient_checkpointing + torch.compile
run_experiment "E6_bs32_gc_compile" \
    --batch_size 32 --gradient_checkpointing 1 \
    --torch_compile 1 --compile_mode reduce-overhead || echo "E6 FAILED"

echo ""
echo "All benchmarks complete at $(date)"
