flax icon indicating copy to clipboard operation
flax copied to clipboard

Migrate vae example from flax.linen to flax.nnx

Open sanepunk opened this issue 2 months ago • 2 comments

Description

The current VAE example uses the flax.linen API for model definition and training.
As Flax continues to develop the nnx module as its next-generation neural network API, it would be valuable to provide an updated version of this example using flax.nnx.

This migration will help users:

  • Learn how to implement a VAE using nnx's new modular and explicit state-handling paradigm.
  • Compare differences between nn and nnx APIs in real-world use cases.
  • Encourage adoption of nnx in research and production examples.

Proposed Changes

  • Reimplement the model (Encoder, Decoder, and VAE wrapper) using flax.nnx.Module.
  • Replace flax.training.train_state.TrainState with nnx.Optimizer for parameter management.
  • Update training and evaluation loops to use nnx.jit and direct method calls instead of apply().
  • Ensure reproducibility and equivalence with the original nn-based example.

Contribution

I'd be happy to implement this migration and submit a PR. Please let me know if there are any specific guidelines or preferences for the implementation approach.


Motivation

The VAE example is a widely understood benchmark that involves both deterministic and stochastic components, making it ideal to showcase nnx's design strengths:

  • Explicit randomness (nnx.Rngs)
  • Parameter/state separation
  • Compositional design
  • Compatibility with Optax and other JAX tools

Having this example available in nnx would significantly benefit users exploring or transitioning to the new API.

sanepunk avatar Nov 03 '25 16:11 sanepunk

@sanepunk thanks for the suggestion, a PR contributing that is welcome!

Update training and evaluation loops to use nnx.jit and direct method calls instead of apply().

Let's use even jax.jit for jitting the train_step.

vfdev-5 avatar Nov 05 '25 14:11 vfdev-5

Just to add to this, I noticed the VAE example here uses nnx: Link

I think this speaks to a broader problem with the example structure. As the nnx API is being introduced, the examples are becoming difficult to navigate, and it's not always clear which API an example is for.

I think it would make sense to split the examples strictly into separate nnx and linen folders. This would make it much easier for users to find the examples they need.

If the maintainers are interested in this, I would be happy to start working on a PR to reorganize them.

MatzeLopi avatar Nov 09 '25 14:11 MatzeLopi