MrVI slowdown due to JAX compilation update
With recent updates to JAX, MrVI trains significantly slower than before. We suspect it is due to the new AOT compilation strategy (https://jax.readthedocs.io/en/latest/aot.html).
Any basic training with MrVI with a fresh install. Reproduced by @PierreBoyeau and myself.
@justjhong lmk if you want me to upper bound jax for now and to which version. We will probably have a release in the coming week or so, so it can get into there.
Hi @ori-kron-wis, thanks for checking. I took some time this morning to try to debug it but was not able to find a solution. I was able to figure out that the problems arise starting from jax==0.4.36. So for now, let's upper bound to jax<0.4.36 (non-inclusive).
Hi, here are a few more details about my understanding of the problem. It seems that lightning introduces a significant overhead when training the model, here and there for instance. MrVI is notably faster at this stage without the lightning wrapper.
For now pinning jax<0.4.36. Potentially related to https://github.com/jax-ml/jax/issues/26162. Check again when this is adressed. Leaving this open as pinning circumvents it but might create issues in the near future.
@ori-kron-wis can you try with the fix suggested in https://github.com/jax-ml/jax/issues/26162? We can also wait for the next jax release that contains it.
The fix is about the XLA compiler used for JAX (i.e jaxlib), we can't just implement it out of the box. It is not merged yet AFAIK.
I did try Jax nightly release (consists of jaxlib nightly), but the issue remains. pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
We need to wait for the next version that contains it unfortunately.
We also found the Jax 0.4.35 can't work for CUDA12.6+. So need to use CUDA12.4 For newer version of Jax (0.6) it works with CUDA12.6+, but is still slower than before.
Also, we are limited with jax<0.7.0 (not installed). Jax will become optional in next version.
relates also to: https://github.com/scverse/scvi-tools/pull/3482
limit was removed. Not sure it is still slow, but Jax is either way Optional now.