penzai icon indicating copy to clipboard operation
penzai copied to clipboard

FR: batched mapping for `named_axes.nmap`

Open amifalk opened this issue 7 months ago • 5 comments

Sometimes nmap'ed computations don't all fit in memory at once and there are not enough devices to shard the computation over (this is limitation is particularly salient when using penzai because adding an arbitrary number of named axes is so darn convenient :) )

Jax now supports symantics for batched vmapping with jax.lax.map. This would be awesome to add to Penzai!

amifalk avatar Jul 26 '24 19:07 amifalk