#!/usr/bin/env bash
# Stage 1: Pipeline correctness + profiling smoke test (50 steps)
# Profiles: data loading, forward, backward, optimizer, allreduce
set -euo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
# Note: NCCL_ALGO=Tree incompatible with AllGather int8 in NCCL 2.26
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"

MODEL_PATH="${MODEL_PATH:-/root/data/qwen3_asr_weights}"
TRAIN_FILE="${TRAIN_FILE:-/root/gemini-asr/lf_asr/artifacts/phase2/train.parquet}"
EVAL_FILE="${EVAL_FILE:-/root/gemini-asr/lf_asr/artifacts/phase2/dev.parquet}"
OUTPUT_DIR="${OUTPUT_DIR:-/root/data/qwen3-asr-phase2-smoke}"
BATCH_SIZE="${BATCH_SIZE:-16}"
MAX_STEPS="${MAX_STEPS:-50}"
LOG_DIR="${LOG_DIR:-/root/data/qwen3-asr-phase2-smoke/logs}"

mkdir -p "$OUTPUT_DIR" "$LOG_DIR"

echo "=== Qwen3-ASR Phase2 Smoke Test ==="
echo "Model:      ${MODEL_PATH}"
echo "Train file: ${TRAIN_FILE}"
echo "Output:     ${OUTPUT_DIR}"
echo "Batch size: ${BATCH_SIZE}"
echo "Max steps:  ${MAX_STEPS}"
echo "GPUs:       $(nvidia-smi -L | wc -l)"
echo "====================================="

torchrun --standalone --nproc_per_node=8 \
    finetuning/qwen3_asr_sft_phase2.py \
    --model_path "${MODEL_PATH}" \
    --train_file "${TRAIN_FILE}" \
    --eval_file "${EVAL_FILE}" \
    --train_split train \
    --eval_split dev \
    --output_dir "${OUTPUT_DIR}" \
    --batch_size "${BATCH_SIZE}" \
    --grad_acc 1 \
    --lr 2e-5 \
    --max_steps "${MAX_STEPS}" \
    --log_steps 5 \
    --save_steps 999999 \
    --eval_steps 999999 \
    --num_workers 8 \
    --prefetch_factor 4 \
    --pin_memory 1 \
    --persistent_workers 1 \
    --max_open_tars 32 \
    --gradient_checkpointing 0 \
    --ddp_static_graph 1 \
    --language_tag_mode auto \
    --max_duration_s 30.0 \
    --seed 42 \
    "$@" \
    2>&1 | tee "${LOG_DIR}/smoke_$(date +%Y%m%d_%H%M%S).log"
