jax icon indicating copy to clipboard operation
jax copied to clipboard

Batched version of jax.vmap

Open michael-0brien opened this issue 6 months ago • 1 comments

Hi, I was hoping to understand better and perhaps re-open the discussion in #19614, regarding adding the batch_size argument to jax.lax.map vs jax.vmap. I understand that modifying jax.lax.map is a much simpler endeavor, so the conclusions there make sense.

However, I wanted to advocate adding the batch size to jax.vmap for a future PR. I see there was mention in #19614 that a reason for adding a batch_size to jax.vmap was that it was a more widely-used API. I would go even a step further than this and say that for us developers 1) In memory-limited applications 2) Who dev on memory-limited GPU hardware, we become limited to designing scripts and library code based jax.lax.map. This can make for significantly more cumbersome development + barrier to entry than is reqiured with jax.vmap.

Consider the following (small) example.

With jax.lax.map:

def f_vmap(arr: jax.Array, static_arg: bool) -> jax.Array:
    f_wrapped = lambda x: f(x, static_arg)
    return jax.lax.map(f_wrapped, arr, batch_size=10)

def f(arr: jax.Array, static_arg: bool) -> jax.Array:
    return ...

With jax.vmap:

@partial(jax.vmap, in_axes=[0, None], batch_size=10)
def f_vmap(arr: jax.Array, static_arg: bool) -> jax.Array:
    return ...

While this is a toy example, I mean to portray that the wrappers in jax.lax.map can start to add significant burden over the course of a complex piece / body of code. I think that the jax.vmap paradigm (or more broadly the function transformation paradigm) has a wonderfully low barrier to entry to new JAX users, and it seems to me that there would be a lot of value added by allowing memory-limited users to stick to that!

michael-0brien avatar Jun 19 '25 15:06 michael-0brien

For more context, I develop a JAX package for scientific application (https://github.com/mjo22/cryojax/). In my discipline, batching is very common but there are many different ways downstream users may want to batch. My design philosophy is therefore to provide users with the un-batched version of what they need, and then tell them to use the power of JAX to build their particular workflow. Our discipline is most of the time fundamentally memory-limited, and most of my users are not JAX programmers.

What this means for us is these new JAX programmers either need to dive into jax.lax when starting out, or they need to write python for-loops more so than is hoped for with JAX. If this feature were added, users could focus their learning solely on JAX's great function transformation paradigm!

michael-0brien avatar Jun 19 '25 16:06 michael-0brien

Thanks for raising this so clearly!

The direction sounds right to me, though I have two questions.

First I'm curious if you'd have a strong preference between extending the jax.vmap API itself vs extending a convenience wrapper like jax.lax.map (or adding a new one) to (1) have in_axes/out_axes and (2) be curried like a function transformation. That is, do you care much if your example is written like

@partial(jax.vmap, in_axes=[0, None], batch_size=10)
def f_vmap(arr: jax.Array, static_arg: bool) -> jax.Array:
    return ...

vs

@partial(jax.lax.map, in_axes=[0, None], batch_size=10)
def f_vmap(arr: jax.Array, static_arg: bool) -> jax.Array:
    return ...

vs writing it with some new convenience wrapper like, I dno, jax.bmap? Or do you mainly care about having one API function you can call that offers looping, vectorization, and in/out axes all in one place?

I ask that mainly because I lean towards keeping the fundamental APIs (like jax.vmap) as simple as possible, then offering built-in convenience wrappers for common compositions.

A second issue is that because (a) this would presumably be a convenience wrapper around scan+vmap, and because (b) scan only scans over leading axes, we'd currently have to implement it with transposes on the inputs and outputs. We could avoid those transposes if we extend scan to have in_axes/out_axes. I can imagine that being an important optimization in some cases. Should we consider generalizing scan to ahve in_axes/out_axes as a prerequisite here?

mattjj avatar Jun 21 '25 18:06 mattjj

@mattjj

Should we consider generalizing scan to ahve in_axes/out_axes as a prerequisite here?

I'm in favor of this part. There was an issue for it at https://github.com/jax-ml/jax/issues/2509.

carlosgmartin avatar Jun 21 '25 21:06 carlosgmartin

I don’t have strong preference between generalizing the jax.vmap vs writing a new API! I can definitely see why the latter would be preferable and perhaps even more powerful.

Taking more of a birds eye view, I think the core aspect of this is giving users the tools to stay in the function transformation paradigm. From my standpoint developing a library and also observing people learn JAX, writing scans / maps takes users out of the paradigm enough that there starts to be a bit of a cost to JAX’s usability, of course varying from individual to individual.

In that sense, perhaps as you say with adding in_axes and out_axes to a jax.lax.scan-like transformation would address my specific request but also this larger picture. I’ll just add that similar to vmap, None must also be supported, so that users can stick to using decorators rather than in-line wrappers for static arguments. Would this mean that a scan-like transformation would be migrated to the core jax namespace, as well as some kind of jax.bmap as you say? I would certainly find this very useful.

I’m happy with whatever implementation you all think is best as long as highlighting function transformations is the priority!

michael-0brien avatar Jun 21 '25 23:06 michael-0brien

@mattjj @carlosgmartin opened a new issue for the more general points discussed here: #30528

michael-0brien avatar Jul 26 '25 16:07 michael-0brien