# LeWM TTS

JEPA-based Text-to-Speech system adapted from [LeWorldModel](https://le-wm.github.io/).

## Architecture

A 24M parameter model with 4 core components:

```
Text (Hindi/UTF-8 bytes) --> TextEncoder --> text embeddings ---+
                                                                |--> JEPAPredictor --> predicted embeddings
Audio (mel spectrogram)  --> AudioEncoder --> audio embeddings --+           |
                                                                            v
                                                                       MelDecoder --> mel spectrogram --> Vocos --> waveform
```

| Component | Description |
|-----------|-------------|
| **TextEncoder** | Byte-level character transformer (4 layers, 256d, 4 heads). Encodes Hindi text via UTF-8 bytes — no tokenizer needed. |
| **AudioEncoder** | 1D CNN (stride-4 downsample) + Transformer (4 layers). Outputs mu/logvar for Gaussian latent space with reparameterization. |
| **JEPAPredictor** | Transformer decoder (6 layers) with causal self-attention + cross-attention to text. Predicts next audio embedding autoregressively. Has KV-cache for fast inference. |
| **MelDecoder** | Linear projection to 512d, 2x transposed convolution (4x upsample), residual conv blocks (ResConvBlock), and refinement layers. Outputs 100-mel spectrogram. |
| **Vocoder** | [Vocos](https://github.com/gemelo-ai/vocos) `charactr/vocos-mel-24khz` — directly decodes 100-mel to 24kHz waveform. No Griffin-Lim. |

### Speaker Embedding (planned)
Built-in support for multi-speaker via learned embedding table added to text embeddings. Not yet trained — target is 10-20 speaker multi-speaker TTS.

## Training Losses

4 losses, all active during training:

1. **Prediction Loss (MSE)** — next-embedding prediction: `MSE(predicted[t], target[t+1])`
2. **KL Loss** — Gaussian regularizer: `KL(q(z|x) || N(0, I))` with annealing over first 20% of training
3. **Reconstruction Loss (L1)** — mel decoder output vs ground truth mel spectrogram
4. **Multi-Resolution Spectral Loss** — STFT at 3 resolutions (64/128/256 FFT sizes), combines spectral convergence + log magnitude L1. Forces sharp spectral detail instead of blurry averages.

## Dataset

- **Source**: 20 hours of studio-quality single-speaker Hindi audio
- **Preprocessing**: silence-based segmentation into 3-12s chunks, mel extraction
- **Processed**: 7,009 segments, 17.88 hours
- **Mel config**: 100 mels, n_fft=1024, hop=256, sr=24000, f_max=None (full bandwidth), log-clamped at 1e-7 — matches Vocos `mel-24khz` exactly

## Files

| File | Purpose |
|------|---------|
| `model.py` | All model components: TextEncoder, AudioEncoder, JEPAPredictor, MelDecoder, MultiResolutionSpectralLoss, LeWMTTS |
| `train.py` | Training loop with KL annealing, gradient clipping, TensorBoard logging, checkpointing |
| `inference.py` | Single-request inference — AR synthesis, teacher-forced reconstruction |
| `inference_engine.py` | Production batched inference engine — KV-cache, FP16, voice prompts, benchmarking |
| `preprocess.py` | Data preprocessing — audio segmentation, mel extraction, manifest creation |
| `dataset.py` | PyTorch dataset + dataloader with dynamic batching and padding |

## Inference

### Modes
- **Synthesize**: Pure autoregressive generation from text. Requires audio prompt (voice cloning) for best results.
- **Reconstruct**: Encode real mel -> decode back (tests encoder/decoder quality)
- **Predict (teacher-forced)**: Encode mel -> predictor -> decode (tests full pipeline without AR drift)

### Performance (A100 80GB)
- Single request with KV-cache: 9x real-time
- Batched (batch=64): 359x real-time throughput

### Voice Prompting
The inference engine supports audio prompts — a reference clip is encoded and used to seed autoregressive generation, effectively providing voice cloning without explicit speaker conditioning.

## Training History

### v1 — 80 mels + Griffin-Lim + Vocos
- Simple MelDecoder (2x transposed conv), 80-mel output
- Vocoder: Griffin-Lim -> Vocos refinement pipeline
- Result: teacher-forced quality decent but "rustic" sound from GL artifacts

### v2 — DAC-based
- Experimental DAC codec variant
- Checkpoints preserved at `v2/output/checkpoints/`

### v3 — 100 mels + upgraded decoder + spectral loss (current)
- Direct 100-mel output matching Vocos natively (no GL middleman)
- Upgraded MelDecoder with ResConvBlock residual layers, 512 hidden channels
- Multi-resolution spectral loss for sharp mel output
- 24M parameters, training on `processed_data_100mel/`
- Output: `output_v3/`

## Usage

### Preprocessing
```bash
python3 preprocess.py --input /path/to/dataset --output /path/to/processed_data
```

### Training
```bash
python3 train.py \
  --data_dir processed_data_100mel \
  --output_dir output_v3 \
  --epochs 200 \
  --batch_size 256 \
  --spectral_weight 0.5
```

### Inference
```bash
# Synthesize from text
python3 inference.py --checkpoint output_v3/checkpoints/best.pt --text "your text" --mode synthesize

# Teacher-forced reconstruction
python3 inference.py --checkpoint output_v3/checkpoints/best.pt --mel path/to/mel.pt --mode predict
```

### Batched Inference
```bash
python3 inference_engine.py --checkpoint output_v3/checkpoints/best.pt --benchmark
```

## Reference
- LeWM paper: https://le-wm.github.io/
