# Overnight 8-Hour TODO: Zero-Padding High-Throughput ASR Training Pipeline

## Mission
Transform the existing 75.2M-sample ASR training pipeline from 123 samples/sec (58.8% GPU util, 48% padding) to 400-600+ samples/sec (90%+ GPU util, <18% padding) via three surgical fixes: indexed tar access, duration bucketing, and worker-side audio decoding.

## Time Budget (8 hours)
| Phase | Task | Time |
|-------|------|------|
| 1 | Build tar offset indices for all 5559 tars | 30 min |
| 2 | Generate per-bucket parquet manifests | 20 min |
| 3 | Rewrite training script (IndexedTarReader, BucketedDataset, worker-side decode) | 3 hours |
| 4 | Bucket BS calibration (per-bucket memory test) | 45 min |
| 5 | Validation run (200 steps, all buckets) | 1 hour |
| 6 | Profiling + comparison report | 30 min |
| Buffer | Debugging, retries | 2 hours |

## Environment
```bash
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export TORCH_NUM_THREADS=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export NCCL_LAUNCH_MODE=PARALLEL
export PYTHONPATH="/root/data/Qwen3-ASR-official:${PYTHONPATH:-}"
```

Worker init (in every subprocess):
```python
import torch
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
```

## Paths
- Training script: `/root/data/Qwen3-ASR-official/finetuning/qwen3_asr_sft_phase2.py`
- Model weights: `/root/data/qwen3_asr_weights/`
- Train manifest: `/root/gemini-asr/lf_asr/artifacts/phase2/train.parquet`
- Dev manifest: `/root/gemini-asr/lf_asr/artifacts/phase2/dev.parquet`
- Audio data: `/alloc/finalsftdata/` (5559 shards, each with `audio_16k.tar`)
- Output dir: `/root/data/qwen3-asr-phase2-out/`
- Bucket artifacts: `/root/gemini-asr/lf_asr/artifacts/phase2/buckets/`

---

## PHASE 1: Build Tar Offset Indices (30 min)

### What
For every `audio_16k.tar`, build a JSON index mapping `member_name → (byte_offset, byte_size)`. This eliminates the 1.4-second `tarfile.getmembers()` scan on every tar open.

### Script: `tools/build_tar_indices.py`

```python
"""Build byte-offset index files for all audio_16k.tar files."""
import os, json, tarfile, sys
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import time

os.environ["OMP_NUM_THREADS"] = "1"

def build_index_for_tar(tar_path: str) -> dict:
    """Open tar, extract member offsets, write index JSON alongside tar."""
    index_path = tar_path + ".index.json"

    # Skip if already built
    if os.path.exists(index_path):
        return {"tar": tar_path, "status": "skipped", "members": 0}

    try:
        with tarfile.open(tar_path, "r:") as tf:
            members = tf.getmembers()
            index = {}
            for m in members:
                if m.isfile():
                    # Store both offset_data and size for direct seek+read
                    index[m.name] = {"offset": m.offset_data, "size": m.size}
                    # Also store without ./ prefix for lookup flexibility
                    bare = m.name.lstrip("./")
                    if bare != m.name:
                        index[bare] = {"offset": m.offset_data, "size": m.size}

        with open(index_path, "w") as f:
            json.dump(index, f)

        return {"tar": tar_path, "status": "ok", "members": len(members)}
    except Exception as e:
        return {"tar": tar_path, "status": f"error: {e}", "members": 0}

def main():
    import pyarrow.parquet as pq

    # Get all unique tar paths from manifest
    print("Loading manifest to find all tar paths...")
    df = pq.read_table("/root/gemini-asr/lf_asr/artifacts/phase2/train.parquet",
                        columns=["tar_path"]).to_pandas()
    tar_paths = sorted(df["tar_path"].unique())
    print(f"Found {len(tar_paths)} unique tar files")

    # Also add dev/test tar paths
    for split in ["dev", "test"]:
        sp = pq.read_table(f"/root/gemini-asr/lf_asr/artifacts/phase2/{split}.parquet",
                           columns=["tar_path"]).to_pandas()
        tar_paths = sorted(set(tar_paths) | set(sp["tar_path"].unique()))
    print(f"Total unique tars (train+dev+test): {len(tar_paths)}")

    # Check how many already have indices
    existing = sum(1 for p in tar_paths if os.path.exists(p + ".index.json"))
    print(f"Already indexed: {existing}, need to build: {len(tar_paths) - existing}")

    # Build indices in parallel
    workers = 64
    done = 0
    errors = 0
    t0 = time.time()

    with ProcessPoolExecutor(max_workers=workers) as pool:
        futures = {pool.submit(build_index_for_tar, p): p for p in tar_paths}
        for f in as_completed(futures):
            result = f.result()
            done += 1
            if "error" in result["status"]:
                errors += 1
                print(f"  ERROR: {result['tar']}: {result['status']}")
            if done % 500 == 0:
                elapsed = time.time() - t0
                rate = done / elapsed
                eta = (len(tar_paths) - done) / rate
                print(f"  {done}/{len(tar_paths)} ({rate:.1f}/s, ETA {eta:.0f}s) errors={errors}")

    elapsed = time.time() - t0
    print(f"\nDone: {done} tars indexed in {elapsed:.1f}s, {errors} errors")

    # Validate: every tar_path in manifest has an index
    missing = [p for p in tar_paths if not os.path.exists(p + ".index.json")]
    if missing:
        print(f"WARNING: {len(missing)} tars missing indices!")
        for p in missing[:10]:
            print(f"  {p}")
    else:
        print("All tars have index files.")

if __name__ == "__main__":
    main()
```

