When caching is enabled, also enable XLA caching features as well
This PR makes it easier to enable all of the caching features in JAX and XLA with a single option. Now, when the JAX persistent cache is enabled (JAX_COMPILATION_CACHE_DIR), some XLA caching features will also be enabled to subdirectories of the JAX cache dir. The XLA caching features that are used can be selected via JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES.
Currently, there is an issue related to kernel naming when both xla_gpu_kernel_cache_file and the JAX persistent cache are enabled together, so only the autotune cache is enabled by default now. Once this is fixed, the default value of JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES should be all.
~Requires https://github.com/openxla/xla/pull/15636~ Requires https://github.com/openxla/xla/pull/18450
@nouiz
The required XLA PR is merged: https://github.com/openxla/xla/pull/15636 @hawkinsp can you review this PR?
The change is fine, but there are CI failures (possibly stale).
@hawkinsp Thanks for reviewing! I've rebased which should fix the CI failures
One more thing: please squash your commits.
@trevor-m — Thanks for your patience here! Can you rebase your PR onto the current main branch? We'll get this in ASAP after that. Thanks!
@dfm Thanks for looking at this. However, we may need to hold off merging this a bit longer. We think there will be issues when using this feature with multihost. To solve it, we can set xla_gpu_experimental_autotune_cache_mode to update for rank 0 only and set it to read for the other ranks. We will need to expose that flag in the xla python bindings first.
We will need to do something similar for the kernel cache.
@dfm I've opened https://github.com/openxla/xla/pull/18450 to expose the cache mode and updated this PR to set it to update for process 0 and read-only for the other processes. I confirmed this fixes the issue with multihost.
This is blocked on https://github.com/openxla/xla/pull/18450, right?
This is blocked on openxla/xla#18450, right?
Right, it got approved twice. I'm not sure why it isn't merged. CI complain about clang-format issue on not changed code: https://github.com/openxla/xla/actions/runs/11390477508/job/31692190615?pr=18450
@hawkinsp The prerequisite MR is now merged
It looks like all the new tests need to be conditioned on the jaxlib version. The issue here is that most of the CI jobs run with the released version of jaxlib, which doesn't include the AutotuneCacheMode enum yet. Something like:
from jax._lib import version as jaxlib_version
and then, in the test:
if jaxlib_version <= (0, 4, 35):
self.skipTest("...") # <- add a description here
should do the trick!
It also looks like you need to add an explicit dependency on :path here:
https://github.com/jax-ml/jax/blob/4b4fb9dae9eb7e2740d70de5b4a610f979530382/jax/BUILD#L425-L438
Thanks @dfm, I've added those changes.
@trevor-m — This should be good to go now! ~~but can you rebase onto the current main branch so that we can rerun the import?~~ Quite a few internal google tests were still failing because https://github.com/openxla/xla/pull/19101 hadn't been merged yet, but now that that's in, I expect everything should be green. Thanks for your patience with this!
Edit: it looks like maybe the import just needed me to re-approve. You shouldn't need to rebase after all. I'll let you know if there are any issues!