jax icon indicating copy to clipboard operation
jax copied to clipboard

When caching is enabled, also enable XLA caching features as well

Open trevor-m opened this issue 1 year ago • 7 comments

This PR makes it easier to enable all of the caching features in JAX and XLA with a single option. Now, when the JAX persistent cache is enabled (JAX_COMPILATION_CACHE_DIR), some XLA caching features will also be enabled to subdirectories of the JAX cache dir. The XLA caching features that are used can be selected via JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES.

Currently, there is an issue related to kernel naming when both xla_gpu_kernel_cache_file and the JAX persistent cache are enabled together, so only the autotune cache is enabled by default now. Once this is fixed, the default value of JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES should be all.

~Requires https://github.com/openxla/xla/pull/15636~ Requires https://github.com/openxla/xla/pull/18450

trevor-m avatar Aug 06 '24 18:08 trevor-m

@nouiz

trevor-m avatar Aug 06 '24 18:08 trevor-m

The required XLA PR is merged: https://github.com/openxla/xla/pull/15636 @hawkinsp can you review this PR?

nouiz avatar Aug 20 '24 23:08 nouiz

The change is fine, but there are CI failures (possibly stale).

@hawkinsp Thanks for reviewing! I've rebased which should fix the CI failures

trevor-m avatar Sep 25 '24 17:09 trevor-m

One more thing: please squash your commits.

hawkinsp avatar Sep 27 '24 12:09 hawkinsp

@trevor-m — Thanks for your patience here! Can you rebase your PR onto the current main branch? We'll get this in ASAP after that. Thanks!

dfm avatar Oct 16 '24 17:10 dfm

@dfm Thanks for looking at this. However, we may need to hold off merging this a bit longer. We think there will be issues when using this feature with multihost. To solve it, we can set xla_gpu_experimental_autotune_cache_mode to update for rank 0 only and set it to read for the other ranks. We will need to expose that flag in the xla python bindings first.

We will need to do something similar for the kernel cache.

trevor-m avatar Oct 16 '24 18:10 trevor-m

@dfm I've opened https://github.com/openxla/xla/pull/18450 to expose the cache mode and updated this PR to set it to update for process 0 and read-only for the other processes. I confirmed this fixes the issue with multihost.

trevor-m avatar Oct 17 '24 18:10 trevor-m

This is blocked on https://github.com/openxla/xla/pull/18450, right?

hawkinsp avatar Oct 30 '24 16:10 hawkinsp

This is blocked on openxla/xla#18450, right?

Right, it got approved twice. I'm not sure why it isn't merged. CI complain about clang-format issue on not changed code: https://github.com/openxla/xla/actions/runs/11390477508/job/31692190615?pr=18450

nouiz avatar Oct 30 '24 16:10 nouiz

@hawkinsp The prerequisite MR is now merged

trevor-m avatar Nov 04 '24 17:11 trevor-m

It looks like all the new tests need to be conditioned on the jaxlib version. The issue here is that most of the CI jobs run with the released version of jaxlib, which doesn't include the AutotuneCacheMode enum yet. Something like:

from jax._lib import version as jaxlib_version

and then, in the test:

if jaxlib_version <= (0, 4, 35):
  self.skipTest("...")  # <- add a description here

should do the trick!

dfm avatar Nov 06 '24 10:11 dfm

It also looks like you need to add an explicit dependency on :path here:

https://github.com/jax-ml/jax/blob/4b4fb9dae9eb7e2740d70de5b4a610f979530382/jax/BUILD#L425-L438

dfm avatar Nov 06 '24 10:11 dfm

Thanks @dfm, I've added those changes.

trevor-m avatar Nov 06 '24 18:11 trevor-m

@trevor-m — This should be good to go now! ~~but can you rebase onto the current main branch so that we can rerun the import?~~ Quite a few internal google tests were still failing because https://github.com/openxla/xla/pull/19101 hadn't been merged yet, but now that that's in, I expect everything should be green. Thanks for your patience with this!

Edit: it looks like maybe the import just needed me to re-approve. You shouldn't need to rebase after all. I'll let you know if there are any issues!

dfm avatar Nov 08 '24 16:11 dfm