jax icon indicating copy to clipboard operation
jax copied to clipboard

Feature request: cache compilations to the filesystem

Open dionhaefner opened this issue 6 years ago • 5 comments

I am building a larger application with Jax, and compiling all the kernels adds about a minute of overhead until we're ready to run.

It would be great if Jax were able to cache compilations, like Numba's cache=True argument to their JIT decorator.

dionhaefner avatar Mar 23 '20 08:03 dionhaefner

(I edited the title based on my understanding of what you mean, and a brief read of some numba docs. Let me know if I got that wrong.)

mattjj avatar Mar 24 '20 03:03 mattjj

Spot on!

dionhaefner avatar Mar 24 '20 07:03 dionhaefner

This is now implemented in jax.experimental: https://github.com/google/jax/blob/main/jax/experimental/compilation_cache/compilation_cache.py (I had nothing to do with the implementation.)

Not sure if the docs have this yet, but all you need to do is import it and call compilation_cache.initialize_cache(path-to-where-you-want-to-keep-the-cache).

mhauru avatar Aug 24 '21 14:08 mhauru

This looks awesome, but unfortunately it only works on TPU so far:

https://github.com/google/jax/commit/8190286b5368edc17553dba3f7fbd5f2a7496134#diff-121975a1a88d61a259f8423655e23a1b389eb18c399b267395d2aba5a196bae9R70-R71

dionhaefner avatar Aug 24 '21 16:08 dionhaefner

Is there any ongoing work to generalize this to CPU and GPU backends?

ncoish avatar Sep 17 '21 19:09 ncoish

Incidentally, this should work on GPU these days.

hawkinsp avatar May 19 '23 18:05 hawkinsp

Are the current best practices for caching documented somewhere?

clemisch avatar May 21 '23 16:05 clemisch