jax icon indicating copy to clipboard operation
jax copied to clipboard

Functionality to chunk `vmap`.

Open joeryjoery opened this issue 3 years ago • 13 comments

Occasionally I run into the problem that batch-sizes are too large for GPU memory and using the current public API of Jax I either have to commit to one of the extremes of using vmap or using the much more limited scan. (I'm still relatively new to Jax, so please correct me if I'm wrong). Could there perhaps be a chunk argument to the vmap function that limits how much is passed to the vmapped function at any time?

I tried implementing this, and have the following mock-up. At the moment this is based on a num_chunks parameter, but this could be more flexible such as in terms of chunk_size. It splits or repeats the inputs of the vmapped function based on the given in_axes, pads all input arrays uniformly with zeroes along the batch dimension, and finally for-loop the chunks through a conventional vmap function. The results are finally concatenated based on their canonical slices.

I don't have all the functionalities of the normal vmap implemented, but it works in a slightly more narrow scope. Could something like this be supported by the public Jax API?

def pad_along_axis(array: jnp.ndarray, axis_length: int, axis: int = 0, *args, **kwargs) -> jnp.ndarray:
    
    target_size = axis_length - jnp.shape(array)[axis]
    
    padding = [(0, 0)] * jnp.ndim(array)
    padding[axis] = (0, target_size)
    
    return jnp.pad(array, padding, *args, **kwargs)


def chunked_vmap(fun, num_chunks: int = 1, in_axes=0, out_axes=0, axis_name=None, axis_size=None):
    # TODO: Compatibility on flattened in_axes. Implementation for out_axes, axis_name, axis_size.
        
    # Note, num_chunks == 1 is equivalent to just using `vmap_fun`.
    vmap_fun = jax.vmap(fun, in_axes, out_axes, axis_name, axis_size)
    
    # Leaf structure of input splitting: ([chunk_a, chunk_b, ...], [pad_a, pad_b, ...]) 
    splitted_treedef = jax.tree_structure(([1] * num_chunks,) * 2)
    
    def split_fun(arg, ax):  
        # Operates on pytree leaves.
        
        if ax is None:
            return [arg] * num_chunks, [0] * num_chunks
        
        chunks = jnp.array_split(arg, num_chunks, axis=ax)
        
        leading_size = jnp.shape(chunks[0])[ax]
        batch_dims = jax.tree_map(lambda a: jnp.shape(a)[ax], chunks)
        
        padded_chunks = jax.tree_map(partial(pad_along_axis, axis_length=leading_size, axis=ax), chunks)
        
        return padded_chunks, batch_dims

    def vmap_f(*args, **kwargs):  # TODO: Incorporate kwargs?
        splitted = jax.tree_map(split_fun, args, in_axes)
        
        input_chunks, canonical_sizes = jax.tree_transpose(
            jax.tree_structure(args), splitted_treedef, splitted
        )
        out_sizes = [max(jax.tree_leaves(s)) for s in canonical_sizes]
       
        # TODO: use jax.lax.scan? Note the dynamic shapes of jax.lax.slice and that in_axes is not yet supported.
        results = [jax.lax.slice(vmap_fun(*c), (0, ), (s,)) for c, s in zip(input_chunks, out_sizes)]
        
        # TODO: collect all outputs immediately, or use a generator with `yield`?
        out = jax.tree_map(lambda *a: jnp.concatenate(a), *results)
        return out
    
    return vmap_f
def myfun(a, b, c):
    return jnp.square(a) * c + jnp.squeeze(b['val'])


v = jnp.arange(100)

args = (v, {'val': v.reshape(1, 1, -1, 100)}, 0.4)
in_axes = (0, {'val': -1}, None)


vmap_fun = jax.vmap(myfun, in_axes=in_axes)
out = vmap_fun(*args)


for chunk_size in [1, 2, 5, 10, 50]:
    chunk_out = chunked_vmap(myfun, chunk_size, in_axes=in_axes)(*args)
    
    assert jnp.isclose(out, chunk_out).all()  # Runs fine

joeryjoery avatar Jun 30 '22 09:06 joeryjoery

Take a look at jax.experimental.maps.xmap, which is designed for exactly this sort of thing: https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.xmap.html

You can chunk vmap by making use of SerialLoop.

shoyer avatar Jul 15 '22 03:07 shoyer

This here might be something you could use: https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vmap_chunked.html#netket.jax.vmap_chunked

ynotzort avatar Mar 12 '23 17:03 ynotzort

This here might be something you could use: https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vmap_chunked.html#netket.jax.vmap_chunked

Thanks that is an excellent reference. I am of the opinion though that Jax could implement this by default given how generally useful this is.

joeryjoery avatar Mar 13 '23 07:03 joeryjoery

Hi, more than a year later ;P, are there any plans on implementing this by default in Jax? Or would the authors be open for a PR?

joeryjoery avatar Nov 04 '23 08:11 joeryjoery

@mattjj we were just talking about this :)

shoyer avatar Nov 04 '23 18:11 shoyer

If you do get to it, I would love if you also add support for vjp over 'chunked' axes like we did in https://netket.readthedocs.io/en/stable/api/_generated/jax/netket.jax.vjp_chunked.html

PhilipVinc avatar Nov 07 '23 00:11 PhilipVinc

+1 for adding this feature, it would be super useful in several projects for me.

f0uriest avatar Nov 07 '23 03:11 f0uriest

Discussion https://github.com/google/jax/discussions/18398 asks for this as well.

froystig avatar Nov 08 '23 01:11 froystig

+1 upvote for implementing this

spongepuddingg avatar Nov 21 '23 22:11 spongepuddingg

It would be nice if JAX's compiler could automatically convert computations from parallel to sequential when necessary, given a known memory constraint. Has this possibility been discussed anywhere?

carlosgmartin avatar Jul 13 '24 15:07 carlosgmartin

jax.lax.map now has a batch_size argument that will chuck the computation and internally utilize vmap to operate over each batch in parallel. See #19614.

cgarciae avatar Jul 16 '24 11:07 cgarciae

I tried using JAX as an alternative to numba, since it is much easier to work in, and I can switch between CPU and GPU using a flag. The biggest problem I am facing now, is that vmap tries to put all of the data onto the GPU at the same time. I am already using vmap to apply the scalar function onto a 2D array, so I was wondering if and how map could help me to batch the data, so I am not getting of memory errors!

arunoruto avatar Sep 06 '24 13:09 arunoruto

Try the batch_size argument to jax.lax.map? (Added in #19614.)

Think of it as a sequential loop over batches, where each batch is vmapped.

froystig avatar Sep 08 '24 17:09 froystig