# LeWM TTS Project

## Overview
Adapting LeWM (LeWorldModel) JEPA architecture for Text-to-Speech synthesis.
- Paper: https://le-wm.github.io/
- Project folder: `/home/ubuntu/lewm-tts`
- Goal: 10-20 speaker multi-speaker TTS on this 14M→24M model

## Architecture (v3c — Current)
- **TextEncoder**: Byte-level char transformer (4 layers, 256d, 4 heads)
- **AudioEncoder**: 1D CNN + Transformer (4 layers) → mu/logvar → reparameterization
  - **downsample_factor**: configurable (2 or 4), stored in config dict
- **JEPAPredictor**: Transformer decoder (6 layers) with causal self-attn + cross-attn to text + **KV-cache**
- **MelDecoder**: Linear→512 + ConvTranspose1d + ResConvBlock residual blocks + refinement conv
  - **upsample_factor**: matches downsample_factor
- **Vocoder**: Direct Vocos decode (model outputs 100 mels matching `charactr/vocos-mel-24khz`)
- **4 losses**: prediction (MSE), KL (Gaussian reg), reconstruction (L1 mel), multi-resolution spectral loss
- **Speaker embedding**: support for n_speakers > 1 (not yet trained)
- **KL annealing**: ramps up over first 20% of training
- Mel preprocessing matches Vocos exactly (verified: mel diff = 0.0000)

## Training Runs
- **v3** (4x downsample, kl=0.1, recon=1.0, spec=0.5): 24M params. Robotic + noise.
- **v3b** (4x downsample, kl=0.01, recon=2.0, spec=1.5): 24M params. Voice audible but noisy at ep90 (recon=0.56). Killed at ep96.
- **v3c** (2x downsample, kl=0.01, recon=2.0, spec=1.5): 21.1M params. **Currently training** at output_v3c/. ~41s/epoch. At ep110 recon≈0.54, voice coming through with noise.

## Key Findings (This Session)
- **KL weight 0.1 too aggressive** — crushes latent space, mel decoder can't reconstruct sharp formants
- **4x downsample bottleneck** — decoder must hallucinate fine detail for 4 frames per latent. 2x is better.
- **Model error is structured** (blurred formants), not random — L1=0.51 model recon sounds worse than L1=0.84 random noise on GT mel
- At L1≈0.51 both random noise and model recon produce voice with noise — decoder IS learning, just needs lower error
- **AR synthesis collapses** — predictor pred_loss=0.02 (TF) but AR free-running converges to silence. Needs scheduled sampling.
- TF prediction ≈ reconstruction quality (predictor is fine, bottleneck is mel decoder)

## Files
- `preprocess.py` — segments audio, extracts mels, saves manifest
- `model.py` — all model components, now supports configurable downsample_factor (2 or 4)
- `dataset.py` — TTSDataset, collate_fn, build_dataloader
- `train.py` — training loop, now supports --downsample_factor arg
- `inference.py` — AR generation with Vocos vocoder (single request)
- `inference_engine.py` — Production engine: batched inference, KV-cache, FP16, voice prompts

## Dataset
- **Raw**: `/home/ubuntu/modi_dataset` (2,330 clips, 18.25h, Hindi, 24kHz)
- **100-mel processed**: `/home/ubuntu/lewm-tts/processed_data_100mel` (7,009 segments, 17.88h)
- Mel config: n_fft=1024, hop=256, n_mels=100, sr=24000, f_max=None, power=1.0, clip_val=1e-7

## Next Steps
1. Let v3c finish 200 epochs, test quality
2. Add **scheduled sampling** to training (fix AR collapse)
3. If mel quality sufficient, proceed to multi-speaker training