### Run:
```bash
cd /root/data/Qwen3-ASR-official
python3 tools/build_tar_indices.py
```

### Validate:
```bash
# Should show 5559 index files, ~6.5GB total
find /alloc/finalsftdata -name "*.index.json" | wc -l
du -sh $(find /alloc/finalsftdata -name "*.index.json" -print -quit | head -1)
```

---

## PHASE 2: Generate Per-Bucket Parquet Manifests (20 min)

### What
Split `train.parquet` into 7 duration-bucket parquets, globally shuffled within each bucket. This enables same-duration batching with zero cross-bucket padding.

### Bucket Design
| Bucket ID | Duration Range | Expected % of Samples |
|-----------|---------------|----------------------|
| b_1_3 | 1-3s | ~12.6% |
| b_3_5 | 3-5s | ~24.8% |
| b_5_7 | 5-7s | ~13.7% |
| b_7_10 | 7-10s | ~35.3% |
| b_10_15 | 10-15s | ~13.1% |
| b_15_20 | 15-20s | ~0.5% |
| b_20_30 | 20-30s | ~0.01% |

Samples under 1s are merged into b_1_3. Samples over 30s are clipped or placed in b_20_30.

### Script: `tools/build_bucket_manifests.py`

```python
"""Split train.parquet into per-bucket parquets, shuffled within each bucket."""
import os, sys
os.environ["OMP_NUM_THREADS"] = "1"

import pyarrow.parquet as pq
import pandas as pd
import numpy as np
from pathlib import Path

BUCKETS = [
    ("b_0_3",   0.0,  3.0),
    ("b_3_5",   3.0,  5.0),
    ("b_5_7",   5.0,  7.0),
    ("b_7_10",  7.0, 10.0),
    ("b_10_15", 10.0, 15.0),
    ("b_15_20", 15.0, 20.0),
    ("b_20_30", 20.0, 30.0),
]

SEED = 42
ARTIFACTS = Path("/root/gemini-asr/lf_asr/artifacts/phase2")
BUCKET_DIR = ARTIFACTS / "buckets"

def main():
    BUCKET_DIR.mkdir(parents=True, exist_ok=True)

    print("Loading train.parquet...")
    df = pd.read_parquet(ARTIFACTS / "train.parquet")
    print(f"  {len(df):,} rows")

    # Add tar_index_path column
    df["tar_index_path"] = df["tar_path"] + ".index.json"

    # Filter out empty transcripts (they get skipped at training time anyway)
    non_empty = df["transcript"].str.strip().ne("") & df["transcript"].notna()
    print(f"  Non-empty transcripts: {non_empty.sum():,} ({non_empty.sum()/len(df)*100:.1f}%)")
    df = df[non_empty].copy()

    # Assign buckets
    stats = []
    rng = np.random.RandomState(SEED)

    for bucket_id, lo, hi in BUCKETS:
        mask = (df["duration_s"] >= lo) & (df["duration_s"] < hi)
        bdf = df[mask].copy()

        # Global shuffle within bucket
        bdf = bdf.sample(frac=1.0, random_state=rng).reset_index(drop=True)

        # Add bucket metadata
        bdf["bucket_id"] = bucket_id
        bdf["bucket_max_duration"] = hi

        out_path = BUCKET_DIR / f"{bucket_id}.parquet"
        bdf.to_parquet(out_path, index=False)

        hours = bdf["duration_s"].sum() / 3600
        stats.append({
            "bucket_id": bucket_id,
            "range": f"{lo}-{hi}s",
            "samples": len(bdf),
            "hours": hours,
            "pct_samples": len(bdf) / len(df) * 100,
            "pct_hours": hours,
            "mean_duration": bdf["duration_s"].mean(),
            "max_duration": bdf["duration_s"].max(),
        })
        print(f"  {bucket_id}: {len(bdf):>10,} samples, {hours:>8,.0f}h, "
              f"mean={bdf['duration_s'].mean():.1f}s, max={bdf['duration_s'].max():.1f}s")

    total_hours = sum(s["hours"] for s in stats)
    for s in stats:
        s["pct_hours"] = s["hours"] / total_hours * 100

    # Write bucket config
    import json
    config = {
        "buckets": stats,
        "total_samples": len(df),
        "total_hours": total_hours,
        "seed": SEED,
    }
    with open(BUCKET_DIR / "bucket_config.json", "w") as f:
        json.dump(config, f, indent=2)

    print(f"\nTotal: {len(df):,} samples, {total_hours:,.0f}h across {len(BUCKETS)} buckets")
    print(f"Bucket parquets written to: {BUCKET_DIR}")

if __name__ == "__main__":
    main()
```

