dLux icon indicating copy to clipboard operation
dLux copied to clipboard

Explore the use of xmap under the hood for vectorised chunks to prevent hitting RAM bottlenecks.

Open LouisDesdoigts opened this issue 1 year ago • 1 comments

Currently vmap batches the whole computation as one which can result in hitting RAM bottlenecks, resulting in a pretty hard-core slow down. Looks like there are some long-term plans to implement this is jax: https://github.com/google/jax/issues/11319.

This issue is simply to track progress over time.

LouisDesdoigts avatar Jun 21 '23 02:06 LouisDesdoigts

Batching can be done with jax.lax.map https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html

LouisDesdoigts avatar Jul 24 '23 13:07 LouisDesdoigts