#!/bin/bash
# Launch training for Cohere Transcribe finetuning
# Adjust NNODES, NPROC, and RDZV as needed for your cluster

set -euo pipefail

# ============================================================
# Configuration — edit these
# ============================================================
NNODES=${NNODES:-1}
NPROC=${NPROC:-8}                          # GPUs per node
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-29500}
CONFIG=${CONFIG:-"config.yaml"}

# ============================================================
# Environment
# ============================================================
export OMP_NUM_THREADS=4
export TOKENIZERS_PARALLELISM=false

# Load wandb and other API keys from .env
if [ -f /workspace/maya-asr/.env ]; then
    set -a
    source /workspace/maya-asr/.env
    set +a
    echo "Loaded API keys from /workspace/maya-asr/.env (WANDB_MODE=online)"
fi
# Force wandb online mode (was defaulting to offline without API key)
export WANDB_MODE=online

# NCCL tuning
export NCCL_IB_DISABLE=0                   # Enable InfiniBand if available
export NCCL_NET_GDR_LEVEL=5                # GPUDirect RDMA
export NCCL_SOCKET_IFNAME=eth0             # Network interface

# Memory
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# ============================================================
# Pre-flight checks
# ============================================================
echo "=== Pre-flight checks ==="
echo "Nodes: $NNODES, GPUs/node: $NPROC, Total GPUs: $((NNODES * NPROC))"
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
echo ""

# Check data exists
MEL_DIR=$(python3 -c "import yaml; c=yaml.safe_load(open('$CONFIG')); print(c['data']['mel_shards_dir'])")
if [ ! -d "$MEL_DIR" ]; then
    echo "ERROR: Mel shards directory not found: $MEL_DIR"
    echo "Run the data preprocessing pipeline first (see data-agent-instructions.md)"
    exit 1
fi

SHARD_COUNT=$(find "$MEL_DIR" -name '*.tar' | wc -l)
echo "Found $SHARD_COUNT mel shards in $MEL_DIR"

# ============================================================
# Launch
# ============================================================
echo ""
echo "=== Launching training ==="
echo "Config: $CONFIG"
echo ""

torchrun \
    --nnodes=$NNODES \
    --nproc_per_node=$NPROC \
    --rdzv_backend=c10d \
    --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
    train.py --config $CONFIG
