equinox
equinox copied to clipboard
Implementation of `filter_map` using `jax.lax.map`
I am really enjoying using filter_vmap
and filter_pmap
to very easily transform functions that act on PyTrees – Truly a game-changer!
Playing around on a Google Cloud TPU, I have now managed to use the transformations to parallelise over 8 TPU cores, but I am running into problems with my memory usage on the TPU's. I have plenty of memory on the host, but am restricted to ≈8 GB on the accelerator.
It seems a suggested workaround for this would be to use jax.lax.map
instead of jax.vmap
, I am wondering if would be feasible to make a filter_map
that could make use of this and be easily exchangeable with the other filter_{pmap,vmap}
functions.