jax icon indicating copy to clipboard operation
jax copied to clipboard

custom batching (vmap)

Open froystig opened this issue 3 years ago • 1 comments

Support custom batching, i.e. the ability to register a custom "vmap rule" for any given function. Example usage would look something like:

from jax import vmap, custom_vmap, numpy as jnp

@custom_vmap
def vector_dot(u, v):
  assert u.ndim == v.ndim == 1
  return u @ v

@vector_dot.def_vmap
def vector_dot_vmap_rule(axis_size, in_batched, u, v):
  u_batched, v_batched = in_batched
  if u_batched:
    assert u.ndim == 2 and u.shape[0] == axis_size
    print('lhs batched')
  if v_batched:
    assert v.ndim == 2 and v.shape[0] == axis_size
    print('rhs batched')
  if u_batched and v_batched:
    out = jnp.sum(u * v, axis=1)
  else:
    out = u @ v if u_batched else v @ u
  return out, u_batched or v_batched

def f(u, v):
  return jnp.exp(vector_dot(u, v))

x = lambda *shape: jnp.ones(shape)
vmap(f, in_axes=(0, None))(x(4, 3), x(3))  # -> lhs batched
vmap(f, in_axes=(1, None))(x(3, 4), x(3))  # -> lhs batched
vmap(f, in_axes=(None, 0))(x(3), x(4, 3))  # -> rhs batched
vmap(f, in_axes=(0, 0))(x(4, 3), x(4, 3))  # -> lhs batched, rhs batched

This would enable #7199 and would help avoid #8853, among other things (e.g. #12345).

froystig avatar Dec 30 '21 22:12 froystig

A rough update on where the implementation is at present:

  • Batching, forward-mode AD (e.g. jvp, jacfwd), and compilation are all supported.
  • Reverse-mode AD (specifically linearization by partial evaluation, and transposition) are a work in progress.
  • The underlying primitive currently stages out the custom-batched function eagerly. We may want to move to a delayed tracing approach.

We're also thinking about whether to recommend the general custom_vmap function as a direct user-facing API, or whether instead to encourage more structured uses via functions like sequential_vmap or variations thereof.

froystig avatar May 12 '22 15:05 froystig

Let's look to fix/support #13283 as part of this as well.

froystig avatar Nov 21 '22 23:11 froystig

By the way, I notice that the above example doesn't include axis_name, which in general needs be passed in as well.

patrick-kidger avatar Mar 15 '23 00:03 patrick-kidger