Feature request: cache compilations to the filesystem
I am building a larger application with Jax, and compiling all the kernels adds about a minute of overhead until we're ready to run.
It would be great if Jax were able to cache compilations, like Numba's cache=True argument to their JIT decorator.
(I edited the title based on my understanding of what you mean, and a brief read of some numba docs. Let me know if I got that wrong.)
Spot on!
This is now implemented in jax.experimental: https://github.com/google/jax/blob/main/jax/experimental/compilation_cache/compilation_cache.py
(I had nothing to do with the implementation.)
Not sure if the docs have this yet, but all you need to do is import it and call compilation_cache.initialize_cache(path-to-where-you-want-to-keep-the-cache).
This looks awesome, but unfortunately it only works on TPU so far:
https://github.com/google/jax/commit/8190286b5368edc17553dba3f7fbd5f2a7496134#diff-121975a1a88d61a259f8423655e23a1b389eb18c399b267395d2aba5a196bae9R70-R71
Is there any ongoing work to generalize this to CPU and GPU backends?
Incidentally, this should work on GPU these days.
Are the current best practices for caching documented somewhere?