jax icon indicating copy to clipboard operation
jax copied to clipboard

upgrade map

Open cgarciae opened this issue 5 months ago • 6 comments

Changes

Adds in_axes, out_axes, and batch_size arguments to map. in_axes and out_axes let you select the axis per array to iterate over, similar to vmap. batch_size allows processing elements in a batches and processed in parallel using vmap. With these changes map can be used as a memory-efficient alternative to vmap.

Original description: Upstreams batch_vmap implementation by @shoyer, it has similar semantics to vmap but its executed in independent baches using scan to reduce memory consumption.

cgarciae avatar Feb 01 '24 12:02 cgarciae

Looks like a great start! Cross-referencing a couple related issues: #11319 #18398

jakevdp avatar Feb 01 '24 18:02 jakevdp

Any update on this ? this came again in cl/606762040

Conchylicultor avatar Feb 22 '24 08:02 Conchylicultor

Hi, I think this would be a great addition to jax!

I have two minor comments:

  • When we implemented batch_vmap in netket a few years ago, we soon found out that we had a large need for an batch_apply as well. Actually, in our codebase we have more uses of batch_apply than of batch_vmap.
    • An added benefit is that batch_vmap can be implemented as a thin shim on top of a batch_apply .

Also, naming wise:

  • batch_vmap seems more like a function that batches a vmap. As this function is performing a batched vmap, shouldn't the name be batched_vmap?
  • batching means many different things in ML jargon. I personally don't associate it to computing things in distinct pieces to reduce memory cost.
    • In jax jargon, I think that scanning is the right term instead of batching, so I find that scan_vmap (or scanned_vmap) would make more sense.
    • However this is very tax-specific and users less experienced with jax control flow might find this non clear. I personally find that chunk_vmap (or chunked_vmap) would be more appropriate. Then again, I'm partisan because we've had this for a while.

PhilipVinc avatar Feb 26 '24 17:02 PhilipVinc

Update: cleaned up the implementation a bit, required all in_axis dimension to be divisible by batch_size, removed the reminders, and improved the docstring.

@shoyer @mattjj seems like the current implementation doesn't support None in the out_axis, do we want to support this?

cgarciae avatar Mar 20 '24 11:03 cgarciae

@PhilipVinc batched_vmap does sounds appropriate, I'll do the renaming unless someone has a different opinion.

cgarciae avatar Mar 20 '24 11:03 cgarciae