#!/usr/bin/env bash
# Stage 1 Smoke Run: Validate FastConformer encoder training pipeline
#
# Preflight automatically ensures all artifacts exist and are fresh:
#   1. Check required Python modules
#   2. Smoke manifest (build if missing)
#   3. Train/val split manifests (rebuild if source is newer)
#   4. Tokenizer (rebuild if train manifest is newer)
#   5. Input cfg (rebuild if train manifest is newer)
#   6. Extracted audio (rebuild if train/val manifests are newer)
#
# Usage:
#   bash scripts/run_stage1_smoke.sh                              # full smoke run
#   bash scripts/run_stage1_smoke.sh --dry-run                    # preflight only
#   bash scripts/run_stage1_smoke.sh --max-steps=10 --devices=2   # multi-GPU shakedown
#   bash scripts/run_stage1_smoke.sh --force-rebuild               # rebuild all artifacts

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$PROJECT_ROOT"

# Detect Python interpreter: PYTHON env var > python > python3
if [ -n "${PYTHON:-}" ] && command -v "$PYTHON" &>/dev/null; then
    PY="$PYTHON"
elif command -v python &>/dev/null; then
    PY=python
elif command -v python3 &>/dev/null; then
    PY=python3
else
    echo "FAIL: Neither 'python' nor 'python3' found in PATH"
    exit 1
fi

CONFIG="configs/train/stage1_smoke.yaml"
MANIFEST="data/manifests/smoke.jsonl"
TRAIN_MANIFEST="data/manifests/smoke_train.jsonl"
VAL_MANIFEST="data/manifests/smoke_val.jsonl"
SPLIT_METADATA="data/manifests/split_metadata.json"
TRAIN_EXTRACTED="data/manifests/smoke_train_extracted.jsonl"
VAL_EXTRACTED="data/manifests/smoke_val_extracted.jsonl"
TRAIN_EXTRACT_META="data/manifests/smoke_train_extracted.meta.json"
VAL_EXTRACT_META="data/manifests/smoke_val_extracted.meta.json"
AUDIO_CACHE="data/audio_cache"
TOKENIZER_DIR="tokenizers/smoke_bpe"
TOKENIZER_MODEL="$TOKENIZER_DIR/tokenizer.model"
TOKENIZER_METADATA="$TOKENIZER_DIR/metadata.json"
INPUT_CFG="configs/data/stage1_smoke_input_cfg.yaml"
INPUT_CFG_METADATA="configs/data/input_cfg_metadata.json"
EXP_DIR="experiments/stage1_smoke"
DRY_RUN=false
FORCE_REBUILD=false
MAX_STEPS=""
DEVICES=""

# Parse args
for arg in "$@"; do
    case $arg in
        --dry-run) DRY_RUN=true ;;
        --force-rebuild) FORCE_REBUILD=true ;;
        --max-steps=*) MAX_STEPS="${arg#*=}" ;;
        --max-steps) echo "ERROR: --max-steps requires a value (e.g. --max-steps=20)"; exit 1 ;;
        --devices=*) DEVICES="${arg#*=}" ;;
        --devices) echo "ERROR: --devices requires a value (e.g. --devices=2)"; exit 1 ;;
        *) echo "Unknown arg: $arg"; exit 1 ;;
    esac
done

echo "============================================"
echo "  Maya ASR - Stage 1 Smoke Run"
echo "============================================"
echo ""

# --- Step 0: Runtime info + dependency check ---
echo "--- Preflight: Runtime & Dependencies ---"
echo "  interpreter: $(which $PY)"
$PY scripts/print_runtime_info.py 2>/dev/null | sed 's/^/  /'

MISSING_MODULES=""
for mod in pyarrow yaml sentencepiece omegaconf; do
    if ! $PY -c "import $mod" 2>/dev/null; then
        MISSING_MODULES="$MISSING_MODULES $mod"
    fi
done

