jax icon indicating copy to clipboard operation
jax copied to clipboard

Unimplemented "discharged_consts" when discharging while loop effects

Open brianwa84 opened this issue 2 years ago • 1 comments

Description

The test below, when added into pallas_test.py, yields the error:

jax/_src/lax/control_flow/loops.py", line 1739, in _while_discharge_rule
    if discharged_consts: raise NotImplementedError(discharged_consts)  # changed this line
NotImplementedError: [array([[0.]], dtype=float32)]

Test case:

  def test_while_loop_discharge_consts(self):
    start = jnp.int32([[0, 10], [15, 20]])
    ntrips = jnp.int32(
        random.uniform(random.PRNGKey(2), (8, 8), minval=5., maxval=10.))

    @functools.partial(
        self.pallas_call,
        grid=(2, 2),
        in_specs=[
            pl.BlockSpec(lambda i, j: (i, j), (1, 1)),  # start
            pl.BlockSpec(lambda i, j: (i, j), (4, 4)),  # ntrips
        ],
        out_shape=jax.ShapeDtypeStruct((50, 2), jnp.float32),
        out_specs=pl.BlockSpec(lambda i, j: (0, 0), (50, 2)))
    def test_fn(start, ntrips, out):
      start = start[0, 0]
      ntrips = ntrips[:, :]
      def body(_, arg):
        i, = arg
        out[i, 1] = jnp.zeros([])
        return i - 1,

      jax.lax.fori_loop(0, ntrips.max(), body, (ntrips.max() - 1 + start,))

    test_fn(start, ntrips)

What jax/jaxlib version are you using?

Google internal

Which accelerator(s) are you using?

GPU

Additional system info

Google internal

NVIDIA GPU info

A100

brianwa84 avatar Nov 02 '23 19:11 brianwa84

This looks related to #16116; assigning @sharadmv

jakevdp avatar Nov 06 '23 20:11 jakevdp