# Maya ASR

Multilingual Indic automatic speech recognition system built on NVIDIA NeMo. Uses a FastConformer Hybrid TDT+CTC architecture with a Parakeet TDT 1.1B pretrained encoder, fine-tuned on ~150K hours of speech across 12 languages (11 Indic + English).

## Architecture

Maya ASR is based on NeMo's `EncDecHybridRNNTCTCBPEModel` — a FastConformer encoder with dual decoding heads (TDT transducer + auxiliary CTC).

```
Audio (16 kHz) → Mel Spectrogram (80-dim) → FastConformer Encoder (42 layers, 1024-dim)
                                                      │
                                          ┌───────────┴───────────┐
                                          ▼                       ▼
                                   TDT Decoder+Joint         CTC Head
                                   (primary loss)          (auxiliary loss)
```

### Model Components

| Component | Architecture | Parameters |
|---|---|---|
| **Encoder** | 42-layer FastConformer (Parakeet TDT 1.1B pretrained) | ~1,062M |
| **TDT Decoder** | 2-layer RNN (pred_hidden=640) | ~7M |
| **Joint Network** | Joint hidden=640, 5 TDT duration heads | ~42M |
| **CTC Head** | Linear projection over encoder output | ~33M |
| **Language Embedding** | 12 languages → 1024-dim additive bias | ~12K |
| **Total** | | ~1,144M |

### Tokenizer

- 32K BPE (SentencePiece), trained on quality-filtered text from all 12 languages
- Located at `tokenizers/stage1_prod_bpe/`

### Supported Languages

Hindi (hi), Bengali (bn), Tamil (ta), Telugu (te), Marathi (mr), Gujarati (gu), Kannada (kn), Malayalam (ml), Punjabi (pa), Odia (or), Assamese (as), English (en).

## Training

Training is split into two stages:

### Stage 1: Acoustic Model with TDT (current)

Fine-tunes the Parakeet TDT 1.1B encoder on multilingual Indic data with a randomly initialized decoder, joint network, and CTC head sized for the 32K multilingual vocabulary.

**Key optimizations:**

- **TDT loss** — matches the pretrained encoder's original objective (durations `[0,1,2,3,4]`, sigma=0.02, omega=0.1). TDT allows the model to skip multiple frames per token emission, improving both speed and accuracy over standard RNNT.
- **Differential learning rate** — encoder at 2e-5 (preserves pretrained features), decoder/joint/CTC heads at 1e-3 (fast adaptation for new vocabulary).
- **Encoder freeze** — encoder is frozen for the first 5K optimizer steps so randomly initialized heads can warm up without corrupting pretrained weights.
- **CTC warmup** — CTC loss weight ramps from 0 to 0.3 over the first 3K steps, preventing random CTC gradients from interfering with early encoder adaptation.
- **Language embedding** — a learned per-language additive bias (12 languages to 1024-dim) is injected into the encoder output. Initialized near-zero (std=0.02) so it does not disrupt pretrained features.
- **Fused joint** — `fuse_loss_wer: true` with `fused_batch_size: 4` computes the joint tensor in sub-batches to cap peak VRAM.
- **Triple-cap OOM protection** — each micro-batch is constrained by three independent limits: max audio duration (120s), max sample count (per-bucket caps), and max total tokens (400). The token cap is critical because the TDT joint tensor size is `B x T_enc x T_dec x (V + 1 + durations)` and `T_dec` depends on transcript length, not audio length.

**Training schedule:**

| Phase | Steps | What happens |
|---|---|---|
| 0 – 3K | 0–3,000 | Encoder frozen, CTC weight ramps 0→0.3 |
| 3K – 5K | 3,000–5,000 | Encoder still frozen, CTC at full weight |
| 5K+ | 5,000–200,000 | Encoder unfrozen at 2e-5 LR, full TDT+CTC training |

**Optimizer:** AdamW (betas=0.9/0.98, weight_decay=1e-4) with CosineAnnealing schedule (2K warmup steps, min_lr=1e-5).

### Stage 2: Gemma LLM Decoder (planned)

Replaces the RNN decoder with a Gemma LLM for improved language modeling, particularly for low-resource Indic languages.

## Running Training

### Prerequisites

- 8x NVIDIA H200 GPUs (140 GB HBM3 each)
- Python 3.10+
- NeMo toolkit with ASR dependencies

### Setup

```bash
# Install package and dev dependencies
make setup

# For training dependencies (NeMo, transformers, deepspeed)
pip install -e ".[train]"
```

### Production Training (8x H200)

```bash
# 1. Build production data artifacts (manifests, splits, tokenizer, input config)
make prod-artifacts

# 2. Run full readiness check (disk, GPUs, config, artifacts, R2 storage)
make prod-readiness

# 3. Launch training (runs readiness gate, starts checkpoint watcher, then trains)
make prod-launch
```

Or run the training script directly:

```bash
python3 scripts/train_prod.py \
  --config configs/train/stage1_prod_8xh200.yaml \
  --train-parquet artifacts/phase3/production_train_final.parquet \
  --val-manifest data/manifests/stage1_prod_val_v2.jsonl \
  --pretrained-encoder pretrained/parakeet_tdt_1.1b_encoder.pt
```

### Smoke Test (single GPU)

```bash
python3 scripts/train_prod.py \
  --config configs/train/stage1_prod_8xh200.yaml \
  --max-steps 100 --devices 1 --log-every 10 --val-every 50 --smoke
```

### Key Training Arguments

