Ojas Patil
Ojas Patil
## What does this PR do? Integrates `nnx.BatchNorm` into the CNN model used in `docs_nnx/mnist_tutorial.ipynb`. Enables batch-norm-aware behavior with `.train()` and `.eval()` modes to improve convergence and metrics visualization. ###...
# Migrate VAE Example to Flax NNX with JIT Optimization ## Summary This PR migrates the VAE (Variational Autoencoder) example from **Flax Linen** to **Flax NNX**, the new simplified API....
#### Description The current [VAE example](https://github.com/google/flax/tree/main/examples/vae) uses the `flax.linen` API for model definition and training. As Flax continues to develop the `nnx` module as its next-generation neural network API, it...