Flax 0.9.0 broke nnx rng splitting
Consider the following code:
import jax
import jax.numpy as jnp
from flax import nnx
from functools import partial
class TestStackedLaxScanLinear(nnx.Module):
def __init__(self, rngs=nnx.Rngs):
@partial(nnx.vmap, axis_size=3) # 3 hidden layers
def create_hidden_layers(rngs: nnx.Rngs):
return nnx.Linear(in_features=4, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs)
self.hidden_layers = create_hidden_layers(rngs)
self.upscale_layer = nnx.Linear(in_features=2, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs)
self.downscale_layer = nnx.Linear(in_features=4, out_features=2, bias_init=nnx.initializers.ones, rngs=rngs)
def __call__(self, x):
out = self.upscale_layer(x)
layer_def, layer_states = nnx.split(self.hidden_layers)
def forward(x, layer_state):
layer = nnx.merge(layer_def, layer_state)
x = layer(x)
return x, None
out, _ = jax.lax.scan(forward, out, layer_states)
out = self.downscale_layer(out)
return out
model = nnx.jit(TestStackedLaxScanLinear(nnx.Rngs(0)))
model(jnp.zeros((2, 2)))
On flax==0.8.5 this works fine, however on flax==0.9.0 it fails with the following error:
──────────────────────────────────────────────────────────────── test.py3.11 ─────────────────────────────────────────────────────────────────
Traceback (most recent call last):
File "/workspaces/codespaces-blank/stablehlo-coreml/tests/test_hmm.py", line 30, in <module>
model = nnx.jit(TestStackedLaxScanLinear(nnx.Rngs(0)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 79, in __call__
return _graph_node_meta_call(cls, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 88, in _graph_node_meta_call
cls._object_meta_construct(node, *args, **kwargs)
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 82, in _object_meta_construct
self.__init__(*args, **kwargs)
File "/workspaces/codespaces-blank/stablehlo-coreml/tests/test_hmm.py", line 11, in __init__
self.hidden_layers = create_hidden_layers(rngs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/graph.py", line 1158, in update_context_manager_wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/transforms/iteration.py", line 339, in vmap_wrapper
pure_args_out, pure_out = vmapped_fn(*pure_args)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/transforms/iteration.py", line 164, in __call__
out = self.f(*args)
^^^^^^^^^^^^^
File "/workspaces/codespaces-blank/stablehlo-coreml/tests/test_hmm.py", line 10, in create_hidden_layers
return nnx.Linear(in_features=4, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 79, in __call__
return _graph_node_meta_call(cls, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 88, in _graph_node_meta_call
cls._object_meta_construct(node, *args, **kwargs)
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 82, in _object_meta_construct
self.__init__(*args, **kwargs)
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/nn/linear.py", line 346, in __init__
kernel_key = rngs.params()
^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/rnglib.py", line 84, in __call__
key = jax.random.fold_in(self.key.value, self.count.value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/jax/_src/random.py", line 246, in fold_in
key, wrapped = _check_prng_key("fold_in", key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/jax/_src/random.py", line 74, in _check_prng_key
if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
^^^^^^^^^
IndexError: tuple index out of range
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The most obvious suspect looks to be https://github.com/google/flax/pull/4064
Hey @kasper0406! 0.9.0 implements the changes proposed in #4107, meaning that split_rngs and state_axes are removed from vmap/scan arguments and they no longer automatically handle RNG state. To get back the previous behavior the nnx.split_rngs API was introduced, you can use this to fix your example:
@nnx.split_rngs(splits=3)
@nnx.vmap(axis_size=3) # 3 hidden layers
def create_hidden_layers(rngs: nnx.Rngs):
return nnx.Linear(
in_features=4,
out_features=4,
bias_init=nnx.initializers.ones,
rngs=rngs,
)
Also, transforms now behave as partials if the function is not provided so there is no need to use partial.
Ah, got it, thanks! I had a look at the release notes, and saw no breaking changes being called out. Not sure if there's a better place to look.