#!/usr/bin/env bash
# Bucketed smoke test with full profiling, flash attention
set -euo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"

MODEL_PATH="/root/data/qwen3_asr_weights"
BUCKET_DIR="/root/gemini-asr/lf_asr/artifacts/phase2/buckets"
BUCKET_CONFIG="${BUCKET_DIR}/bucket_config.json"
OUTPUT_DIR="/root/data/qwen3-asr-phase2-smoke"
LOG_DIR="${OUTPUT_DIR}/logs"
STEPS=${1:-200}
BS=${2:-16}
GC=${3:-0}
ATTN=${4:-flash_attention_2}
WORKERS=${5:-8}

mkdir -p "$OUTPUT_DIR" "$LOG_DIR"
rm -rf "$OUTPUT_DIR"/checkpoint-* 2>/dev/null || true

echo "=== Bucketed Smoke Test $(date) ==="
echo "  Steps=$STEPS BS=$BS GC=$GC attn=$ATTN workers=$WORKERS"

# Background GPU monitor
(while true; do
    nvidia-smi --query-gpu=index,memory.used,utilization.gpu,power.draw --format=csv,noheader
    sleep 5
done) > "$LOG_DIR/gpu_trace.csv" 2>/dev/null &
GPU_MON_PID=$!

T0=$(date +%s)

torchrun --standalone --nproc_per_node=8 \
    finetuning/qwen3_asr_sft_phase2.py \
    --model_path "$MODEL_PATH" \
    --bucket_dir "$BUCKET_DIR" \
    --bucket_config "$BUCKET_CONFIG" \
    --output_dir "$OUTPUT_DIR" \
    --batch_size $BS \
    --grad_acc 1 \
    --lr 2e-5 \
    --max_steps $STEPS \
    --log_steps 10 \
    --save_steps 999999 \
    --eval_steps 999999 \
    --num_workers $WORKERS \
    --prefetch_factor 4 \
    --pin_memory 1 \
    --persistent_workers 1 \
    --max_open_tars 64 \
    --use_indexed_tar 1 \
    --worker_decode 1 \
    --gradient_checkpointing $GC \
    --attn_implementation $ATTN \
    --ddp_static_graph 1 \
    --max_batch_seq_len 500 \
    --language_tag_mode auto \
    --seed 42 \
    --profiling 1 \
    2>&1 | tee "$LOG_DIR/smoke.log"

RET=$?
T1=$(date +%s)
kill $GPU_MON_PID 2>/dev/null || true

echo ""
echo "=== Smoke Test Summary ==="
echo "Wall time: $((T1-T0))s for $STEPS steps"
echo "Exit code: $RET"

if [ $RET -eq 0 ]; then
    grep "\[profile\]" "$LOG_DIR/smoke.log" | tail -5
    grep "\[PROFILING REPORT\]" -A 20 "$LOG_DIR/smoke.log"
    grep "'train_samples_per_second'" "$LOG_DIR/smoke.log" | tail -1
    grep "'train_loss'" "$LOG_DIR/smoke.log" | tail -1

    if [ -f "$LOG_DIR/gpu_trace.csv" ]; then
        MAX_MEM=$(awk -F', ' '{gsub(/ MiB/,"",$2); print $2}' "$LOG_DIR/gpu_trace.csv" | sort -n | tail -1)
        AVG_UTIL=$(awk -F', ' '{gsub(/ %/,"",$3); sum+=$3; n++} END{printf "%.1f", sum/n}' "$LOG_DIR/gpu_trace.csv")
        echo "Peak GPU memory: ${MAX_MEM:-?} MiB"
        echo "Avg GPU utilization: ${AVG_UTIL:-?}%"
    fi
fi

echo "=== Done $(date) ==="
