equinox icon indicating copy to clipboard operation
equinox copied to clipboard

support tensorstore checkpoints using gda_serialization api

Open GallagherCommaJack opened this issue 1 year ago • 11 comments

I'd like to use equinox for some fairly large-scale training runs, but the state for those models is often too large to fit on a single accelerator, so gathering all the state to serialize with numpy is far from ideal.

Also, checkpoints can be large, and saving them can take a while, so async support is nice to have.

Jax provides a low-level serialization API for individual arrays which should be perfect here, but making it work nicely with arbitrary equinox modules is going to be nontrivial.

Flax creates a directory tree based on the nested structure of their state dicts, maybe something similar could be done with equinox modules?

GallagherCommaJack avatar Oct 28 '22 20:10 GallagherCommaJack