Migrate vae example from flax.linen to flax.nnx
Description
The current VAE example uses the flax.linen API for model definition and training.
As Flax continues to develop the nnx module as its next-generation neural network API, it would be valuable to provide an updated version of this example using flax.nnx.
This migration will help users:
- Learn how to implement a VAE using
nnx's new modular and explicit state-handling paradigm. - Compare differences between
nnandnnxAPIs in real-world use cases. - Encourage adoption of
nnxin research and production examples.
Proposed Changes
- Reimplement the model (
Encoder,Decoder, andVAEwrapper) usingflax.nnx.Module. - Replace
flax.training.train_state.TrainStatewithnnx.Optimizerfor parameter management. - Update training and evaluation loops to use
nnx.jitand direct method calls instead ofapply(). - Ensure reproducibility and equivalence with the original
nn-based example.
Contribution
I'd be happy to implement this migration and submit a PR. Please let me know if there are any specific guidelines or preferences for the implementation approach.
Motivation
The VAE example is a widely understood benchmark that involves both deterministic and stochastic components, making it ideal to showcase nnx's design strengths:
- Explicit randomness (
nnx.Rngs) - Parameter/state separation
- Compositional design
- Compatibility with Optax and other JAX tools
Having this example available in nnx would significantly benefit users exploring or transitioning to the new API.
@sanepunk thanks for the suggestion, a PR contributing that is welcome!
Update training and evaluation loops to use nnx.jit and direct method calls instead of apply().
Let's use even jax.jit for jitting the train_step.
Just to add to this, I noticed the VAE example here uses nnx:
Link
I think this speaks to a broader problem with the example structure. As the nnx API is being introduced, the examples are becoming difficult to navigate, and it's not always clear which API an example is for.
I think it would make sense to split the examples strictly into separate nnx and linen folders. This would make it much easier for users to find the examples they need.
If the maintainers are interested in this, I would be happy to start working on a PR to reorganize them.