# Cohere Transcribe — Indic Language Extension

Finetuning [CohereLabs/cohere-transcribe-03-2026](https://huggingface.co/CohereLabs/cohere-transcribe-03-2026) (2B Fast-Conformer encoder + Transformer decoder) for 11 Indic languages.

## Overview

| Metric | Value |
|--------|-------|
| Base model | CohereLabs/cohere-transcribe-03-2026 (2.15B params) |
| Languages | English + Hindi, Telugu, Tamil, Malayalam, Bengali, Gujarati, Kannada, Punjabi, Marathi, Odia, Assamese |
| Training data | 73M utterances, 145K hours (cleaned from 74.5M / 147K hours) |
| Tokenizer | Extended from 16K → 59,554 vocab (32K Indic subwords added) |
| Hardware | 8x NVIDIA H200 (141GB HBM3e each) |
| Training time | ~5 days (2 epochs) |

## Tokenizer Extension

The original 16K BPE tokenizer had zero Indic subwords — every Indic character fell back to 3 UTF-8 bytes (~13 tokens/word for Hindi vs ~1.2 for English).

We trained a 32K SentencePiece unigram model on 57M Indic transcripts, then merged the new subwords into the existing BPE tokenizer via synthetic merge rules.

**Result: 89.5% token reduction for Indic text** (13 tok/word → 1.5 tok/word), with zero impact on English tokenization.

| Language | Old tok/word | New tok/word | Reduction |
|----------|-------------|--------------|-----------|
| Hindi | 11.95 | 1.22 | 89.8% |
| Telugu | 16.58 | 1.54 | 90.7% |
| Tamil | 22.00 | 1.91 | 91.3% |
| Bengali | 13.97 | 1.18 | 91.6% |
| English | 1.17 | 1.17 | 0.0% |

## Data Pipeline

### Data Cleaning (applied to training manifest)
- Removed empty/short transcripts, REDACTED, `<unintelligible>` tags
- Stripped `[singing]`, `[laugh]`, `[noise]` annotations
- Removed segments < 0.5s or > 30s
- Removed very slow speech rate (< 2 chars/sec, likely misaligned)
- Normalized ellipses, zero-width spaces, native digits → Arabic
- Fixed language tag mismatches (relabeled Telugu-in-wrong-lang, dropped ambiguous Devanagari)
- Deduplicated cross-language and within-shard duplicates
- **Net: 72,998,025 rows / 145,139 hours (2.06% removed)**

### Pre-computed Mel Spectrograms
- 128-channel log-mel, 16kHz, 100fps, per-feature normalized
- Extracted from tar shards to individual `.npy` files for fast random access
- Per-shard parquet indexes for streaming reads

## Training Configuration

```yaml
model: 59,554 vocab (16K original + 43K Indic subwords + intermediates)
batch: 128 utterances/GPU × 8 GPUs = 1024 effective batch
lr: 2e-4, cosine decay, 2% warmup
epochs: 2
precision: bf16
strategy: DDP
gradient_checkpointing: manual (encoder conformer layers)
label_smoothing: 0.1 (via NLLLoss — model head outputs log-probs)
```

### Key Implementation Details
- **Decoder prompt**: 9 atomic control tokens per sample (language-conditioned)
- **Label shift**: Model does NOT shift labels internally — `labels[i] = decoder_input_ids[i+1]`
- **Loss**: NLLLoss (not CrossEntropyLoss) because model head applies log_softmax
- **Weight tying**: Decoder embedding and output head share weights
- **Temperature sampling**: τ=5.0 for language balancing (upsamples Assamese/Odia)

## Repository Structure

```
training/
├── train.py                 # Main trainer (DDP, grad checkpointing, wandb)
├── dataset.py               # Streaming dataset (tar + extracted .npy support)
├── dataset_fast.py           # Fast dataset (direct .npy reads, bucketed batching)
├── config.yaml              # Production training config
├── config_smoke.yaml        # Smoke test config (small subset)
├── evaluate.py              # WER/CER evaluation script
├── tokenizer_utils.py       # Extended tokenizer loader + decode helper
├── smoke_test.py            # Startup validation (alignment, forward/backward, generate)
├── resize_embeddings.py     # Embedding resize + weight tying for extended vocab
├── extend_tokenizer.py      # SentencePiece training + tokenizer extension pipeline
├── merge_tokenizer_final.py # BPE merge with synthetic merge rules
├── push_to_r2.sh            # Checkpoint upload to Cloudflare R2
└── tokenizer_extension/
    └── extended_model/      # Extended checkpoint (59K vocab, bf16)
        ├── model.safetensors
        ├── config.json
        ├── tokenizer.json
        └── modeling_cohere_asr.py

maya-asr/
├── manifests/
│   ├── training_manifest_cleaned.parquet  # 73M rows, cleaned
│   └── training_manifest_smoke.parquet    # 60K rows, smoke test
├── mel_extracted/           # Extracted mel .npy files (from tar shards)
├── mel_extracted_index/     # Per-shard parquet indexes
├── clean_manifest.py        # Data cleanup pipeline
└── build_shard_indexes.py   # Per-shard index builder
```

## Inference

```python
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import torch

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    "path/to/checkpoint/model",
    trust_remote_code=True,
    dtype=torch.bfloat16,
).cuda()

processor = AutoProcessor.from_pretrained(
    "path/to/checkpoint/processor",
    trust_remote_code=True,
)

# For correct Indic decoding, use the tokenizer_utils wrapper:
from tokenizer_utils import load_extended_tokenizer, decode_tokens
tokenizer = load_extended_tokenizer("path/to/checkpoint/processor")

# Transcribe
result = model.transcribe(
    audio_file="audio.wav",
    language="hi",        # hi, te, ta, ml, bn, gu, kn, pa, mr, or, as, en
    punctuation=True,
    processor=processor,
)
print(result["transcription"])
```

## Checkpoints

Checkpoints are saved every 2000 steps to local disk and pushed to R2 at milestones:
- `r2:ptcheckpoints/cohere-transcribe/04-05-2026/ckpt-{step}/`

Each checkpoint contains:
- `model/model.safetensors` — trained weights (bf16)
- `model/config.json` — model architecture
- `processor/tokenizer.json` — extended 59K tokenizer
- `training_state.pt` — optimizer state for resume

## Wandb

Training tracked at: [maya-asr-cohere-transcribe](https://wandb.ai/maya-research-maya-research/maya-asr-cohere-transcribe)
