orbax
orbax copied to clipboard
Is it possible to store checkpoints in an external storage such as S3?
I wasn't able to find the answers to my questions in the docs, so I'll just ask here:
- What storage types other than local filesystem are supported with orbax? For instance, can I use S3?
- Is it possible to add my own storage type somehow?
Thanks!
We support a Google-internal distributed file system as well as Google Cloud storage. No idea if any issues would be encountered with S3, but you could give it a try.
Depending on what issues you encounter, if any, implementing your own TypeHandler
s and AggregateHandler
would probably be the best approach to customize serialization / deserialization logic if you need to. See here: https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.html. Once implemented, you just register the handlers to start using them.
One way to support a large number of various filesystems would be to use fsspec for reading/writing weight files. Is that something the orbax/jax team might consider?
A temporary workaround is to save to a temp directory and copy the saved content to the remote file system, though this wouldn't work so easily with the checkpoint manager (e.g., only save the last n
checkpoints)
There's a recent change to offer better support for this problem. Previously S3 would not work correctly because atomic rename was not supported, but alternative atomicity logic can be configured using checkpoint/orbax/checkpoint/path/atomicity.py.