#!/usr/bin/env bash
# Benchmark series: BS sweep + torch.compile + profiling
set -euo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"

MODEL_PATH="/root/data/qwen3_asr_weights"
TRAIN_FILE="/root/gemini-asr/lf_asr/artifacts/phase2/train.parquet"
OUTPUT_DIR="/root/data/qwen3-asr-phase2-bench"
LOG_DIR="/root/data/qwen3-asr-phase2-bench/logs"
STEPS=100

mkdir -p "$OUTPUT_DIR" "$LOG_DIR"

run_experiment() {
    local EXP_ID="$1"
    local BS="$2"
    local COMPILE_FLAG="$3"
    local EXTRA_FLAGS="${4:-}"

    echo ""
    echo "========================================"
    echo "  Experiment: $EXP_ID"
    echo "  BS=$BS, compile=$COMPILE_FLAG"
    echo "  Extra: $EXTRA_FLAGS"
    echo "========================================"

    # Clean up from previous run
    rm -rf "$OUTPUT_DIR/checkpoint-*" 2>/dev/null || true

    # GPU snapshot before
    nvidia-smi --query-gpu=index,memory.used,utilization.gpu,power.draw --format=csv,noheader > "$LOG_DIR/${EXP_ID}_gpu_pre.csv"

    # Background GPU monitor (sample every 5s)
    (while true; do
        nvidia-smi --query-gpu=index,memory.used,utilization.gpu,power.draw --format=csv,noheader
        sleep 5
    done) > "$LOG_DIR/${EXP_ID}_gpu_trace.csv" 2>/dev/null &
    GPU_MON_PID=$!

    START_TIME=$(date +%s)

    torchrun --standalone --nproc_per_node=8 \
        finetuning/qwen3_asr_sft_phase2.py \
        --model_path "${MODEL_PATH}" \
        --train_file "${TRAIN_FILE}" \
        --output_dir "${OUTPUT_DIR}" \
        --batch_size "${BS}" \
        --grad_acc 1 \
        --lr 2e-5 \
        --max_steps "${STEPS}" \
        --log_steps 10 \
        --save_steps 999999 \
        --eval_steps 999999 \
        --num_workers 8 \
        --prefetch_factor 4 \
        --pin_memory 1 \
        --persistent_workers 1 \
        --max_open_tars 32 \
        --gradient_checkpointing 0 \
        --ddp_static_graph 1 \
        --language_tag_mode auto \
        --max_duration_s 30.0 \
        --seed 42 \
        ${COMPILE_FLAG} \
        ${EXTRA_FLAGS} \
        2>&1 | tee "$LOG_DIR/${EXP_ID}.log"

    END_TIME=$(date +%s)
    ELAPSED=$((END_TIME - START_TIME))

    # Stop GPU monitor
    kill $GPU_MON_PID 2>/dev/null || true

    # GPU snapshot after
    nvidia-smi --query-gpu=index,memory.used,utilization.gpu,power.draw --format=csv,noheader > "$LOG_DIR/${EXP_ID}_gpu_post.csv"

    # Extract metrics
    echo ""
    echo "--- $EXP_ID Results ---"
    echo "Wall time: ${ELAPSED}s"
    echo "Steps: $STEPS"
    echo "Avg step time: $(echo "scale=3; $ELAPSED / $STEPS" | bc)s"

    # Extract peak memory from GPU trace
    if [ -f "$LOG_DIR/${EXP_ID}_gpu_trace.csv" ]; then
        MAX_MEM=$(awk -F', ' '{print $2}' "$LOG_DIR/${EXP_ID}_gpu_trace.csv" | sed 's/ MiB//' | sort -n | tail -1)
        echo "Peak GPU memory: ${MAX_MEM} MiB"
    fi

    # Extract loss from log
    grep -oP "'loss': [\d.]+" "$LOG_DIR/${EXP_ID}.log" | tail -1
    grep -oP "'train_samples_per_second': [\d.]+" "$LOG_DIR/${EXP_ID}.log" | tail -1
    grep -oP "'train_steps_per_second': [\d.]+" "$LOG_DIR/${EXP_ID}.log" | tail -1
    echo "========================="
}

echo "Starting benchmark series at $(date)"

# E1: BS=16, no compile (baseline)
run_experiment "E1_bs16_nocompile" 16 ""

# E2: BS=32, no compile
run_experiment "E2_bs32_nocompile" 32 ""

# E3: BS=16, torch.compile reduce-overhead
run_experiment "E3_bs16_compile_ro" 16 "--torch_compile 1 --compile_mode reduce-overhead"

# E4: BS=32, torch.compile reduce-overhead
run_experiment "E4_bs32_compile_ro" 32 "--torch_compile 1 --compile_mode reduce-overhead"

# E5: BS=64, no compile (test memory limit)
run_experiment "E5_bs64_nocompile" 64 "" || echo "E5 FAILED (likely OOM)"

echo ""
echo "All benchmarks complete at $(date)"
echo "Logs in: $LOG_DIR"
