flax icon indicating copy to clipboard operation
flax copied to clipboard

Update Flax NNX Randomness

Open 8bitmp3 opened this issue 1 year ago • 4 comments

@cgarciae @IvyZX PTAL

Preview: https://flax--4279.org.readthedocs.build/en/4279/guides/randomness.html

8bitmp3 avatar Oct 10 '24 00:10 8bitmp3

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Also updating:

- Here is a list of the main PRNG-related types in Flax NNX:
+ Here are the main PRNG-related types in Flax NNX:

Added the JAX jax.random tutorial for new users:

> **Note:** To learn more about random number generation in JAX,
the `jax.random` API, and PRNG-generated sequences, check out
this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).

Added a link to JAX Working with pytrees for new users:

`nnx.reseed` is a function that accepts an arbitrary graph node (this includes
[pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees)
of `nnx.Module`s)...

Added nnx.:

...`nnx.Rngs` object..... `nnx.reseed`.... `nnx.Dropout`

8bitmp3 avatar Oct 10 '24 12:10 8bitmp3

rebasing after https://github.com/google/flax/pull/4281

8bitmp3 avatar Oct 10 '24 21:10 8bitmp3

@cgarciae PTAL, thanks!

8bitmp3 avatar Oct 15 '24 22:10 8bitmp3