flax icon indicating copy to clipboard operation
flax copied to clipboard

Add BatchNorm layers to CNN in MNIST tutorial for improved training stability

Open sanepunk opened this issue 7 months ago • 2 comments

What does this PR do?

Integrates nnx.BatchNorm into the CNN model used in docs_nnx/mnist_tutorial.ipynb. Enables batch-norm-aware behavior with .train() and .eval() modes to improve convergence and metrics visualization.

Highlights

  • Added nnx.BatchNorm() after each convolution
  • Updated training loop to call model.train() so running statistics are updated
  • Switched to model.eval() in the evaluation loop for deterministic inference
  • Observed smoother loss and accuracy curves in the notebook’s metrics graphs

Why?

BatchNorm stabilizes and accelerates training by normalizing activations between layers, leading to better gradient flow and faster convergence.

Testing

Ran the MNIST tutorial notebook end-to-end and confirmed:

  • Training loss decreases more smoothly
  • Validation accuracy improves more quickly
  • Metrics plots clearly reflect these improvements

Closes #

sanepunk avatar Jun 07 '25 18:06 sanepunk

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Jun 07 '25 18:06 google-cla[bot]

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@IvyZX @jburnim @cgarciae WDYT

8bitmp3 avatar Jun 30 '25 19:06 8bitmp3

Thanks for the update

pccoeit avatar Jul 18 '25 13:07 pccoeit

Can you merge this PR? @8bitmp3 @IvyZX @jburnim @cgarciae

sanepunk avatar Aug 02 '25 17:08 sanepunk

Can you merge this PR? @8bitmp3 @IvyZX @jburnim @cgarciae

I’ve made the necessary changes to fix the GitHub Actions failure, everything checks out now.

sanepunk avatar Aug 02 '25 19:08 sanepunk

@sanepunk looks good! You can fix pre-commit with:

pip install pre-commit
pre-commit run --all-files

cgarciae avatar Aug 03 '25 00:08 cgarciae

thank you @cgarciae @8bitmp3 @IvyZX @jburnim for helping me with my first PR to FLAX

sanepunk avatar Aug 05 '25 06:08 sanepunk