equinox
equinox copied to clipboard
support tensorstore checkpoints using gda_serialization api
I'd like to use equinox for some fairly large-scale training runs, but the state for those models is often too large to fit on a single accelerator, so gathering all the state to serialize with numpy is far from ideal.
Also, checkpoints can be large, and saving them can take a while, so async support is nice to have.
Jax provides a low-level serialization API for individual arrays which should be perfect here, but making it work nicely with arbitrary equinox modules is going to be nontrivial.
Flax creates a directory tree based on the nested structure of their state dicts, maybe something similar could be done with equinox modules?