jax icon indicating copy to clipboard operation
jax copied to clipboard

Add a "broadcasting vmap" helper to custom_batching.

Open dfm opened this issue 1 year ago • 3 comments

It has come up a few times (most recently in https://github.com/google/jax/issues/23624) that the "vectorized" behavior of pure_callback and ffi_call is confusing. I'm working on improving that, but in the meantime, it seems like it would be useful to provide a broadcasting_vmap similar to the sequential_vmap helper that we currently have for vmapping with a lax.map.

dfm avatar Sep 13 '24 13:09 dfm

It should also respect the in_axes and out_axes of vmap for my use case. I don't manually construct the vmaps in my case, it's a function that nests vmaps and scans as specified by a signature. https://github.com/Joshuaalbert/DSA2000-Cal/blob/joshs-working-branch/dsa2000_cal/dsa2000_cal/common/jax_utils.py#L395

Joshuaalbert avatar Sep 13 '24 13:09 Joshuaalbert

What do you mean by "respect the in_axes and out_axes"?

It might make more sense to have this conversation in https://github.com/google/jax/issues/23624 regardless. If the example code I provided over there doesn't do what you want, please explain exactly what behavior you would expect!

dfm avatar Sep 13 '24 13:09 dfm

I mean this should pass.

@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
    def add(x, y, z):
        assert x.shape == (4, 5)
        assert y.shape == (4, 5)
        assert z.shape == ()
        return x + y + z
    return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)


if __name__ == '__main__':
    x = jnp.arange(4, dtype=jnp.float32)
    y = jnp.arange(5, dtype=jnp.float32)
    z = jnp.array(1, dtype=jnp.float32)

    assert cb_vec(x, y, z).shape == (4, 5)

Joshuaalbert avatar Sep 13 '24 14:09 Joshuaalbert

Closing in favor of: https://github.com/jax-ml/jax/pull/23881

dfm avatar Oct 10 '24 14:10 dfm