#!/usr/bin/env bash
# Profile a smoke training run: capture GPU utilization, memory, and step timing.
#
# Usage:
#   bash scripts/profile_smoke_run.sh --max-steps=20 --devices=1

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$PROJECT_ROOT"

mkdir -p reports

PROFILE_LOG="reports/gpu_profile_$(date +%Y%m%d_%H%M%S).csv"
TRAINING_LOG="reports/training_profile_$(date +%Y%m%d_%H%M%S).txt"

echo "Profiling smoke training run..."
echo "  GPU log: $PROFILE_LOG"
echo "  Training log: $TRAINING_LOG"
echo ""

# Start nvidia-smi sampling in background (1 sample/sec)
nvidia-smi --query-gpu=timestamp,index,utilization.gpu,memory.used,memory.total,power.draw \
    --format=csv,nounits -l 1 > "$PROFILE_LOG" 2>/dev/null &
SMI_PID=$!

# Run training
bash scripts/run_stage1_smoke.sh "$@" 2>&1 | tee "$TRAINING_LOG"
TRAIN_EXIT=$?

# Stop nvidia-smi
kill $SMI_PID 2>/dev/null || true
wait $SMI_PID 2>/dev/null || true

if [ $TRAIN_EXIT -ne 0 ]; then
    echo "Training failed (exit $TRAIN_EXIT). GPU profile still saved."
    exit $TRAIN_EXIT
fi

# Summarize GPU profile
echo ""
echo "--- GPU Profile Summary ---"
python3 -c "
import csv
from pathlib import Path

rows = list(csv.DictReader(open('$PROFILE_LOG')))
if not rows:
    print('  No GPU samples captured.')
    exit()

# Clean column names (nvidia-smi adds spaces)
cleaned = []
for r in rows:
    cleaned.append({k.strip(): v.strip() for k, v in r.items()})
rows = cleaned

utils = [float(r.get('utilization.gpu [%]', r.get('utilization.gpu', 0))) for r in rows]
mems = [float(r.get('memory.used [MiB]', r.get('memory.used', 0))) for r in rows]
mem_totals = [float(r.get('memory.total [MiB]', r.get('memory.total', 0))) for r in rows]

print(f'  Samples: {len(rows)}')
print(f'  Avg GPU util: {sum(utils)/len(utils):.1f}%')
print(f'  Peak GPU util: {max(utils):.0f}%')
print(f'  Peak GPU memory: {max(mems):.0f} MiB / {max(mem_totals):.0f} MiB ({100*max(mems)/max(mem_totals):.1f}%)')
print(f'  Avg GPU memory: {sum(mems)/len(mems):.0f} MiB')
"

# Extract step timing from training log
echo ""
echo "--- Step Timing Summary ---"
grep -oP 'Avg sec/step: \K[\d.]+' "$TRAINING_LOG" | head -1 | while read avg; do
    echo "  Avg sec/step: ${avg}"
done
grep -oP 'Steps: \K\d+' "$TRAINING_LOG" | head -1 | while read steps; do
    echo "  Steps completed: ${steps}"
done
grep -oP 'Elapsed: \K[\d.]+s' "$TRAINING_LOG" | head -1 | while read elapsed; do
    echo "  Elapsed: ${elapsed}"
done
