Functionality to chunk `vmap`.
Occasionally I run into the problem that batch-sizes are too large for GPU memory and using the current public API of Jax I either have to commit to one of the extremes of using vmap or using the much more limited scan. (I'm still relatively new to Jax, so please correct me if I'm wrong). Could there perhaps be a chunk argument to the vmap function that limits how much is passed to the vmapped function at any time?
I tried implementing this, and have the following mock-up. At the moment this is based on a num_chunks parameter, but this could be more flexible such as in terms of chunk_size. It splits or repeats the inputs of the vmapped function based on the given in_axes, pads all input arrays uniformly with zeroes along the batch dimension, and finally for-loop the chunks through a conventional vmap function. The results are finally concatenated based on their canonical slices.
I don't have all the functionalities of the normal vmap implemented, but it works in a slightly more narrow scope. Could something like this be supported by the public Jax API?
def pad_along_axis(array: jnp.ndarray, axis_length: int, axis: int = 0, *args, **kwargs) -> jnp.ndarray:
target_size = axis_length - jnp.shape(array)[axis]
padding = [(0, 0)] * jnp.ndim(array)
padding[axis] = (0, target_size)
return jnp.pad(array, padding, *args, **kwargs)
def chunked_vmap(fun, num_chunks: int = 1, in_axes=0, out_axes=0, axis_name=None, axis_size=None):
# TODO: Compatibility on flattened in_axes. Implementation for out_axes, axis_name, axis_size.
# Note, num_chunks == 1 is equivalent to just using `vmap_fun`.
vmap_fun = jax.vmap(fun, in_axes, out_axes, axis_name, axis_size)
# Leaf structure of input splitting: ([chunk_a, chunk_b, ...], [pad_a, pad_b, ...])
splitted_treedef = jax.tree_structure(([1] * num_chunks,) * 2)
def split_fun(arg, ax):
# Operates on pytree leaves.
if ax is None:
return [arg] * num_chunks, [0] * num_chunks
chunks = jnp.array_split(arg, num_chunks, axis=ax)
leading_size = jnp.shape(chunks[0])[ax]
batch_dims = jax.tree_map(lambda a: jnp.shape(a)[ax], chunks)
padded_chunks = jax.tree_map(partial(pad_along_axis, axis_length=leading_size, axis=ax), chunks)
return padded_chunks, batch_dims
def vmap_f(*args, **kwargs): # TODO: Incorporate kwargs?
splitted = jax.tree_map(split_fun, args, in_axes)
input_chunks, canonical_sizes = jax.tree_transpose(
jax.tree_structure(args), splitted_treedef, splitted
)
out_sizes = [max(jax.tree_leaves(s)) for s in canonical_sizes]
# TODO: use jax.lax.scan? Note the dynamic shapes of jax.lax.slice and that in_axes is not yet supported.
results = [jax.lax.slice(vmap_fun(*c), (0, ), (s,)) for c, s in zip(input_chunks, out_sizes)]
# TODO: collect all outputs immediately, or use a generator with `yield`?
out = jax.tree_map(lambda *a: jnp.concatenate(a), *results)
return out
return vmap_f
def myfun(a, b, c):
return jnp.square(a) * c + jnp.squeeze(b['val'])
v = jnp.arange(100)
args = (v, {'val': v.reshape(1, 1, -1, 100)}, 0.4)
in_axes = (0, {'val': -1}, None)
vmap_fun = jax.vmap(myfun, in_axes=in_axes)
out = vmap_fun(*args)
for chunk_size in [1, 2, 5, 10, 50]:
chunk_out = chunked_vmap(myfun, chunk_size, in_axes=in_axes)(*args)
assert jnp.isclose(out, chunk_out).all() # Runs fine
Take a look at jax.experimental.maps.xmap, which is designed for exactly this sort of thing:
https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html
https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.xmap.html
You can chunk vmap by making use of SerialLoop.
This here might be something you could use: https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vmap_chunked.html#netket.jax.vmap_chunked
This here might be something you could use: https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vmap_chunked.html#netket.jax.vmap_chunked
Thanks that is an excellent reference. I am of the opinion though that Jax could implement this by default given how generally useful this is.
Hi, more than a year later ;P, are there any plans on implementing this by default in Jax? Or would the authors be open for a PR?
@mattjj we were just talking about this :)
If you do get to it, I would love if you also add support for vjp over 'chunked' axes like we did in https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vjp_chunked.html
+1 for adding this feature, it would be super useful in several projects for me.
Discussion https://github.com/google/jax/discussions/18398 asks for this as well.
+1 upvote for implementing this
It would be nice if JAX's compiler could automatically convert computations from parallel to sequential when necessary, given a known memory constraint. Has this possibility been discussed anywhere?
jax.lax.map now has a batch_size argument that will chuck the computation and internally utilize vmap to operate over each batch in parallel. See #19614.
I tried using JAX as an alternative to numba, since it is much easier to work in, and I can switch between CPU and GPU using a flag. The biggest problem I am facing now, is that vmap tries to put all of the data onto the GPU at the same time. I am already using vmap to apply the scalar function onto a 2D array, so I was wondering if and how map could help me to batch the data, so I am not getting of memory errors!
Try the batch_size argument to jax.lax.map? (Added in #19614.)
Think of it as a sequential loop over batches, where each batch is vmapped.