## Big Picture Current training is a **fully local, preprocessed, DDP RNNT/TDT finetuning pipeline**. We are **not** streaming from R2 anymore, and we are **not** building batches from raw shards on the fly. The flow is: 1. Raw shard metadata was merged into a manifest. 2. Text was cleaned and language tags were prepended. 3. A corrected, deduplicated `v2` train manifest was produced. 4. At runtime, each GPU rank loads the full train manifest. 5. Each sample reads an **extracted FLAC** directly from disk. 6. A custom sampler groups samples by duration, temperature-rebalances languages, and packs batches to a max audio duration. 7. DDP splits those packed batches across 8 GPUs. 8. Gradient accumulation of 4 micro-batches gives the final optimizer step. The current code paths that define this are mainly `config.py`, `dataset.py`, `train.py`, and historically `build_manifest.py`. ## 1. What The Training Job Reads Today The active paths are: ```25:35:/alloc/parakeet_tdt_finetune/config.py # Paths BASE_DIR = Path("/alloc/parakeet_tdt_finetune") DATA_DIR = _env_path("DATA_DIR", "/alloc/asr_data") MASTER_MANIFEST_PATH = _env_path("MASTER_MANIFEST_PATH", DATA_DIR / "master_manifest.parquet") TRAIN_MANIFEST_PATH = _env_path("TRAIN_MANIFEST_PATH", DATA_DIR / "train_manifest_v2.parquet") VAL_MANIFEST_PATH = _env_path("VAL_MANIFEST_PATH", DATA_DIR / "val_manifest_v2.parquet") EVAL_EVERY_STEPS = _env_int("EVAL_EVERY_STEPS", 0) EVAL_SAMPLES_PER_LANG = _env_int("EVAL_SAMPLES_PER_LANG", 10) TOKENIZER_DIR = str(BASE_DIR / "trained_tokenizer_12k_compare") CHECKPOINT_DIR = _env_path("CHECKPOINT_DIR", "/alloc/checkpoints_parakeet") WANDB_LOG_DIR = _env_path("WANDB_LOG_DIR", "/alloc/wandb_logs") ``` So the current run uses: - `train_manifest_v2.parquet` - `trained_tokenizer_12k_compare` - no active validation, because `EVAL_EVERY_STEPS=0` - checkpoints every `5000` optimizer steps Important nuance: although the manifest still carries a `tar_path` column, the runtime no longer reads tar files. It uses `dirname(tar_path)` only to recover the shard directory where extracted FLACs now live. ## 2. What A Manifest Row Represents A row in the final training manifest is already "training-ready". It contains: - `segment_id` - `shard_id` - `text` - `duration` - `lang` - `tar_member_name` - `tar_path` - `source` The meaning is: - `text` is already cleaned and already includes the `<|lang|>` prefix - `duration` is the filtered segment duration in seconds - `tar_member_name` is the FLAC filename for that segment - `tar_path` is a legacy pointer to the old tar location; we only use its parent directory So by training time, the job is not deciding what text to use. That decision was already made upstream. ## 3. How The Manifest Was Built The repo’s base manifest builder is `build_manifest.py`. It reads shard-level `metadata.parquet` and `audio_index.parquet`, filters rows, cleans text, and writes a single parquet. Core logic: ```62:123:/alloc/parakeet_tdt_finetune/build_manifest.py try: meta_cols = [ "segment_id", "transcription_native", "duration_s", "tx_quality_score", "segment_language", ] meta = pq.read_table(meta_path, columns=meta_cols) idx = pq.read_table(index_path, columns=["segment_id", "tar_member_name"]) meta_df = meta.to_pandas() idx_df = idx.to_pandas() df = meta_df.merge(idx_df, on="segment_id", how="inner") # Quality filters df = df[ (df["duration_s"] >= MIN_DURATION) & (df["duration_s"] <= MAX_DURATION) & (df["tx_quality_score"] >= MIN_QUALITY) & (df["transcription_native"].notna()) & (df["transcription_native"].str.strip().str.len() > 0) ].copy() ... # Clean text + prepend language tag df["text"] = df.apply( lambda row: _make_text(row["transcription_native"], row["segment_language"] or lang_from_path), axis=1, ) ... df["duration"] = df["duration_s"] df["tar_path"] = tar_path ``` Text cleaning itself is done once, offline: ```24:37:/alloc/parakeet_tdt_finetune/build_manifest.py # ── Text cleaning (same logic as dataset.py, applied once at manifest time) ── _TAG_RE = re.compile(r"\[.*?\]|", re.IGNORECASE) _PUNCT_RE = re.compile( r"[।॥.,!?;:\"'\-—()…%/+=$*@#&_^>|\\<>`~。\u0964\u0965]+" ) def clean_text(text): if not text: return "" text = _TAG_RE.sub("", text) text = _PUNCT_RE.sub("", text) text = re.sub(r"\s+", " ", text).strip() return text ``` That means: - bracket tags like `[laugh]`, `[music]`, `[singing]` are removed - punctuation across Indic and Latin punctuation sets is removed - whitespace is collapsed - then the language tag is prepended, e.g. `<|hi|> ...` ### Important extra step: `v2` manifests Current training does **not** use the raw output of `build_manifest.py` directly. It uses `train_manifest_v2.parquet` and `val_manifest_v2.parquet`, which were further postprocessed to: - correct a large set of wrong language tags - keep row counts the same - preserve cleaned text - deduplicate/split cleanly for training and validation That `v2` correction logic lives in the data artifacts, not in the current repo code, but the training config points to those corrected outputs. ## 4. How A Single Sample Is Loaded At Runtime Current dataset code: ```31:78:/alloc/parakeet_tdt_finetune/dataset.py def __init__(self, manifest_path, tokenizer, sample_rate=cfg.TARGET_SAMPLE_RATE): self.tokenizer = tokenizer self.sample_rate = sample_rate log.info(f"Loading manifest from {manifest_path}...") table = pq.read_table(str(manifest_path)) self.df = table.to_pandas() log.info(f"Loaded {len(self.df):,} segments") ... def __getitem__(self, idx): row = self.df.iloc[idx] shard_dir = os.path.dirname(row["tar_path"]) audio_path = os.path.join(shard_dir, row["tar_member_name"]) try: waveform, sr = torchaudio.load(audio_path) except Exception as e: log.warning( f"Bad audio: shard={row.get('shard_id','?')} " f"file={row['tar_member_name']}: {e}" ) return None ... text = row["text"] tokens = torch.tensor( self.tokenizer.text_to_ids(text), dtype=torch.long ) ``` ### What actually happens in `__getitem__` For sample `idx`: 1. Read one row from the in-memory pandas manifest. 2. Compute `shard_dir = dirname(tar_path)`. 3. Compute `audio_path = shard_dir / tar_member_name`. 4. `torchaudio.load(audio_path)` reads the extracted FLAC. 5. If stereo, average channels to mono. 6. If sample rate is not `16000`, resample to `16000`. 7. Tokenize `row["text"]` with the active SentencePiece tokenizer. 8. Truncate token sequence to `MAX_TOKEN_LEN=512`. 9. Return: - `signal` - `tokens` - `duration` - `lang` ### Corrupt audio behavior If audio decode fails: - it returns `None` - the collate function drops the sample - no fake silence/text pair is introduced That is the correct behavior and avoids poisoning training. ## 5. How The Sampler Decides What Goes Into A Batch This is the most important part for speed and language balance. ### 5.1 Duration bucketing The configured bucket boundaries are: - `3.0` - `5.0` - `8.0` - `12.0` - `16.0` So durations are effectively grouped into 6 ranges: - bucket 0: `<= 3s` - bucket 1: `3-5s` - bucket 2: `5-8s` - bucket 3: `8-12s` - bucket 4: `12-16s` - bucket 5: `16-20s` This is done via: ```103:126:/alloc/parakeet_tdt_finetune/dataset.py def _assign_bucket(self, dur): return bisect.bisect_right(self.boundaries, dur) ... buckets = defaultdict(list) for i, dur in enumerate(self.durations): buckets[self._assign_bucket(dur)].append(i) ``` Why this matters: - short clips get batched with short clips - long clips get batched with long clips - padding waste is much lower than random batching ### 5.2 Temperature language balancing This is the multilingual balancing step. Code: ```106:123:/alloc/parakeet_tdt_finetune/dataset.py rng = random.Random(self.seed + self.epoch) lang_counts = defaultdict(int) for lang in self.languages: lang_counts[lang] += 1 total = len(self.durations) if self.temperature <= 0: lang_w = {l: 1.0 for l in lang_counts} else: lang_w = {l: math.pow(c, self.temperature) for l, c in lang_counts.items()} ws = sum(lang_w.values()) lang_props = {l: w / ws for l, w in lang_w.items()} sample_weights = [] for lang in self.languages: actual = lang_counts[lang] / total target = lang_props[lang] sample_weights.append(target / actual if actual > 0 else 0) ``` Current `TEMPERATURE=0.3`. What that means: - natural language frequencies are flattened - high-resource languages like English get down-weighted - low-resource languages like Assamese/Odia get up-weighted Mechanically: - if a language is overrepresented, its `sample_weight < 1` - if underrepresented, `sample_weight > 1` ### 5.3 Weighted random ordering inside each bucket Inside each duration bucket, it does: ```129:143:/alloc/parakeet_tdt_finetune/dataset.py for b_id, indices in buckets.items(): keyed = [(rng.random() ** (1.0 / max(sample_weights[i], 1e-12)), i) for i in indices] keyed.sort(reverse=True) sorted_idx = [i for _, i in keyed] batch, batch_dur = [], 0.0 for idx in sorted_idx: d = self.durations[idx] if batch_dur + d > self.max_dur and batch: all_batches.append(batch) batch, batch_dur = [], 0.0 batch.append(idx) batch_dur += d ``` That `U^(1/w)` trick is a weighted random-without-replacement ranking: - larger `w` pushes a sample earlier in the ordering - smaller `w` pushes it later So low-resource-language samples are more likely to be consumed earlier as batches fill up. ### 5.4 Batch packing Within each bucket: - start an empty batch - keep adding utterances until total raw duration would exceed `MAX_BATCH_DURATION=90` - then start a new batch So batches are **audio-duration-capped**, not sample-count-capped. Approximate sample counts per micro-batch: - `1-3s`: roughly `30-45` - `3-5s`: roughly `18-22` - `5-8s`: roughly `11-15` - `8-12s`: roughly `7-10` - `12-16s`: roughly `5-7` - `16-20s`: roughly `4-5` ## 6. How Batches Are Distributed Across 8 GPUs After batch formation, the sampler does two more things. ### 6.1 Sort by compute cost ```145:152:/alloc/parakeet_tdt_finetune/dataset.py def _batch_cost(b): max_dur = max(self.durations[i] for i in b) return len(b) * max_dur all_batches.sort(key=_batch_cost, reverse=True) usable = (len(all_batches) // self.world_size) * self.world_size all_batches = all_batches[:usable] rank_batches = all_batches[self.rank:usable:self.world_size] ``` This is subtle and important. It sorts batches by approximate compute cost: - `batch_cost = number_of_samples * max_duration_in_batch` Then each rank gets every `world_size`-th batch. Why: - spreads heavy and light batches more evenly across GPUs - reduces DDP straggler effects - keeps step times closer across ranks ### 6.2 Trim for grad accumulation alignment ```154:160:/alloc/parakeet_tdt_finetune/dataset.py if self.grad_accum > 1: trim = len(rank_batches) - (len(rank_batches) % self.grad_accum) rank_batches = rank_batches[:trim] rng.shuffle(rank_batches) self._num_batches = len(rank_batches) ``` Because `GRAD_ACCUM=4`, each rank must have a number of micro-batches divisible by 4. Otherwise one rank would finish an optimizer step while another rank did not, which would break DDP synchronization. So it trims the tail to a multiple of 4. Then it shuffles the rank-local batch order. ## 7. What The DataLoader Actually Emits The DataLoader uses `batch_sampler`, not a normal sampler: ```443:460:/alloc/parakeet_tdt_finetune/train.py dataset = LocalShardDataset( manifest_path=cfg.TRAIN_MANIFEST_PATH, tokenizer=model.tokenizer, ) ... train_dl = DataLoader( dataset, batch_sampler=sampler, collate_fn=collate_asr_batch, num_workers=cfg.NUM_WORKERS, pin_memory=True, persistent_workers=True, ) ``` Current runtime settings: - `NUM_WORKERS=12` **per rank** - `8` GPU ranks - so up to `96` DataLoader workers total The collate function does: ```173:193:/alloc/parakeet_tdt_finetune/dataset.py def collate_asr_batch(batch): """Collate for NeMo RNNT: (signal, signal_len, tokens, token_len).""" batch = [item for item in batch if item is not None] if not batch: return (torch.zeros(1, 1), torch.tensor([1], dtype=torch.long), torch.zeros(1, 1, dtype=torch.long), torch.tensor([0], dtype=torch.long)) ... return signal_padded, signal_lens, token_padded, token_lens ``` So for every micro-batch it returns: - `signal_padded`: `[B, T_audio]` - `signal_lens`: `[B]` - `token_padded`: `[B, T_text]` - `token_lens`: `[B]` This is exactly what NeMo RNNT expects. ## 8. Effective Batch Size And What The Progress Bar Means This is another place people get confused. ### Micro-batch Per GPU: - max `90s` of raw audio ### Optimizer step Because: - `8` GPUs - `GRAD_ACCUM=4` effective audio per optimizer step is: - `90 * 8 * 4 = 2880 seconds` That is **48 minutes of audio per optimizer step**. ### Progress bar vs global step The progress bar counts **micro-batches**, not optimizer steps. So: - if progress bar advances by `4`, global step advances by `1` - checkpoint every `5000` global steps means roughly every `20000` progress-bar batches This is why the progress bar can look far ahead of checkpoint step numbers. ## 9. Epoch Meaning In This Pipeline An "epoch" here means: - one full pass through the sampler’s constructed batch list - not a classic fixed-size random-shuffled dataset epoch in the small-dataset sense Right now, one epoch is around `572,849` micro-batches. Since the corpus is huge, the run is mostly **step-budget-driven**: - `MAX_STEPS=600000` - epoch counts are just a side effect of how many packed batches exist ## 10. Model-Side Training Schedule The data pipeline feeds into a 2-phase optimization schedule. From `config.py`: - `PHASE1_LR = 3e-4` - `PHASE2_LR = 1e-4` - `UNFREEZE_STEP = 40000` - `WARMUP_STEPS = 5000` - `PHASE2_WARMUP = 3000` From `train.py`, the LR is overridden per batch by `EncoderUnfreezeCallback`. So: - phase 1: encoder frozen, decoder/joint learn first - phase 2: encoder unfrozen, full model trains This is separate from batching, but it matters because the same batch pipeline feeds two different training regimes. ## 11. What Is Disabled Right Now Current code still contains a `PerLanguageWERCallback` class, but **it is not attached** to the trainer callbacks anymore, and validation is effectively disabled. The active callbacks are: ```472:485:/alloc/parakeet_tdt_finetune/train.py callbacks = [ EncoderUnfreezeCallback(), TrainingMetricsCallback(), EpochBoundarySyncCallback(), LearningRateMonitor(logging_interval="step"), StepOnlyModelCheckpoint( dirpath=str(cfg.CHECKPOINT_DIR), filename="parakeet-tdt-{step}", every_n_train_steps=cfg.CHECKPOINT_EVERY, save_top_k=-1, save_last=False, every_n_epochs=0, ), ] ``` So currently: - no validation loop - no custom eval callback - only training + logging + checkpointing ## 12. Important Quirks / Non-Obvious Details A few subtle things to know: - `train.py` header comment still says "audio read from tar files". That is stale. Current `dataset.py` reads extracted FLACs directly. - `tar_path` in the manifest is also effectively a legacy field now. The file itself may not exist anymore because tars were deleted after extraction. The dataset only uses `dirname(tar_path)` to locate the shard directory. - `source` exists in the manifest, but **source is not used for sampling**. Only `lang` and `duration` affect sampling. - The whole train manifest is loaded into pandas at startup. On this machine that is acceptable because RAM is huge. - WER reference/predicted lines you see in logs are NeMo’s internal logging during training, not the custom validation path. ## Short Version If I compress the entire pipeline into one paragraph: We train from `train_manifest_v2.parquet`, which already contains cleaned, language-tagged transcripts and pointers to extracted FLAC files on local NVMe. `LocalShardDataset` loads the full manifest into memory, and each `__getitem__` reads one FLAC, converts it to mono 16kHz, tokenizes the pre-cleaned text, and returns audio + token tensors. `DurationBucketBatchSampler` groups samples into 6 duration buckets, temperature-rebalances languages with `T=0.3`, greedily packs each micro-batch to `90s` of audio, sorts batches by compute cost, shards them deterministically across 8 DDP ranks, trims to a multiple of `GRAD_ACCUM=4`, and shuffles per epoch. The DataLoader pads within-batch audio and token lengths. Each optimizer step therefore sees `2880s` of audio total across 8 GPUs and 4 micro-batches, while checkpoints are saved every 5000 optimizer steps. If you want, I can also give you: 1. a **visual diagram** of the pipeline, or 2. a **worked example** tracing one single sample from raw shard metadata all the way into a padded DDP micro-batch.