flax icon indicating copy to clipboard operation
flax copied to clipboard

Flax 0.9.0 broke nnx rng splitting

Open kasper0406 opened this issue 1 year ago • 1 comments

Consider the following code:


import jax
import jax.numpy as jnp
from flax import nnx
from functools import partial

class TestStackedLaxScanLinear(nnx.Module):
    def __init__(self, rngs=nnx.Rngs):
        @partial(nnx.vmap, axis_size=3)  # 3 hidden layers
        def create_hidden_layers(rngs: nnx.Rngs):
            return nnx.Linear(in_features=4, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs)
        self.hidden_layers = create_hidden_layers(rngs)

        self.upscale_layer = nnx.Linear(in_features=2, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs)
        self.downscale_layer = nnx.Linear(in_features=4, out_features=2, bias_init=nnx.initializers.ones, rngs=rngs)

    def __call__(self, x):
        out = self.upscale_layer(x)

        layer_def, layer_states = nnx.split(self.hidden_layers)

        def forward(x, layer_state):
            layer = nnx.merge(layer_def, layer_state)
            x = layer(x)
            return x, None
        out, _ = jax.lax.scan(forward, out, layer_states)

        out = self.downscale_layer(out)
        return out

model = nnx.jit(TestStackedLaxScanLinear(nnx.Rngs(0)))
model(jnp.zeros((2, 2)))

On flax==0.8.5 this works fine, however on flax==0.9.0 it fails with the following error:


──────────────────────────────────────────────────────────────── test.py3.11 ─────────────────────────────────────────────────────────────────
Traceback (most recent call last):
  File "/workspaces/codespaces-blank/stablehlo-coreml/tests/test_hmm.py", line 30, in <module>
    model = nnx.jit(TestStackedLaxScanLinear(nnx.Rngs(0)))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 79, in __call__
    return _graph_node_meta_call(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 88, in _graph_node_meta_call
    cls._object_meta_construct(node, *args, **kwargs)
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 82, in _object_meta_construct
    self.__init__(*args, **kwargs)
  File "/workspaces/codespaces-blank/stablehlo-coreml/tests/test_hmm.py", line 11, in __init__
    self.hidden_layers = create_hidden_layers(rngs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/graph.py", line 1158, in update_context_manager_wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/transforms/iteration.py", line 339, in vmap_wrapper
    pure_args_out, pure_out = vmapped_fn(*pure_args)
                              ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/transforms/iteration.py", line 164, in __call__
    out = self.f(*args)
          ^^^^^^^^^^^^^
  File "/workspaces/codespaces-blank/stablehlo-coreml/tests/test_hmm.py", line 10, in create_hidden_layers
    return nnx.Linear(in_features=4, out_features=4, bias_init=nnx.initializers.ones, rngs=rngs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 79, in __call__
    return _graph_node_meta_call(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 88, in _graph_node_meta_call
    cls._object_meta_construct(node, *args, **kwargs)
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/object.py", line 82, in _object_meta_construct
    self.__init__(*args, **kwargs)
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/nn/linear.py", line 346, in __init__
    kernel_key = rngs.params()
                 ^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/flax/nnx/nnx/rnglib.py", line 84, in __call__
    key = jax.random.fold_in(self.key.value, self.count.value)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/jax/_src/random.py", line 246, in fold_in
    key, wrapped = _check_prng_key("fold_in", key)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/codespace/.local/share/hatch/env/virtual/stablehlo-coreml-experimental/XhHBArB2/test.py3.11/lib/python3.11/site-packages/jax/_src/random.py", line 74, in _check_prng_key
    if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key):
                                                    ^^^^^^^^^
IndexError: tuple index out of range
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The most obvious suspect looks to be https://github.com/google/flax/pull/4064

kasper0406 avatar Aug 29 '24 07:08 kasper0406

Hey @kasper0406! 0.9.0 implements the changes proposed in #4107, meaning that split_rngs and state_axes are removed from vmap/scan arguments and they no longer automatically handle RNG state. To get back the previous behavior the nnx.split_rngs API was introduced, you can use this to fix your example:

@nnx.split_rngs(splits=3)
@nnx.vmap(axis_size=3)  # 3 hidden layers
def create_hidden_layers(rngs: nnx.Rngs):
  return nnx.Linear(
    in_features=4,
    out_features=4,
    bias_init=nnx.initializers.ones,
    rngs=rngs,
  )

Also, transforms now behave as partials if the function is not provided so there is no need to use partial.

cgarciae avatar Aug 29 '24 10:08 cgarciae

Ah, got it, thanks! I had a look at the release notes, and saw no breaking changes being called out. Not sure if there's a better place to look.

kasper0406 avatar Aug 29 '24 19:08 kasper0406