scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

MrVI slowdown due to JAX compilation update

Open justjhong opened this issue 10 months ago • 10 comments

With recent updates to JAX, MrVI trains significantly slower than before. We suspect it is due to the new AOT compilation strategy (https://jax.readthedocs.io/en/latest/aot.html).

Any basic training with MrVI with a fresh install. Reproduced by @PierreBoyeau and myself.

justjhong avatar Feb 10 '25 17:02 justjhong

@justjhong lmk if you want me to upper bound jax for now and to which version. We will probably have a release in the coming week or so, so it can get into there.

ori-kron-wis avatar Feb 11 '25 10:02 ori-kron-wis

Hi @ori-kron-wis, thanks for checking. I took some time this morning to try to debug it but was not able to find a solution. I was able to figure out that the problems arise starting from jax==0.4.36. So for now, let's upper bound to jax<0.4.36 (non-inclusive).

justjhong avatar Feb 11 '25 13:02 justjhong

Hi, here are a few more details about my understanding of the problem. It seems that lightning introduces a significant overhead when training the model, here and there for instance. MrVI is notably faster at this stage without the lightning wrapper.

PierreBoyeau avatar Feb 13 '25 00:02 PierreBoyeau

For now pinning jax<0.4.36. Potentially related to https://github.com/jax-ml/jax/issues/26162. Check again when this is adressed. Leaving this open as pinning circumvents it but might create issues in the near future.

canergen avatar Feb 13 '25 00:02 canergen

@ori-kron-wis can you try with the fix suggested in https://github.com/jax-ml/jax/issues/26162? We can also wait for the next jax release that contains it.

canergen avatar Feb 27 '25 08:02 canergen

The fix is about the XLA compiler used for JAX (i.e jaxlib), we can't just implement it out of the box. It is not merged yet AFAIK.

I did try Jax nightly release (consists of jaxlib nightly), but the issue remains. pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html

We need to wait for the next version that contains it unfortunately.

ori-kron-wis avatar Mar 03 '25 10:03 ori-kron-wis

We also found the Jax 0.4.35 can't work for CUDA12.6+. So need to use CUDA12.4 For newer version of Jax (0.6) it works with CUDA12.6+, but is still slower than before.

ori-kron-wis avatar May 29 '25 06:05 ori-kron-wis

Also, we are limited with jax<0.7.0 (not installed). Jax will become optional in next version.

ori-kron-wis avatar Aug 03 '25 12:08 ori-kron-wis

relates also to: https://github.com/scverse/scvi-tools/pull/3482

ori-kron-wis avatar Sep 11 '25 12:09 ori-kron-wis

limit was removed. Not sure it is still slow, but Jax is either way Optional now.

ori-kron-wis avatar Oct 20 '25 12:10 ori-kron-wis