equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Implementation of `filter_map` using `jax.lax.map`

Open JeppeKlitgaard opened this issue 1 year ago • 4 comments

I am really enjoying using filter_vmap and filter_pmap to very easily transform functions that act on PyTrees – Truly a game-changer!

Playing around on a Google Cloud TPU, I have now managed to use the transformations to parallelise over 8 TPU cores, but I am running into problems with my memory usage on the TPU's. I have plenty of memory on the host, but am restricted to ≈8 GB on the accelerator.

It seems a suggested workaround for this would be to use jax.lax.map instead of jax.vmap, I am wondering if would be feasible to make a filter_map that could make use of this and be easily exchangeable with the other filter_{pmap,vmap} functions.

JeppeKlitgaard avatar Apr 17 '23 00:04 JeppeKlitgaard