Add a "broadcasting vmap" helper to custom_batching.
It has come up a few times (most recently in https://github.com/google/jax/issues/23624) that the "vectorized" behavior of pure_callback and ffi_call is confusing. I'm working on improving that, but in the meantime, it seems like it would be useful to provide a broadcasting_vmap similar to the sequential_vmap helper that we currently have for vmapping with a lax.map.
It should also respect the in_axes and out_axes of vmap for my use case. I don't manually construct the vmaps in my case, it's a function that nests vmaps and scans as specified by a signature. https://github.com/Joshuaalbert/DSA2000-Cal/blob/joshs-working-branch/dsa2000_cal/dsa2000_cal/common/jax_utils.py#L395
What do you mean by "respect the in_axes and out_axes"?
It might make more sense to have this conversation in https://github.com/google/jax/issues/23624 regardless. If the example code I provided over there doesn't do what you want, please explain exactly what behavior you would expect!
I mean this should pass.
@partial(jax.vmap, in_axes=(0, None, None))
@partial(jax.vmap, in_axes=(None, 0, None))
def cb_vec(x, y, z):
def add(x, y, z):
assert x.shape == (4, 5)
assert y.shape == (4, 5)
assert z.shape == ()
return x + y + z
return jax.pure_callback(add, jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), x, y, z, vectorized=True)
if __name__ == '__main__':
x = jnp.arange(4, dtype=jnp.float32)
y = jnp.arange(5, dtype=jnp.float32)
z = jnp.array(1, dtype=jnp.float32)
assert cb_vec(x, y, z).shape == (4, 5)
Closing in favor of: https://github.com/jax-ml/jax/pull/23881