flax
flax copied to clipboard
[nnx] explicit Variables
What does this PR do?
- The
Variable.valueattribute which contained the unprocessed values is now named.raw_value. - Replaces
Variable.get_valueandVariable.set_valuewith a.valueproperty. Modulegetattr and setattr no longer extract the innervalueof Variables.- Removes
Module.variablesandState.variableshelpers.
Basic example now looks like this:
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w: Param[Array] = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b: Param[Array] = nnx.Param(jnp.zeros((dout,)))
def __call__(self, x: jax.Array):
return x @ self.w.value + self.b.value
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Codecov Report
Attention: Patch coverage is 95.66474% with 15 lines in your changes are missing coverage. Please review.
Project coverage is 58.85%. Comparing base (
acba0bf) to head (0d347d0).
Additional details and impacted files
@@ Coverage Diff @@
## main #3720 +/- ##
==========================================
- Coverage 59.05% 58.85% -0.21%
==========================================
Files 103 103
Lines 12440 12318 -122
==========================================
- Hits 7347 7250 -97
+ Misses 5093 5068 -25
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Woohoo!!! Awesome!!