flax icon indicating copy to clipboard operation
flax copied to clipboard

nnx.tabulate expects concrete values in unreleased flax>0.12.0

Open DBraun opened this issue 1 month ago • 5 comments

I have the latest flax installed with pip install -U git+https://github.com/google/flax.git (specifically https://github.com/google/flax/commit/74985b29404a8da3ab767fdfbbf8ab5ecb532574). nnx.tabulate is not working because it seems to require shapes to be concrete.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): WSL2 on Windows Pro
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
Name: flax
Version: 0.12.0
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: https://github.com/google/flax
Author:
Author-email: Flax team <[email protected]>
License:
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, treescope, typing_extensions
Required-by: clu, evosax, jax-ai-stack, jraphx
---
Name: jax
Version: 0.8.0
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: jaxlib, ml_dtypes, numpy, opt_einsum, scipy
Required-by: chex, clu, evosax, flax, jax-ai-stack, jaxloudnorm, jraphx, optax, orbax-checkpoint, orbax-export
---
Name: jaxlib
Version: 0.8.0
Summary: XLA library for JAX
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/admin/.local/lib/python3.12/site-packages
Requires: ml_dtypes, numpy, scipy
Required-by: chex, clu, jax, jraphx, optax, orbax-export
  • Python version: 3.12
  • GPU/TPU model and memory: RTX 4080
  • CUDA version (if applicable): 12

Problem you have encountered:

MRE:

from jax import numpy as jnp
from flax import nnx


class Net(nnx.Module):
    def __init__(self):
        self.rngs = nnx.Rngs(0)

    def __call__(self, x):
        return self.rngs.uniform((x.shape[0], 10))


if __name__ == '__main__':
    net = Net()
    x = jnp.zeros((4, 8))
    print("running forward pass")
    y = net(x)
    print("running tabulate")
    print(nnx.tabulate(net, x, console_kwargs={"width": 200}))
    print("all done")

What you expected to happen:

With flax 0.12.0 from PyPI, the output is

running forward pass
running tabulate
                           Net Summary                           
┏━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ path         ┃ type ┃ inputs       ┃ outputs       ┃ RngState ┃
┡━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│              │ Net  │ float32[4,8] │ float32[4,10] │ 2 (12 B) │
├──────────────┼──────┼──────────────┼───────────────┼──────────┤
│ rngs/uniform │ Rngs │ - 4          │ float32[4,10] │ 2 (12 B) │
│              │      │ - 10         │               │          │
├──────────────┼──────┼──────────────┼───────────────┼──────────┤
│              │      │              │         Total │ 2 (12 B) │
└──────────────┴──────┴──────────────┴───────────────┴──────────┘
                                                                 
                   Total Parameters: 2 (12 B)                    

all done

Logs, error messages, etc:

Output from MRE above:

running forward pass
running tabulate
Traceback (most recent call last):
  File "/mnt/c/Users/admin/AppData/Roaming/JetBrains/PyCharm2025.2/scratches/scratch_33.py", line 19, in <module>
    print(nnx.tabulate(net, x, console_kwargs={"width": 200}))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 385, in tabulate
    jits[(type(obj), method)].trace(obj, *input_args, **input_kwargs)
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 475, in trace
    traced = self.jitted_fn.trace(*pure_args, **pure_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 129, in __call__
    out = self.f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 215, in wrapper
    return f(obj, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/c/Users/admin/AppData/Roaming/JetBrains/PyCharm2025.2/scratches/scratch_33.py", line 10, in __call__
    return self.rngs.uniform((x.shape[0], 10))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 447, in __call__
    pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
                                               ^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py", line 129, in __call__
    out = self.f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/summary.py", line 215, in wrapper
    return f(obj, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py", line 62, in rngs_random_method
    return random_f(self(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/admin/.local/lib/python3.12/site-packages/jax/_src/random.py", line 425, in uniform
    shape = core.canonicalize_shape(shape)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>, JitTracer<~int32[]>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function rngs_random_method at /home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py:61 for jit. This concrete value was not available in Python because it depends on the value of the argument args[0][0].
The error occurred while tracing the function rngs_random_method at /home/admin/.local/lib/python3.12/site-packages/flax/nnx/rnglib.py:61 for jit. This concrete value was not available in Python because it depends on the value of the argument args[0][1].
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

DBraun avatar Nov 03 '25 14:11 DBraun

@DBraun thanks for reporting this regression. I confirm the error with the latest flax and the example is passing on v0.11.2. I assume that the regression is related to this PR. @samanklesaria please take a look.

vfdev-5 avatar Nov 03 '25 14:11 vfdev-5

The PR #4948 dealt with calculating flops for each method in a tabulate call. To get the number of flops, we need to compile the method first. But certain methods (like rng.uniform can't be compiled). We can't jit uniform because it doesn't have a static output shape: its output shape depends on the value of the tuple argument. The error in the code above can be reproduced more simply as follows:

x = jnp.zeros((4, 8))
rng = nnx.Rngs(0)
uniform = nnx.jit(nnx.Rngs.uniform)
uniform(rng, (x.shape[0], 10))

More generally, #4948 will always fail whenever the object passed to tabulate has methods (or has children with methods) that aren't jitt-table.

Ideally, we'd be able to support tabulate calls for methods that can't be jitted like this, even if we don't support the calculate_flops option for these situations. We need handle exceptions raised in the jit call by re-trying without a jit.

samanklesaria avatar Nov 03 '25 18:11 samanklesaria

Ideally, we'd be able to support tabulate calls for methods that can't be jitted like this.

I may be misunderstanding, but jitting them actually isn't a problem in the unreleased flax.

from jax import numpy as jnp
from flax import nnx


class Net(nnx.Module):
    def __init__(self):
        self.rngs = nnx.Rngs(0)

    def __call__(self, x):
        return self.rngs.uniform((x.shape[0], 10))


if __name__ == '__main__':
    net = Net()
    @nnx.jit
    def forward(m, x):
        return m(x)

    x = jnp.zeros((4, 8))
    print("running jit forward pass")
    y = forward(net, x)
    print("all done")

output:

running jit forward pass
all done

DBraun avatar Nov 04 '25 02:11 DBraun

I may be misunderstanding, but jitting them actually isn't a problem in the unreleased flax.

Jitting forward is fine. But jitting Rngs.uniform is not: it has a fundamentally dynamic shape that depends on its input arguments. In the table produced by tabulate, there will be a row for the call to Rngs.uniform. That row cannot be generated by a jitted call. Does that make sense?

samanklesaria avatar Nov 04 '25 03:11 samanklesaria

It seems that jitting recursively is the problem then, instead of just trying to jit at the top level. However, my knowledge of the flax internals is limited.

DBraun avatar Nov 04 '25 12:11 DBraun