nnx.jit() cannot specify backend or device due to out_shardings
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux, but OS-agnostic
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:- flax 0.10.6
- jax 0.6.0
- jaxlib 0.6.0 (cuda12)
- Python version: 3.12
- GPU/TPU model and memory: not relevant
- CUDA version (if applicable): 12.6
Problem you have encountered:
nnx.jit() doesn't seem to fully support backend=... or device=... parameters unlike the vanilla jax.jit() does, because out_shardings must be UNSPECIFIED when device is specified.
ValueError: If backend or device is specified on jit, then out_shardings should not be specified.
What you expected to happen:
nnx.jit(fn, backend=...) or nnx.jit(fn, device=...) should work seamlessly as the vanilla jit(). The minimal reproduction code below should run and print a [1, 4] array.
Logs, error messages, etc:
This is probably because nnx's jit wrapper always passes a 3-tuple (jax_in_shardings, kwargs_shardings, jax_out_shardings) to jax.jit according to the output of JitFn.
See https://github.com/google/flax/blob/main/flax/nnx/transforms/compilation.py#L379
Steps to reproduce:
A minimal reproduction:
import flax.nnx as nnx
import jax.numpy as jnp
model = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
@nnx.jit(backend='cpu')
def foo(model: nnx.Linear, x):
return model(x)
batch_size: int = 2
y = foo(model, jnp.ones([batch_size, 3]))
assert y.shape == (batch_size, 4)
print(y)
Expected: no error Actual: ValueError: If backend or device is specified on jit, then out_shardings should not be specified.
Hey @wookayin, thanks for reporting this. I think it makes sense to support this, we'd have to tweak nnx.jit a little bit but it should be possible to avoid defining out_shardings under the hood if both user defined in_shardings and out_shardings are not provided.
Hi @cgarciae , I would like to work on this.