jax
jax copied to clipboard
custom batching (vmap)
Support custom batching, i.e. the ability to register a custom "vmap rule" for any given function. Example usage would look something like:
from jax import vmap, custom_vmap, numpy as jnp
@custom_vmap
def vector_dot(u, v):
assert u.ndim == v.ndim == 1
return u @ v
@vector_dot.def_vmap
def vector_dot_vmap_rule(axis_size, in_batched, u, v):
u_batched, v_batched = in_batched
if u_batched:
assert u.ndim == 2 and u.shape[0] == axis_size
print('lhs batched')
if v_batched:
assert v.ndim == 2 and v.shape[0] == axis_size
print('rhs batched')
if u_batched and v_batched:
out = jnp.sum(u * v, axis=1)
else:
out = u @ v if u_batched else v @ u
return out, u_batched or v_batched
def f(u, v):
return jnp.exp(vector_dot(u, v))
x = lambda *shape: jnp.ones(shape)
vmap(f, in_axes=(0, None))(x(4, 3), x(3)) # -> lhs batched
vmap(f, in_axes=(1, None))(x(3, 4), x(3)) # -> lhs batched
vmap(f, in_axes=(None, 0))(x(3), x(4, 3)) # -> rhs batched
vmap(f, in_axes=(0, 0))(x(4, 3), x(4, 3)) # -> lhs batched, rhs batched
This would enable #7199 and would help avoid #8853, among other things (e.g. #12345).
A rough update on where the implementation is at present:
- Batching, forward-mode AD (e.g.
jvp
,jacfwd
), and compilation are all supported. - Reverse-mode AD (specifically linearization by partial evaluation, and transposition) are a work in progress.
- The underlying primitive currently stages out the custom-batched function eagerly. We may want to move to a delayed tracing approach.
We're also thinking about whether to recommend the general custom_vmap
function as a direct user-facing API, or whether instead to encourage more structured uses via functions like sequential_vmap
or variations thereof.
Let's look to fix/support #13283 as part of this as well.
By the way, I notice that the above example doesn't include axis_name
, which in general needs be passed in as well.