if [ -n "$MISSING_MODULES" ]; then
    echo "FAIL: Missing required Python modules:$MISSING_MODULES"
    echo ""
    echo "Fix with:"
    echo "  make setup"
    echo "  # or: $PY -m pip install -e \".[dev]\""
    exit 1
fi
echo "OK: Required modules (pyarrow, yaml, sentencepiece, omegaconf)"

# Check training deps (fail fast if not dry-run)
TORCH_AVAILABLE=false
NEMO_AVAILABLE=false
if $PY -c "import torch" 2>/dev/null; then
    TORCH_AVAILABLE=true
fi
if $PY -c "import nemo" 2>/dev/null; then
    NEMO_AVAILABLE=true
fi

if [ "$DRY_RUN" = false ] && [ "$TORCH_AVAILABLE" = false ]; then
    echo "FAIL: torch not installed (required for training)"
    echo "  Fix: pip install -e '.[train,dev]'"
    exit 1
fi
if [ "$DRY_RUN" = false ] && [ "$NEMO_AVAILABLE" = false ]; then
    echo "FAIL: nemo not installed (required for training)"
    echo "  Fix: pip install -e '.[train,dev]'"
    exit 1
fi

echo ""
echo "--- Preflight: Artifact Check ---"

if [ "$FORCE_REBUILD" = true ]; then
    echo "  (--force-rebuild: will rebuild split + tokenizer + input_cfg + extracted train/val)"
fi

# --- Step 1: Ensure smoke manifest exists ---
if [ ! -f "$MANIFEST" ]; then
    echo "  Building smoke manifest..."
    $PY scripts/build_manifest.py --languages en hi --max-shards 1 --output "$MANIFEST"
fi
MANIFEST_LINES=$(wc -l < "$MANIFEST")
echo "OK: Manifest        $MANIFEST ($MANIFEST_LINES rows)"

# --- Step 2: Ensure train/val split manifests are fresh ---
NEED_SPLIT=false
if [ "$FORCE_REBUILD" = true ]; then
    NEED_SPLIT=true
    echo "  Rebuilding split (--force-rebuild)..."
elif [ ! -f "$TRAIN_MANIFEST" ] || [ ! -f "$VAL_MANIFEST" ]; then
    NEED_SPLIT=true
    echo "  Split manifests missing, building..."
elif [ "$MANIFEST" -nt "$TRAIN_MANIFEST" ]; then
    NEED_SPLIT=true
    echo "  Rebuilding split (source manifest is newer)..."
elif [ ! -f "$SPLIT_METADATA" ]; then
    NEED_SPLIT=true
    echo "  Rebuilding split (split metadata missing)..."
