Add BatchNorm layers to CNN in MNIST tutorial for improved training stability
What does this PR do?
Integrates nnx.BatchNorm into the CNN model used in docs_nnx/mnist_tutorial.ipynb. Enables batch-norm-aware behavior with .train() and .eval() modes to improve convergence and metrics visualization.
Highlights
- Added
nnx.BatchNorm()after each convolution - Updated training loop to call
model.train()so running statistics are updated - Switched to
model.eval()in the evaluation loop for deterministic inference - Observed smoother loss and accuracy curves in the notebook’s metrics graphs
Why?
BatchNorm stabilizes and accelerates training by normalizing activations between layers, leading to better gradient flow and faster convergence.
Testing
Ran the MNIST tutorial notebook end-to-end and confirmed:
- Training loss decreases more smoothly
- Validation accuracy improves more quickly
- Metrics plots clearly reflect these improvements
Closes #
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
@IvyZX @jburnim @cgarciae WDYT
Thanks for the update
Can you merge this PR? @8bitmp3 @IvyZX @jburnim @cgarciae
Can you merge this PR? @8bitmp3 @IvyZX @jburnim @cgarciae
I’ve made the necessary changes to fix the GitHub Actions failure, everything checks out now.
@sanepunk looks good! You can fix pre-commit with:
pip install pre-commit
pre-commit run --all-files
thank you @cgarciae @8bitmp3 @IvyZX @jburnim for helping me with my first PR to FLAX