flax
flax copied to clipboard
Migrate VAE example to Flax NNX
Migrate VAE Example to Flax NNX with JIT Optimization
Summary
This PR migrates the VAE (Variational Autoencoder) example from Flax Linen to Flax NNX, the new simplified API. The migration includes proper use of @nnx.jit decorators for significant performance improvements.
Changes
1. Model Architecture (models.py)
- ✅ Migrated from
nn.Module(Linen) tonnx.Module(NNX) - ✅ Replaced
@nn.compactdecorators with explicit__init__methods - ✅ Updated to use
nnx.Linear,nnx.relu, andnnx.sigmoid - ✅ Stateful module design following NNX patterns
2. Training Logic (train.py)
- ✅ Removed Linen imports (
from flax import linen as nn) - ✅ Removed
train_state.TrainState(not used in NNX) - ✅ Migrated to
nnx.Optimizerfor direct state management - ✅ Added
@nnx.jitdecorator totrain_step()function - ✅ Added
@nnx.jitdecorator toeval_f()function - ✅ Updated to use
nnx.value_and_gradfor gradient computation - ✅ Simplified RNG initialization with
nnx.Rngs - ✅ Fixed activation functions to use
jax.nn.log_sigmoid(NNX compatible)
3. Code Quality
- ✅ Removed commented-out Linen code
- ✅ Removed debug print statements
- ✅ Added proper docstrings
- ✅ Cleaner, more readable code structure
Performance Benchmarks
Training Time Comparison (30 epochs on binarized MNIST)
Hardware: CPU
| Implementation | Total Time | Speedup |
|---|---|---|
| Linen (Original) | 770.55 seconds | 1.0x (baseline) |
| NNX (This PR) | 83.62 seconds | 🚀 9.2x faster |
Detailed Training Logs
Linen (Original) - 770.55 seconds
I1107 20:28:20.349939 134792575346496 train.py:92] Initializing model.
eval epoch: 1, loss: 120.4574, BCE: 95.9768, KLD: 24.4806
eval epoch: 2, loss: 112.2203, BCE: 86.5090, KLD: 25.7113
eval epoch: 3, loss: 109.2380, BCE: 82.5628, KLD: 26.6753
eval epoch: 4, loss: 107.4103, BCE: 80.7181, KLD: 26.6922
eval epoch: 5, loss: 106.4009, BCE: 79.9999, KLD: 26.4009
eval epoch: 6, loss: 105.4629, BCE: 78.1415, KLD: 27.3214
eval epoch: 7, loss: 104.9457, BCE: 78.2812, KLD: 26.6645
eval epoch: 8, loss: 104.1115, BCE: 77.1570, KLD: 26.9545
eval epoch: 9, loss: 104.0962, BCE: 77.2225, KLD: 26.8737
eval epoch: 10, loss: 103.4601, BCE: 76.5734, KLD: 26.8867
eval epoch: 11, loss: 103.1765, BCE: 74.9070, KLD: 28.2695
eval epoch: 12, loss: 102.9001, BCE: 75.4757, KLD: 27.4244
eval epoch: 13, loss: 102.8762, BCE: 75.2979, KLD: 27.5783
eval epoch: 14, loss: 102.6160, BCE: 74.7805, KLD: 27.8354
eval epoch: 15, loss: 102.4688, BCE: 74.8430, KLD: 27.6258
eval epoch: 16, loss: 102.2647, BCE: 74.5797, KLD: 27.6851
eval epoch: 17, loss: 102.1568, BCE: 74.6302, KLD: 27.5266
eval epoch: 18, loss: 102.0417, BCE: 74.4451, KLD: 27.5966
eval epoch: 19, loss: 101.8212, BCE: 74.2193, KLD: 27.6019
eval epoch: 20, loss: 101.6886, BCE: 73.9931, KLD: 27.6955
eval epoch: 21, loss: 101.7861, BCE: 74.6319, KLD: 27.1542
eval epoch: 22, loss: 101.5596, BCE: 73.6897, KLD: 27.8699
eval epoch: 23, loss: 101.5948, BCE: 73.7950, KLD: 27.7998
eval epoch: 24, loss: 101.3442, BCE: 72.9828, KLD: 28.3614
eval epoch: 25, loss: 101.3695, BCE: 73.8438, KLD: 27.5258
eval epoch: 26, loss: 101.3590, BCE: 73.2725, KLD: 28.0864
eval epoch: 27, loss: 101.2677, BCE: 73.4502, KLD: 27.8176
eval epoch: 28, loss: 101.0781, BCE: 73.9786, KLD: 27.0995
eval epoch: 29, loss: 101.0382, BCE: 73.7349, KLD: 27.3033
eval epoch: 30, loss: 100.8657, BCE: 73.0155, KLD: 27.8502
I1107 20:41:07.840827 134792575346496 main.py:62] Total training time: 770.55 seconds
NNX (This PR) - 83.62 seconds ⚡
I1107 20:27:30.448354 133449167238976 train.py:89] Initializing model.
eval epoch: 1, loss: 120.5688, BCE: 97.0525, KLD: 23.5163
eval epoch: 2, loss: 112.4121, BCE: 86.6029, KLD: 25.8092
eval epoch: 3, loss: 109.3856, BCE: 82.7984, KLD: 26.5873
eval epoch: 4, loss: 107.7712, BCE: 81.1137, KLD: 26.6575
eval epoch: 5, loss: 106.4054, BCE: 79.3470, KLD: 27.0583
eval epoch: 6, loss: 105.4061, BCE: 78.6539, KLD: 26.7522
eval epoch: 7, loss: 105.0498, BCE: 78.3357, KLD: 26.7141
eval epoch: 8, loss: 104.4686, BCE: 77.4883, KLD: 26.9803
eval epoch: 9, loss: 103.9196, BCE: 76.4744, KLD: 27.4451
eval epoch: 10, loss: 103.6025, BCE: 75.8894, KLD: 27.7131
eval epoch: 11, loss: 103.2291, BCE: 76.3263, KLD: 26.9028
eval epoch: 12, loss: 103.0131, BCE: 75.4742, KLD: 27.5389
eval epoch: 13, loss: 102.7186, BCE: 75.6376, KLD: 27.0810
eval epoch: 14, loss: 102.7346, BCE: 75.2200, KLD: 27.5146
eval epoch: 15, loss: 102.4095, BCE: 74.4650, KLD: 27.9445
eval epoch: 16, loss: 102.2019, BCE: 73.9366, KLD: 28.2653
eval epoch: 17, loss: 102.2103, BCE: 74.8777, KLD: 27.3327
eval epoch: 18, loss: 101.9611, BCE: 74.5904, KLD: 27.3707
eval epoch: 19, loss: 101.6206, BCE: 74.3661, KLD: 27.2545
eval epoch: 20, loss: 101.8825, BCE: 73.7910, KLD: 28.0915
eval epoch: 21, loss: 101.6405, BCE: 74.0798, KLD: 27.5608
eval epoch: 22, loss: 101.5163, BCE: 73.6049, KLD: 27.9114
eval epoch: 23, loss: 101.6124, BCE: 74.0859, KLD: 27.5265
eval epoch: 24, loss: 101.3986, BCE: 73.0860, KLD: 28.3126
eval epoch: 25, loss: 101.2187, BCE: 73.8297, KLD: 27.3891
eval epoch: 26, loss: 101.2321, BCE: 73.3646, KLD: 27.8674
eval epoch: 27, loss: 101.1214, BCE: 72.8930, KLD: 28.2285
eval epoch: 28, loss: 100.9203, BCE: 74.0056, KLD: 26.9146
eval epoch: 29, loss: 101.0700, BCE: 73.9756, KLD: 27.0943
eval epoch: 30, loss: 100.8980, BCE: 72.6217, KLD: 28.2763
I1107 20:28:53.361864 133449167238976 main.py:62] Total training time: 83.62 seconds
Performance Analysis
Time Saved: 686.93 seconds (11.45 minutes) for 30 epochs
Key Performance Factors:
-
JIT Compilation (
@nnx.jit): 5-10x speedup through XLA optimization - Operation Fusion: Multiple operations compiled into single optimized kernels
- Reduced Python Overhead: Compiled code runs at near-C speed
- Better Memory Management: Fewer intermediate allocations
- Asynchronous Dispatch: GPU/CPU overlap for improved throughput
Model Quality
Both implementations converge to similar final loss values (~100.86-100.89), demonstrating that the NNX migration maintains training quality while dramatically improving performance.
Testing
- ✅ Training runs successfully for 30 epochs
- ✅ Loss convergence matches original implementation
- ✅ Generated samples quality preserved
- ✅ Reconstruction quality maintained
- ✅ No breaking changes to public API
Compatibility
- Requires
flax >= 0.8.0(NNX API support) - Compatible with existing
configs/default.py - Same command-line interface:
python main.py --workdir=/tmp/mnist --config=configs/default.py
Migration Benefits
- 🚀 Performance: 9.2x faster training (770s → 83s)
- 📝 Simplicity: More intuitive, Pythonic API
-
🔍 Debuggability: Direct module inspection without
.apply() - 🎯 Future-proof: Aligns with Flax's recommended API going forward
- 📚 Educational: Better example for new users learning Flax NNX
Related Documentation
Checklist
- [x] Code follows Flax NNX best practices
- [x] Performance benchmarks included
- [x] Training quality validated
- [x] Backward compatibility considered
- [x] Documentation patterns followed