jax
jax copied to clipboard
upgrade map
Changes
Adds in_axes
, out_axes
, and batch_size
arguments to map
. in_axes
and out_axes
let you select the axis per array to iterate over, similar to vmap
. batch_size
allows processing elements in a batches and processed in parallel using vmap
. With these changes map
can be used as a memory-efficient alternative to vmap
.
Original description: Upstreams batch_vmap
implementation by @shoyer, it has similar semantics to vmap
but its executed in independent baches using scan
to reduce memory consumption.
Looks like a great start! Cross-referencing a couple related issues: #11319 #18398
Any update on this ? this came again in cl/606762040
Hi, I think this would be a great addition to jax!
I have two minor comments:
- When we implemented batch_vmap in netket a few years ago, we soon found out that we had a large need for an batch_apply as well. Actually, in our codebase we have more uses of
batch_apply
than ofbatch_vmap
.- An added benefit is that
batch_vmap
can be implemented as a thin shim on top of abatch_apply
.
- An added benefit is that
Also, naming wise:
-
batch_vmap
seems more like a function thatbatches
avmap
. As this function is performing abatched
vmap, shouldn't the name bebatched_vmap
? -
batching
means many different things in ML jargon. I personally don't associate it to computing things in distinct pieces to reduce memory cost.- In jax jargon, I think that
scanning
is the right term instead ofbatching
, so I find thatscan_vmap
(orscanned_vmap
) would make more sense. - However this is very tax-specific and users less experienced with jax control flow might find this non clear. I personally find that
chunk_vmap
(orchunked_vmap
) would be more appropriate. Then again, I'm partisan because we've had this for a while.
- In jax jargon, I think that
Update: cleaned up the implementation a bit, required all in_axis
dimension to be divisible by batch_size
, removed the reminders, and improved the docstring.
@shoyer @mattjj seems like the current implementation doesn't support None
in the out_axis
, do we want to support this?
@PhilipVinc batched_vmap
does sounds appropriate, I'll do the renaming unless someone has a different opinion.