jax
jax copied to clipboard
Unimplemented "discharged_consts" when discharging while loop effects
Description
The test below, when added into pallas_test.py, yields the error:
jax/_src/lax/control_flow/loops.py", line 1739, in _while_discharge_rule
if discharged_consts: raise NotImplementedError(discharged_consts) # changed this line
NotImplementedError: [array([[0.]], dtype=float32)]
Test case:
def test_while_loop_discharge_consts(self):
start = jnp.int32([[0, 10], [15, 20]])
ntrips = jnp.int32(
random.uniform(random.PRNGKey(2), (8, 8), minval=5., maxval=10.))
@functools.partial(
self.pallas_call,
grid=(2, 2),
in_specs=[
pl.BlockSpec(lambda i, j: (i, j), (1, 1)), # start
pl.BlockSpec(lambda i, j: (i, j), (4, 4)), # ntrips
],
out_shape=jax.ShapeDtypeStruct((50, 2), jnp.float32),
out_specs=pl.BlockSpec(lambda i, j: (0, 0), (50, 2)))
def test_fn(start, ntrips, out):
start = start[0, 0]
ntrips = ntrips[:, :]
def body(_, arg):
i, = arg
out[i, 1] = jnp.zeros([])
return i - 1,
jax.lax.fori_loop(0, ntrips.max(), body, (ntrips.max() - 1 + start,))
test_fn(start, ntrips)
What jax/jaxlib version are you using?
Google internal
Which accelerator(s) are you using?
GPU
Additional system info
Google internal
NVIDIA GPU info
A100
This looks related to #16116; assigning @sharadmv