else
    CURRENT_SHA=$($PY -c "
from maya_asr.config import file_sha256
from pathlib import Path
print(file_sha256(Path('$MANIFEST')))
")
    STORED_SHA=$($PY -c "
import json
meta = json.load(open('$SPLIT_METADATA'))
print(meta.get('source_manifest_sha256', ''))
")
    if [ "$CURRENT_SHA" != "$STORED_SHA" ]; then
        NEED_SPLIT=true
        echo "  Rebuilding split (source manifest content changed)..."
    fi
fi

if [ "$NEED_SPLIT" = true ]; then
    $PY scripts/split_manifest.py \
        --input "$MANIFEST" \
        --train-output "$TRAIN_MANIFEST" \
        --val-output "$VAL_MANIFEST" \
        --val-ratio 0.01 --seed 42
fi
TRAIN_LINES=$(wc -l < "$TRAIN_MANIFEST")
VAL_LINES=$(wc -l < "$VAL_MANIFEST")
echo "OK: Train manifest  $TRAIN_MANIFEST ($TRAIN_LINES rows)"
echo "OK: Val manifest    $VAL_MANIFEST ($VAL_LINES rows)"

# --- Step 3: Ensure tokenizer is fresh ---
NEED_TOKENIZER=false
if [ "$FORCE_REBUILD" = true ]; then
    NEED_TOKENIZER=true
    echo "  Rebuilding tokenizer (--force-rebuild)..."
elif [ ! -f "$TOKENIZER_MODEL" ]; then
    NEED_TOKENIZER=true
    echo "  Tokenizer missing, building..."
elif [ "$TRAIN_MANIFEST" -nt "$TOKENIZER_MODEL" ]; then
    NEED_TOKENIZER=true
    echo "  Rebuilding tokenizer (train manifest is newer)..."
elif [ ! -f "$TOKENIZER_METADATA" ]; then
    NEED_TOKENIZER=true
    echo "  Rebuilding tokenizer (metadata missing)..."
else
    TRAIN_SHA=$($PY -c "
from maya_asr.config import file_sha256
from pathlib import Path
print(file_sha256(Path('$TRAIN_MANIFEST')))
")
    TOK_SHA=$($PY -c "
import json
meta = json.load(open('$TOKENIZER_METADATA'))
print(meta.get('source_manifest_sha256', ''))
")
    if [ "$TRAIN_SHA" != "$TOK_SHA" ]; then
        NEED_TOKENIZER=true
        echo "  Rebuilding tokenizer (train manifest content changed)..."
    fi
fi

if [ "$NEED_TOKENIZER" = true ]; then
    $PY scripts/build_tokenizer.py \
        --manifest "$TRAIN_MANIFEST" \
        --output-dir "$TOKENIZER_DIR" \
        --vocab-size 512
fi
echo "OK: Tokenizer       $TOKENIZER_MODEL"

# --- Step 4: Ensure input_cfg is fresh ---
NEED_INPUT_CFG=false
if [ "$FORCE_REBUILD" = true ]; then
    NEED_INPUT_CFG=true
    echo "  Rebuilding input_cfg (--force-rebuild)..."
elif [ ! -f "$INPUT_CFG" ]; then
    NEED_INPUT_CFG=true
    echo "  Input cfg missing, building..."
elif [ "$TRAIN_MANIFEST" -nt "$INPUT_CFG" ]; then
    NEED_INPUT_CFG=true
    echo "  Rebuilding input_cfg (train manifest is newer)..."
elif [ ! -f "$INPUT_CFG_METADATA" ]; then
    NEED_INPUT_CFG=true
    echo "  Rebuilding input_cfg (metadata missing)..."
else
    TRAIN_SHA=$($PY -c "
from maya_asr.config import file_sha256
from pathlib import Path
print(file_sha256(Path('$TRAIN_MANIFEST')))
")
    TRAIN_ABS=$($PY -c "from pathlib import Path; print(Path('$TRAIN_MANIFEST').resolve())")
    CFG_SHA=$($PY -c "
import json
meta = json.load(open('$INPUT_CFG_METADATA'))
hashes = meta.get('source_manifest_sha256', {})
print(hashes.get('$TRAIN_ABS', ''))
")
    if [ "$TRAIN_SHA" != "$CFG_SHA" ]; then
        NEED_INPUT_CFG=true
        echo "  Rebuilding input_cfg (train manifest content changed)..."
    fi
fi

if [ "$NEED_INPUT_CFG" = true ]; then
    $PY scripts/generate_input_cfg.py \
        --manifests "$TRAIN_MANIFEST" \
        --output "$INPUT_CFG"
fi
echo "OK: Input cfg       $INPUT_CFG"

# --- Step 5: Ensure extracted audio manifests are fresh (train) ---
NEED_EXTRACT_TRAIN=false
if [ "$FORCE_REBUILD" = true ]; then
    NEED_EXTRACT_TRAIN=true
    echo "  Rebuilding extracted train audio (--force-rebuild)..."
elif [ ! -f "$TRAIN_EXTRACTED" ]; then
    NEED_EXTRACT_TRAIN=true
    echo "  Extracted train manifest missing, building..."
elif [ "$TRAIN_MANIFEST" -nt "$TRAIN_EXTRACTED" ]; then
    NEED_EXTRACT_TRAIN=true
    echo "  Rebuilding extracted train audio (source is newer)..."
elif [ ! -f "$TRAIN_EXTRACT_META" ]; then
    NEED_EXTRACT_TRAIN=true
    echo "  Rebuilding extracted train audio (metadata missing)..."
else
    TRAIN_SHA=$($PY -c "
from maya_asr.config import file_sha256
from pathlib import Path
print(file_sha256(Path('$TRAIN_MANIFEST')))
")
    EXT_SHA=$($PY -c "
import json
meta = json.load(open('$TRAIN_EXTRACT_META'))
print(meta.get('source_manifest_sha256', ''))
")
    if [ "$TRAIN_SHA" != "$EXT_SHA" ]; then
        NEED_EXTRACT_TRAIN=true
        echo "  Rebuilding extracted train audio (train manifest content changed)..."
    fi
fi

if [ "$NEED_EXTRACT_TRAIN" = true ]; then
    $PY scripts/extract_smoke_audio.py \
        --input "$TRAIN_MANIFEST" \
        --output "$TRAIN_EXTRACTED" \
        --audio-dir "$AUDIO_CACHE" \
        --max-rows 200
fi
TRAIN_EXT_LINES=$(wc -l < "$TRAIN_EXTRACTED")
echo "OK: Extracted train $TRAIN_EXTRACTED ($TRAIN_EXT_LINES rows)"

# --- Step 5b: Ensure extracted audio manifests are fresh (val) ---
NEED_EXTRACT_VAL=false
if [ "$FORCE_REBUILD" = true ]; then
    NEED_EXTRACT_VAL=true
    echo "  Rebuilding extracted val audio (--force-rebuild)..."
elif [ ! -f "$VAL_EXTRACTED" ]; then
    NEED_EXTRACT_VAL=true
    echo "  Extracted val manifest missing, building..."
elif [ "$VAL_MANIFEST" -nt "$VAL_EXTRACTED" ]; then
    NEED_EXTRACT_VAL=true
    echo "  Rebuilding extracted val audio (source is newer)..."
elif [ ! -f "$VAL_EXTRACT_META" ]; then
    NEED_EXTRACT_VAL=true
    echo "  Rebuilding extracted val audio (metadata missing)..."
else
    VAL_SHA=$($PY -c "
from maya_asr.config import file_sha256
from pathlib import Path
print(file_sha256(Path('$VAL_MANIFEST')))
")
    EXT_VAL_SHA=$($PY -c "
import json
meta = json.load(open('$VAL_EXTRACT_META'))
print(meta.get('source_manifest_sha256', ''))
")
    if [ "$VAL_SHA" != "$EXT_VAL_SHA" ]; then
        NEED_EXTRACT_VAL=true
        echo "  Rebuilding extracted val audio (val manifest content changed)..."
    fi
fi

if [ "$NEED_EXTRACT_VAL" = true ]; then
    $PY scripts/extract_smoke_audio.py \
        --input "$VAL_MANIFEST" \
        --output "$VAL_EXTRACTED" \
        --audio-dir "$AUDIO_CACHE" \
        --max-rows 50
fi
VAL_EXT_LINES=$(wc -l < "$VAL_EXTRACTED")
echo "OK: Extracted val   $VAL_EXTRACTED ($VAL_EXT_LINES rows)"

# --- Step 6: Check config ---
if [ ! -f "$CONFIG" ]; then
    echo "FAIL: Config not found at $CONFIG"
    exit 1
fi
echo "OK: Config          $CONFIG"

# --- Step 7: Experiment dir ---
mkdir -p "$EXP_DIR"
echo "OK: Experiment dir  $EXP_DIR"

echo ""
echo "--- Resolved Artifact Paths ---"
echo "  manifest:         $(realpath "$MANIFEST")"
echo "  train_manifest:   $(realpath "$TRAIN_MANIFEST")"
echo "  val_manifest:     $(realpath "$VAL_MANIFEST")"
echo "  train_extracted:  $(realpath "$TRAIN_EXTRACTED")"
echo "  val_extracted:    $(realpath "$VAL_EXTRACTED")"
echo "  tokenizer_dir:    $(realpath "$TOKENIZER_DIR")"
echo "  input_cfg:        $(realpath "$INPUT_CFG")"
echo "  config:           $(realpath "$CONFIG")"
echo "  exp_dir:          $(realpath "$EXP_DIR")"

# --- Step 8: Validate config ---
echo ""
echo "--- Config Validation ---"
$PY -c "
from pathlib import Path
from omegaconf import OmegaConf

config = OmegaConf.load('$CONFIG')

required = ['trainer', 'model', 'exp_manager']
for section in required:
    assert section in config, f'Missing section: {section}'
    print(f'  OK: {section} section present')

t = config.trainer
override = '$MAX_STEPS'
effective_steps = int(override) if override else t.max_steps
print(f'  OK: max_steps={effective_steps} (config={t.max_steps}, override={override or \"none\"})')
print(f'  OK: devices={t.devices}')
print(f'  OK: precision={t.precision}')

m = config.model
for component in ['preprocessor', 'encoder', 'decoder', 'train_ds', 'validation_ds', 'optim']:
    assert component in m, f'Missing model component: {component}'
    print(f'  OK: model.{component} present')

enc = m.encoder
print(f'  OK: encoder ({enc.n_layers} layers, {enc.d_model} dim)')

tok = m.tokenizer
tok_dir = Path(tok.dir)
assert tok_dir.exists(), f'Tokenizer dir not found: {tok_dir}'
assert (tok_dir / 'tokenizer.model').exists(), f'tokenizer.model not found in {tok_dir}'
print(f'  OK: tokenizer dir {tok_dir} (model file present)')

# Validate extracted manifests exist (used at training time)
train_ext = Path('$TRAIN_EXTRACTED')
val_ext = Path('$VAL_EXTRACTED')
assert train_ext.exists(), f'Extracted train manifest not found: {train_ext}'
assert val_ext.exists(), f'Extracted val manifest not found: {val_ext}'
print(f'  OK: extracted train manifest {train_ext}')
print(f'  OK: extracted val manifest {val_ext}')

print()
print('Config validation: PASS')
"

# --- Step 9: Run or dry-run ---
echo ""
if [ "$DRY_RUN" = true ]; then
    echo "--- DRY RUN MODE ---"
    echo "All preflight checks passed. Config validated."
    if [ "$TORCH_AVAILABLE" = true ] && [ "$NEMO_AVAILABLE" = true ]; then
        echo "Training deps available. Ready to train."
        echo ""
        echo "Run:"
        echo "  bash scripts/run_stage1_smoke.sh --max-steps=20"
    else
        echo "Training deps missing. Install first:"
        echo ""
        echo "  pip install -e '.[train,dev]'"
    fi
    exit 0
fi

# torch/nemo already validated above (fail-fast if missing and not dry-run)

TRAIN_ARGS="--config $CONFIG --smoke"
TRAIN_ARGS="$TRAIN_ARGS --train-manifest $TRAIN_EXTRACTED --val-manifest $VAL_EXTRACTED"
if [ -n "$MAX_STEPS" ]; then
    TRAIN_ARGS="$TRAIN_ARGS --max-steps $MAX_STEPS"
    echo "--- Starting Smoke Training (--max-steps=$MAX_STEPS) ---"
else
    echo "--- Starting Smoke Training ---"
fi
if [ -n "$DEVICES" ]; then
    TRAIN_ARGS="$TRAIN_ARGS --devices $DEVICES"
    echo "  devices=$DEVICES"
fi
echo ""

$PY scripts/train_smoke.py $TRAIN_ARGS

echo ""
echo "============================================"
echo "  Smoke run finished"
echo "============================================"
