penzai
penzai copied to clipboard
FR: batched mapping for `named_axes.nmap`
Sometimes nmap
'ed computations don't all fit in memory at once and there are not enough devices to shard the computation over (this is limitation is particularly salient when using penzai because adding an arbitrary number of named axes is so darn convenient :) )
Jax now supports symantics for batched vmapping with jax.lax.map
. This would be awesome to add to Penzai!