jax icon indicating copy to clipboard operation
jax copied to clipboard

upgrade map

Open cgarciae opened this issue 1 year ago • 6 comments

Changes

Adds in_axes, out_axes, and batch_size arguments to map. in_axes and out_axes let you select the axis per array to iterate over, similar to vmap. batch_size allows processing elements in a batches and processed in parallel using vmap. With these changes map can be used as a memory-efficient alternative to vmap.

Original description: Upstreams batch_vmap implementation by @shoyer, it has similar semantics to vmap but its executed in independent baches using scan to reduce memory consumption.

cgarciae avatar Feb 01 '24 12:02 cgarciae

Looks like a great start! Cross-referencing a couple related issues: #11319 #18398

jakevdp avatar Feb 01 '24 18:02 jakevdp

Any update on this ? this came again in cl/606762040

Conchylicultor avatar Feb 22 '24 08:02 Conchylicultor

Hi, I think this would be a great addition to jax!

I have two minor comments:

  • When we implemented batch_vmap in netket a few years ago, we soon found out that we had a large need for an batch_apply as well. Actually, in our codebase we have more uses of batch_apply than of batch_vmap.
    • An added benefit is that batch_vmap can be implemented as a thin shim on top of a batch_apply .

Also, naming wise:

  • batch_vmap seems more like a function that batches a vmap. As this function is performing a batched vmap, shouldn't the name be batched_vmap?
  • batching means many different things in ML jargon. I personally don't associate it to computing things in distinct pieces to reduce memory cost.
    • In jax jargon, I think that scanning is the right term instead of batching, so I find that scan_vmap (or scanned_vmap) would make more sense.
    • However this is very tax-specific and users less experienced with jax control flow might find this non clear. I personally find that chunk_vmap (or chunked_vmap) would be more appropriate. Then again, I'm partisan because we've had this for a while.

PhilipVinc avatar Feb 26 '24 17:02 PhilipVinc

Update: cleaned up the implementation a bit, required all in_axis dimension to be divisible by batch_size, removed the reminders, and improved the docstring.

@shoyer @mattjj seems like the current implementation doesn't support None in the out_axis, do we want to support this?

cgarciae avatar Mar 20 '24 11:03 cgarciae

@PhilipVinc batched_vmap does sounds appropriate, I'll do the renaming unless someone has a different opinion.

cgarciae avatar Mar 20 '24 11:03 cgarciae

I'm curious about the reasoning for putting this into map vs vmap? I do think the later is a more accessible/widely used API, but maybe that also makes it higher risk?

shoyer avatar May 23 '24 04:05 shoyer

  1. Re @shoyer's question about map vs vmap, batching would essentially make jax.vmap and lax.map 2 extremes on a continuum. In both cases, the input is partitioned and vmap is applied to each batch in serial. Standard vmap is like batched map with an unbounded batchsize, and standard map is like batched vmap with a batchsize of 1 (but without in_axes, etc). So, perhaps it's reasonable to include batching first in map without worrying about in_axes/out_axes. Later, batching could be included in vmap, and then map could be subsumed into vmap (which I agree has a much better interface).

  2. I completely agree with @BrandonSmithJ that the value of batched map/vmap is greatly increased if it handles input sizes that are not divisible by batch_size. This comes up very often in scientific applications, and having to deal with this outside of map/vmap would be really clunky and lead to a lot of code duplication.

gbuzzard avatar May 29 '24 14:05 gbuzzard

@shoyer There as pros and cons for both but not to much. map is a safer option.

cgarciae avatar May 29 '24 14:05 cgarciae

@cgarciae FYI @inailuig noticed a while ago that the approach you are using here (which we also are using) breaks sharding propagation if you chunk an axis that is sharded. Basically Jax ends up replicating the arrays after the chunked shard map

to fix it we had to resort to shard map and sharding constraints.

it would be nice if this could be done automatically by Jax…

PhilipVinc avatar Jun 05 '24 17:06 PhilipVinc

@cgarciae FYI @inailuig noticed a while ago that the approach you are using here (which we also are using) breaks sharding propagation if you chunk an axis that is sharded.

+1 it would be really nice to implement shard(map) as scan(shard(vmap(...))). This could potentially be a follow-on improvement, though...

shoyer avatar Jun 05 '24 18:06 shoyer