#!/usr/bin/env bash
# Production training launch — Qwen3-ASR Phase 2
# Expected duration: ~4.5 days
set -euo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"
export WANDB_MODE="${WANDB_MODE:-online}"

MODEL_PATH="/root/data/qwen3_asr_weights"
BUCKET_DIR="/root/gemini-asr/lf_asr/artifacts/phase2/buckets"
BUCKET_CONFIG="${BUCKET_DIR}/bucket_config.json"
OUTPUT_DIR="/root/data/qwen3-asr-phase2-out"

mkdir -p "$OUTPUT_DIR"

echo "============================================"
echo "  Qwen3-ASR Phase 2 — Production Training"
echo "  $(date)"
echo "  Output: $OUTPUT_DIR"
echo "============================================"

torchrun --standalone --nproc_per_node=8 \
    finetuning/qwen3_asr_sft_phase2.py \
    --model_path "$MODEL_PATH" \
    --bucket_dir "$BUCKET_DIR" \
    --bucket_config "$BUCKET_CONFIG" \
    --output_dir "$OUTPUT_DIR" \
    --batch_size 16 \
    --grad_acc 1 \
    --lr 2e-5 \
    --warmup_ratio 0.02 \
    --lr_scheduler_type cosine \
    --epochs 1 \
    --log_steps 50 \
    --save_steps 5000 \
    --save_total_limit 5 \
    --num_workers 4 \
    --prefetch_factor 4 \
    --pin_memory 1 \
    --persistent_workers 1 \
    --max_open_tars 64 \
    --use_indexed_tar 1 \
    --worker_decode 1 \
    --gradient_checkpointing 0 \
    --attn_implementation flash_attention_2 \
    --ddp_static_graph 1 \
    --max_batch_seq_len 700 \
    --language_tag_mode auto \
    --seed 42 \
    --profiling 1 \
    --wandb_project maya-asr \
    --wandb_run_name qwen3-asr-1.7b \
    --resume 0 \
    2>&1 | tee "$OUTPUT_DIR/train.log"