### Run:
```bash
python3 tools/build_bucket_manifests.py
```

---

## PHASE 3: Rewrite Training Script (3 hours)

### Changes to `finetuning/qwen3_asr_sft_phase2.py`

#### 3a. Replace `TarAudioLRUCache` with `IndexedTarReader`

The key change: instead of `tarfile.open()` (which scans the entire tar), use pre-built offset indices with plain `open() + seek() + read()`.

```python
class IndexedTarReader:
    """Direct byte-offset tar reader using pre-built index files.

    ~400x faster than tarfile.extractfile() because:
    - No tarfile.getmembers() scan (1.4s → 0ms)
    - Plain file handle seek+read instead of tarfile extraction
    - LRU cache of open file handles
    """

    def __init__(self, max_open_files: int = 64):
        self.max_open_files = max_open_files
        self._file_cache = OrderedDict()  # tar_path -> file handle
        self._index_cache = {}  # tar_path -> {member_name: {offset, size}}

    def _load_index(self, tar_path: str) -> dict:
        if tar_path in self._index_cache:
            return self._index_cache[tar_path]

        index_path = tar_path + ".index.json"
        with open(index_path, "r") as f:
            index = json.load(f)
        self._index_cache[tar_path] = index
        return index

    def _get_file(self, tar_path: str):
        real_path = os.path.realpath(tar_path)
        if real_path in self._file_cache:
            self._file_cache.move_to_end(real_path)
            return self._file_cache[real_path]

        fh = open(real_path, "rb")
        self._file_cache[real_path] = fh
        if len(self._file_cache) > self.max_open_files:
            _, old_fh = self._file_cache.popitem(last=False)
            old_fh.close()
        return fh

    def read_member(self, tar_path: str, member_name: str) -> bytes:
        index = self._load_index(tar_path)

        # Try exact name, then with/without ./ prefix
        entry = index.get(member_name)
        if entry is None:
            bare = member_name.lstrip("./")
            entry = index.get(bare) or index.get("./" + member_name)
        if entry is None:
            raise FileNotFoundError(f"Member {member_name} not found in index for {tar_path}")

        fh = self._get_file(tar_path)
        fh.seek(entry["offset"])
        return fh.read(entry["size"])

    def close(self):
        for fh in self._file_cache.values():
            fh.close()
        self._file_cache.clear()
        self._index_cache.clear()
```

