#!/usr/bin/env bash
# Quick smoke test: new worker-decode pipeline with legacy parquet (no buckets needed)
set -euo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"

MODEL_PATH="/root/data/qwen3_asr_weights"
TRAIN_FILE="/root/gemini-asr/lf_asr/artifacts/phase2/train.parquet"
OUTPUT_DIR="/root/data/qwen3-asr-phase2-quicksmoke"
LOG_DIR="${OUTPUT_DIR}/logs"
STEPS=${1:-50}
BS=${2:-32}

mkdir -p "$OUTPUT_DIR" "$LOG_DIR"
rm -rf "$OUTPUT_DIR/checkpoint-*" 2>/dev/null || true

echo "=== Quick Smoke: Worker-Decode Pipeline $(date) ==="
echo "  Steps=$STEPS BS=$BS (GC=1, worker_decode=1, use_indexed_tar=1)"

# GPU monitor
(while true; do
    nvidia-smi --query-gpu=index,memory.used,utilization.gpu --format=csv,noheader
    sleep 5
done) > "$LOG_DIR/gpu_trace.csv" 2>/dev/null &
GPU_MON_PID=$!

T0=$(date +%s)

torchrun --standalone --nproc_per_node=8 \
    finetuning/qwen3_asr_sft_phase2.py \
    --model_path "$MODEL_PATH" \
    --train_file "$TRAIN_FILE" \
    --output_dir "$OUTPUT_DIR" \
    --batch_size $BS \
    --grad_acc 1 \
    --lr 2e-5 \
    --max_steps $STEPS \
    --log_steps 5 \
    --save_steps 999999 \
    --eval_steps 999999 \
    --num_workers 8 \
    --prefetch_factor 4 \
    --pin_memory 1 \
    --persistent_workers 1 \
    --max_open_tars 64 \
    --use_indexed_tar 1 \
    --worker_decode 1 \
    --gradient_checkpointing 1 \
    --ddp_static_graph 1 \
    --language_tag_mode auto \
    --max_duration_s 30.0 \
    --seed 42 \
    --profiling 1 \
    2>&1 | tee "$LOG_DIR/smoke.log"

RET=$?
T1=$(date +%s)
kill $GPU_MON_PID 2>/dev/null || true

echo ""
echo "=== Quick Smoke Summary ==="
echo "Wall: $((T1-T0))s, Steps: $STEPS, Exit: $RET"

if [ $RET -eq 0 ]; then
    grep "\[PROFILING REPORT\]" -A 15 "$LOG_DIR/smoke.log"
    grep "'train_samples_per_second'" "$LOG_DIR/smoke.log" | tail -1
    grep "'train_loss'" "$LOG_DIR/smoke.log" | tail -1

    if [ -f "$LOG_DIR/gpu_trace.csv" ]; then
        MAX_MEM=$(awk -F', ' '{gsub(/ MiB/,"",$2); print $2}' "$LOG_DIR/gpu_trace.csv" | sort -n | tail -1)
        AVG_UTIL=$(awk -F', ' '{gsub(/ %/,"",$3); sum+=$3; n++} END{printf "%.1f", sum/n}' "$LOG_DIR/gpu_trace.csv")
        echo "Peak GPU mem: ${MAX_MEM:-?} MiB | Avg GPU util: ${AVG_UTIL:-?}%"
    fi
fi
echo "=== Done $(date) ==="
