dLux
dLux copied to clipboard
Explore the use of xmap under the hood for vectorised chunks to prevent hitting RAM bottlenecks.
Currently vmap batches the whole computation as one which can result in hitting RAM bottlenecks, resulting in a pretty hard-core slow down. Looks like there are some long-term plans to implement this is jax: https://github.com/google/jax/issues/11319.
This issue is simply to track progress over time.
Batching can be done with jax.lax.map
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.map.html