#### 3b. Move Audio Decoding to Dataset Workers

Currently: workers yield `{tar_path, tar_member_name, transcript}` → collator decodes audio in main process.
After: workers yield `{waveform, transcript, language, duration_s}` → collator only pads + tokenizes.

In `ParquetTarIterableDataset.__iter__()`, after yielding the sample dict, add audio decoding:

```python
# In the dataset __iter__, AFTER filtering:
# Decode audio in the worker process (parallelized across num_workers)
raw_bytes = self._tar_reader.read_member(tar_path, tar_member_name)
wav, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False)
wav = np.asarray(wav, dtype=np.float32)
if wav.ndim == 2:
    wav = wav.mean(axis=1)
# Should already be 16kHz, but safety check
if sr != 16000:
    wav_t = torch.from_numpy(wav).unsqueeze(0)
    wav_t = torchaudio.functional.resample(wav_t, sr, 16000)
    wav = wav_t.squeeze(0).numpy()

yield {
    "waveform": wav,
    "transcript": transcript,
    "language": language,
    "prompt": prompt,
    "duration_s": duration_s,
}
```

Each worker gets its own `IndexedTarReader` instance (initialized in worker_init_fn).

The collator becomes thin:
```python
def __call__(self, features):
    # features now contain decoded waveforms, not tar paths
    audios = [f["waveform"] for f in features if f["waveform"] is not None]
    # ... rest is just processor call (mel + tokenize) + label masking
```

#### 3c. Implement `BucketedIterableDataset`

Replace the single `ParquetTarIterableDataset` with a bucket-aware version:

```python
class BucketedIterableDataset(IterableDataset):
    """Reads from per-bucket parquet files, yields same-duration batches."""

    def __init__(self, bucket_dir, bucket_config, ...):
        self.buckets = []  # list of (bucket_id, parquet_path, batch_size, weight)
        for b in bucket_config["buckets"]:
            path = bucket_dir / f"{b['bucket_id']}.parquet"
            self.buckets.append({
                "id": b["bucket_id"],
                "path": str(path),
                "weight": b["hours"],  # proportional to audio hours
                "batch_size": b.get("batch_size", 16),  # from calibration
            })

    def __iter__(self):
        # Each worker handles all buckets but different row groups
        # Weighted round-robin: draw from bucket proportional to weight
        ...
```

#### 3d. Dynamic Batch Size from `bucket_config.json`

The training script reads per-bucket batch sizes from the calibration file:

```json
{
  "b_0_3":  {"batch_size": 80,  "grad_acc": 1},
  "b_3_5":  {"batch_size": 48,  "grad_acc": 1},
  "b_5_7":  {"batch_size": 32,  "grad_acc": 1},
  "b_7_10": {"batch_size": 24,  "grad_acc": 2},
  "b_10_15": {"batch_size": 16, "grad_acc": 3},
  "b_15_20": {"batch_size": 8,  "grad_acc": 6},
  "b_20_30": {"batch_size": 8,  "grad_acc": 6}
}
```

These are initial estimates. Phase 4 calibrates them empirically.

---

## PHASE 4: Bucket Batch Size Calibration (45 min)

### Script: `tools/calibrate_bucket_bs.py`

For each bucket, run 5 training steps at the estimated BS. Monitor `torch.cuda.max_memory_allocated()`. Binary search for the max BS that fits in 130GB (leaving 13GB headroom on 143GB H200).

```bash
# Run on single GPU first for speed
for bucket in b_0_3 b_3_5 b_5_7 b_7_10 b_10_15 b_15_20 b_20_30; do
    CUDA_VISIBLE_DEVICES=0 python3 tools/calibrate_bucket_bs.py \
        --bucket $bucket \
        --start_bs 128 \
        --min_bs 8 \
        --max_memory_gb 130 \
        --steps 5
done
```

