# LeWM TTS - JEPA-based Hindi TTS

## Key Issues Found (v3/v3b/v3c)
1. Missing EMA target encoder (core JEPA component)
2. Prediction loss was <1% of total loss — predictor learned identity mapping
3. KL divergence exploded to 2.45 (kl_weight=0.01 too low)
4. v3c's downsample_factor=2 made consecutive embeddings too similar
5. Random start embedding at inference (never trained)
6. No input noise to bridge train/inference gap

## v4 Fixes Applied
- **EMA target encoder**: deepcopy of AudioEncoder, updated with momentum 0.998
- **Rebalanced losses**: pred_weight=10.0 (now ~36% of total), kl_weight=0.05
- **Free-bits KL**: threshold=1.0, prevents both collapse and explosion
- **Learnable start_emb**: prepended to training sequence, receives gradients
- **Input noise**: annealed 0→0.2 over first 30% of training
- **Downsample factor**: back to 4 (more temporal separation)

## v4 Results
- MelDecoder was the bottleneck (L1=0.57 even with perfect embeddings)
- JEPA predictor worked well: cosine_sim=0.97, 6.2x better than copy baseline
- Recon had noise overlay, synthesis was pure noise

## v5 Architecture (codec-based, simplified)
- EnCodec 24kHz (frozen): audio ↔ 128d embeddings @ 75Hz, near-transparent quality
- Simple linear projections: proj_in(128→256→256), proj_out(256→256→128) — NO downsampler
- TextEncoder(4L,256d) + JEPAPredictor(6L decoder) at 75Hz
- Projections trained first (output_v5/), then frozen for predictor training (output_v5_pred/)
- Predictor reinitialized from scratch with frozen proj_in targets (no EMA — EMA caused divergence)
- Files: model_v5.py, train_v5.py, train_v5_pred.py, dataset_v5.py, inference_v5.py, preprocess_codec.py
- Data: processed_data_codec/ (7009 samples, EnCodec embeddings)

## Key Findings
- Down/upsampler CNN approach: FAILED (noisy roundtrip)
- EMA on small proj_in: FAILED (pred_loss diverged 0.12→0.38)
- Voice prompting works great: encode real audio as prefix, AR-continue from there
- Pure AR at 75Hz collapses after ~50 steps (too many steps for small model)
- Roundtrip (proj_in→proj_out→decode): clean audio, L1=0.27
- TF prediction: improving, L1=0.62 at epoch 45

## Dataset
7009 segments, 17.88h single-speaker Hindi, processed_data_100mel/ and processed_data_codec/
