flax
flax copied to clipboard
Update Flax NNX Randomness
@cgarciae @IvyZX PTAL
Preview: https://flax--4279.org.readthedocs.build/en/4279/guides/randomness.html
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Also updating:
- Here is a list of the main PRNG-related types in Flax NNX:
+ Here are the main PRNG-related types in Flax NNX:
Added the JAX jax.random tutorial for new users:
> **Note:** To learn more about random number generation in JAX,
the `jax.random` API, and PRNG-generated sequences, check out
this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).
Added a link to JAX Working with pytrees for new users:
`nnx.reseed` is a function that accepts an arbitrary graph node (this includes
[pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees)
of `nnx.Module`s)...
Added nnx.:
...`nnx.Rngs` object..... `nnx.reseed`.... `nnx.Dropout`
rebasing after https://github.com/google/flax/pull/4281
@cgarciae PTAL, thanks!