| Argument | Default | Description |
|---|---|---|
| `--encoder-lr` | 2e-5 | Learning rate for pretrained encoder |
| `--head-lr` | 1e-3 | Learning rate for decoder/joint/CTC |
| `--freeze-encoder-steps` | 5000 | Steps to freeze encoder |
| `--ctc-warmup-steps` | 3000 | Steps to ramp CTC loss weight |
| `--max-batch-dur` | 120.0 | Max audio seconds per micro-batch |
| `--max-batch-size` | 16 | Hard cap on samples per micro-batch |
| `--max-tokens-in-batch` | 400 | Max total tokens per batch (joint tensor guard) |
| `--temperature` | 0.3 | Language rebalancing temperature (0=uniform, 1=natural) |
| `--grad-accum` | 4 | Gradient accumulation steps |
| `--no-lang-embed` | false | Disable language embedding |

## Inference

Maya ASR models are standard NeMo ASR models. After training produces a `.nemo` checkpoint, use NeMo's transcribe API:

```python
import nemo.collections.asr as nemo_asr

# Load the trained model
model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from("path/to/maya_asr.nemo")
model.eval()
model.cuda()

# Transcribe audio files
transcriptions = model.transcribe(["audio1.wav", "audio2.wav"])
print(transcriptions)
```

For TDT-specific decoding (using duration-aware greedy search):

```python
from omegaconf import OmegaConf

# The model defaults to TDT greedy decoding as configured in training
# To adjust decoding parameters:
decoding_cfg = OmegaConf.create({
    "strategy": "greedy",
    "model_type": "tdt",
    "durations": [0, 1, 2, 3, 4],
    "greedy": {"max_symbols": 10},
})
model.change_decoding_strategy(decoding_cfg)
transcriptions = model.transcribe(["audio.wav"])
```

To use the CTC head instead of TDT (faster but slightly less accurate):

```python
transcriptions = model.transcribe(["audio.wav"], return_hypotheses=False, verbose=False)
# Or explicitly switch to CTC decoding:
model.cur_decoder = "ctc"
transcriptions = model.transcribe(["audio.wav"])
```

## Data Pipeline

The training data flows through a multi-stage pipeline:

```
Raw audio tars (per-language shards on disk)
  │
  ▼
build_manifest.py → Raw manifest JSONL (path, duration, transcript, language)
  │
  ▼
split_manifest.py → Train/Val splits (deterministic, stratified)
  │
  ▼
build_tokenizer.py → 32K BPE tokenizer (SentencePiece, quality-filtered)
  │
  ▼
Phase 3 Production Parquet (tar_path, tar_offset_data, tar_nbytes, transcript, duration_s, language)
  │
  ▼
TarOffsetReader (os.pread, zero-extraction from existing tars)
  │
  ▼
ProductionBatchSampler (duration bucketing + language rebalancing + triple-cap OOM protection)
  │
  ▼
NeMo training loop
```

**Key data features:**

- **Tar-offset loading** — reads audio directly from tar files using byte offsets (`os.pread`), avoiding tar extraction. Zero disk overhead.
- **Duration bucketing** — groups similar-length clips into buckets `[0-4, 4-8, 8-12, 12-16, 16-20]s` to minimize padding waste.
- **Language temperature rebalancing** (T=0.3) — upweights low-resource languages so they are seen more frequently during training.
- **Compute-cost DDP sharding** — distributes batches across GPUs balanced by compute cost, not just sample count.
- **Token-length-aware batching** — pre-computes token counts per sample and caps total tokens per batch at 400, directly controlling the dominant VRAM consumer (joint tensor `T_dec` dimension).

## Hardware Requirements

### Production Training (Stage 1)

- **GPUs:** 8x NVIDIA H200 (140 GB HBM3 each)
- **Precision:** BF16 mixed
- **VRAM breakdown (per GPU):**
  - Model weights (bf16): ~4.5 GB
  - Optimizer states (AdamW, fp32 copies): ~13.5 GB
  - Activations + joint tensor: ~75 GB budget (75% safety margin)
- **Effective batch:** 8 GPUs x micro-batch x 4 gradient accumulation steps
- **Training target:** 200K optimizer steps, ~5-7 days
- **Checkpoints:** saved every 5K steps, uploaded to R2 object storage, validated (size match) before local deletion

### Inference

- **Minimum:** 1x GPU with 8+ GB VRAM (model is ~2.3 GB in bf16)
- **Recommended:** 1x GPU with 16+ GB VRAM for comfortable batch inference

## Project Structure

```
configs/
  train/              # Training YAML configs
  data/               # Data pipeline configs
scripts/
  train_prod.py       # Production training script (TDT + DiffLR + lang embed)
  build_manifest.py   # Raw manifest builder
  split_manifest.py   # Train/val splitter
  build_tokenizer.py  # BPE tokenizer builder
  run_stage1_prod.sh  # Production launcher (readiness + watcher + train)
lf_asr/
  data/
    production_sampler.py  # Triple-cap batch sampler with language rebalancing
    tar_offset_reader.py   # Zero-extraction tar audio reader
src/maya_asr/         # Core library
tokenizers/           # Trained BPE tokenizer artifacts
pretrained/           # Pretrained encoder weights
artifacts/            # Data processing outputs (parquet)
experiments/          # Training outputs (gitignored)
tests/
  unit/               # Unit tests (CI-safe, no data dependency)
  integration/        # Integration tests (requires data)
```

## Make Targets

| Target | Description |
|---|---|
| `make setup` | Install package + dev deps |
| `make test` | Run unit tests only |
| `make test-all` | Run unit + integration tests |
| `make lint` | Ruff lint check |
| `make format` | Auto-format + fix |
| `make smoke` | Import check + unit suite |
| `make prod-artifacts` | Build production data artifacts |
| `make prod-readiness` | Production launch gate (disk, GPUs, config, artifacts, R2) |
| `make prod-launch` | Fail-fast production launcher (readiness + watcher + training) |
| `make validate-r2` | End-to-end R2 upload/download/integrity test |

## License

MIT
