jax
jax copied to clipboard
Allow compilation cache to be saved from process indices that are not process index 0
At present, this check stops the compilation cache from being written on any process that is not process index 0. This makes sense if the compilation cache directory resides on shared storage. However, in our case, we do not wish to put this directory on shared storage and would instead prefer to save it on every process. WDYT of having an enum state jax_persistent_cache_write with values ['always', 'never', 'on_process_0'] to control this behavior?