#!/usr/bin/env bash
# 1000-step canary: full production config with checkpoint + eval + resume test
set -euo pipefail

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=3
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export NCCL_LAUNCH_MODE=PARALLEL
export TORCHINDUCTOR_CACHE_DIR="/tmp/torch_inductor_cache"
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"

MODEL_PATH="/root/data/qwen3_asr_weights"
BUCKET_DIR="/root/gemini-asr/lf_asr/artifacts/phase2/buckets"
BUCKET_CONFIG="${BUCKET_DIR}/bucket_config.json"
VAL_FILE="${BUCKET_DIR}/val_holdout.parquet"
OUTPUT_DIR="/root/data/qwen3-asr-phase2-canary"
LOG_DIR="${OUTPUT_DIR}/logs"

mkdir -p "$OUTPUT_DIR" "$LOG_DIR"
rm -rf "$OUTPUT_DIR"/checkpoint-* 2>/dev/null || true

echo "=== 1K-Step Canary Run $(date) ==="
echo "  Aggressive dynamic BS, no GC, FA2, checkpoint+eval enabled"

# GPU monitor
(while true; do
    nvidia-smi --query-gpu=index,memory.used,utilization.gpu --format=csv,noheader
    sleep 5
done) > "$LOG_DIR/gpu_trace.csv" 2>/dev/null &
GPU_MON_PID=$!
trap "kill $GPU_MON_PID 2>/dev/null || true" EXIT

T0=$(date +%s)

torchrun --standalone --nproc_per_node=8 \
    finetuning/qwen3_asr_sft_phase2.py \
    --model_path "$MODEL_PATH" \
    --bucket_dir "$BUCKET_DIR" \
    --bucket_config "$BUCKET_CONFIG" \
    --eval_file "$VAL_FILE" \
    --output_dir "$OUTPUT_DIR" \
    --batch_size 16 \
    --grad_acc 1 \
    --lr 2e-5 \
    --warmup_ratio 0.02 \
    --lr_scheduler_type cosine \
    --max_steps 1000 \
    --log_steps 50 \
    --save_steps 500 \
    --save_total_limit 3 \
    --eval_steps 500 \
    --num_workers 8 \
    --prefetch_factor 4 \
    --pin_memory 1 \
    --persistent_workers 1 \
    --max_open_tars 64 \
    --use_indexed_tar 1 \
    --worker_decode 1 \
    --gradient_checkpointing 0 \
    --attn_implementation flash_attention_2 \
    --ddp_static_graph 1 \
    --max_batch_seq_len 500 \
    --language_tag_mode auto \
    --seed 42 \
    --profiling 1 \
    2>&1 | tee "$LOG_DIR/canary.log"

RET=$?
T1=$(date +%s)

echo ""
echo "=== Canary Summary ==="
echo "Wall time: $((T1-T0))s for 1000 steps"
echo "Exit code: $RET"

if [ $RET -eq 0 ]; then
    # Profiling
    grep "\[PROFILING REPORT\]" -A 20 "$LOG_DIR/canary.log"
    grep "'train_loss'" "$LOG_DIR/canary.log" | tail -1

    # Checkpoint verification
    echo ""
    echo "=== Checkpoint Verification ==="
    ls -lh "$OUTPUT_DIR"/checkpoint-*/model.safetensors 2>/dev/null || echo "No checkpoints found!"

    # GPU stats
    if [ -f "$LOG_DIR/gpu_trace.csv" ]; then
        MAX_MEM=$(awk -F', ' '{gsub(/ MiB/,"",$2); print $2}' "$LOG_DIR/gpu_trace.csv" | sort -n | tail -1)
        AVG_UTIL=$(awk -F', ' '{gsub(/ %/,"",$3); sum+=$3; n++} END{printf "%.1f", sum/n}' "$LOG_DIR/gpu_trace.csv")
        echo "Peak GPU memory: ${MAX_MEM:-?} MiB"
        echo "Avg GPU utilization: ${AVG_UTIL:-?}%"
    fi

    # Resume test: run 10 more steps from checkpoint
    echo ""
    echo "=== Resume Test (10 steps from latest checkpoint) ==="
    LATEST_CKPT=$(ls -d "$OUTPUT_DIR"/checkpoint-* 2>/dev/null | sort -t- -k2 -n | tail -1)
    if [ -n "$LATEST_CKPT" ]; then
        echo "Resuming from: $LATEST_CKPT"
        torchrun --standalone --nproc_per_node=8 \
            finetuning/qwen3_asr_sft_phase2.py \
            --model_path "$MODEL_PATH" \
            --bucket_dir "$BUCKET_DIR" \
            --bucket_config "$BUCKET_CONFIG" \
            --output_dir "$OUTPUT_DIR" \
            --batch_size 16 --grad_acc 1 \
            --lr 2e-5 --warmup_ratio 0.02 --lr_scheduler_type cosine \
            --max_steps 1010 \
            --log_steps 5 \
            --save_steps 999999 --eval_steps 999999 \
            --num_workers 8 --prefetch_factor 4 \
            --pin_memory 1 --persistent_workers 1 \
            --max_open_tars 64 --use_indexed_tar 1 --worker_decode 1 \
            --gradient_checkpointing 0 \
            --attn_implementation flash_attention_2 \
            --ddp_static_graph 1 \
            --max_batch_seq_len 500 \
            --language_tag_mode auto --seed 42 \
            --resume 1 \
            2>&1 | tee "$LOG_DIR/resume_test.log"
        RESUME_RET=$?
        echo "Resume exit code: $RESUME_RET"
        grep "'train_loss'" "$LOG_DIR/resume_test.log" | tail -1
    else
        echo "No checkpoint to resume from!"
    fi
fi

echo "=== Done $(date) ==="