Then verify on 8-GPU DDP (5 steps per bucket):
```bash
torchrun --nproc_per_node=8 tools/calibrate_bucket_bs.py --mode verify_all
```

Write results to `artifacts/phase2/buckets/calibrated_bs.json`.

---

## PHASE 5: Validation Run (1 hour)

### 5a. Correctness Test (50 steps, each bucket)
```bash
torchrun --nproc_per_node=8 finetuning/qwen3_asr_sft_phase2.py \
    --model_path /root/data/qwen3_asr_weights \
    --bucket_dir /root/gemini-asr/lf_asr/artifacts/phase2/buckets \
    --bucket_config /root/gemini-asr/lf_asr/artifacts/phase2/buckets/calibrated_bs.json \
    --output_dir /root/data/qwen3-asr-phase2-out \
    --max_steps 50 \
    --gradient_checkpointing 1 \
    --log_steps 5 \
    --save_steps 999999 \
    --lr 2e-5 --warmup_ratio 0.02 \
    --num_workers 16 --prefetch_factor 4 --max_open_tars 64 \
    --language_tag_mode auto
```

Check:
- [ ] No OOM on any bucket
- [ ] Loss is decreasing
- [ ] All 7 buckets are being sampled
- [ ] GPU utilization >85%
- [ ] No NCCL timeouts

### 5b. Throughput Benchmark (200 steps)
Same command with `--max_steps 200`. Extract:
- Avg samples/sec (global)
- Per-bucket step times
- GPU memory per bucket
- GPU utilization
- Padding efficiency (computed from feature_attention_mask)

### 5c. Compare Against Baseline

| Metric | Old (E1) | New (bucketed) | Delta |
|--------|----------|----------------|-------|
| Samples/sec | 123.2 | ??? | ??? |
| GPU util | 58.8% | ??? | ??? |
| Padding waste | ~48% | ??? | ??? |
| Peak memory | 123.8GB | ??? | ??? |
| Stability | 100 steps clean | ??? | ??? |

---

## PHASE 6: Profiling Report (30 min)

Run 50 steps with PyTorch profiler enabled:
```bash
torchrun --nproc_per_node=8 finetuning/qwen3_asr_sft_phase2.py \
    ... (same as above) \
    --profile 1 --profile_steps 20
```

Report breakdown:
- Data loading time (tar read + FLAC decode)
- Mel spectrogram computation
- Forward pass (audio encoder + text decoder)
- Backward pass
- Optimizer step + allreduce
- Idle/sync time

---

## Success Criteria

1. **Throughput**: >300 samples/sec (>2.4x improvement over 123.2 baseline)
2. **GPU Utilization**: >85% (up from 58.8%)
3. **Padding Waste**: <20% (down from 48%)
4. **Stability**: 200 steps clean, no OOM, no NCCL timeout
5. **All buckets sampled**: proportional to audio hours
6. **Loss convergence**: comparable or better than unbucketed baseline

## Production Launch Command (after validation)

```bash
torchrun --standalone --nproc_per_node=8 \
    finetuning/qwen3_asr_sft_phase2.py \
    --model_path /root/data/qwen3_asr_weights \
    --bucket_dir /root/gemini-asr/lf_asr/artifacts/phase2/buckets \
    --bucket_config /root/gemini-asr/lf_asr/artifacts/phase2/buckets/calibrated_bs.json \
    --eval_file /root/gemini-asr/lf_asr/artifacts/phase2/dev.parquet \
    --output_dir /root/data/qwen3-asr-phase2-out \
    --epochs 1 \
    --gradient_checkpointing 1 \
    --lr 2e-5 --warmup_ratio 0.02 --lr_scheduler_type cosine \
    --log_steps 50 \
    --save_steps 5000 --save_total_limit 10 \
    --eval_steps 20000 \
    --num_workers 16 --prefetch_factor 4 --max_open_tars 64 \
    --language_tag_mode auto \
    --seed 42
```
