axlearn
axlearn copied to clipboard
Add Orbax Checkpointing
At this time, you need to use Orbax Checkpointing to test out in-memory checkpointing. It is also better integrated with Pathways, when the time to test that comes.
This PR is currently a draft. It does the following: (more notes are here)
- Adds the option to use an orbax checkpointer
- Renames the original checkpointer to
StateStorageCheckpointer - Sets
use_orbaxtoFalseby default for models underexperiments/text/gptfolder
To test orbax checkpointer:
- Manually set
cfg.checkpointer.use_orbax = True - Launch training as usual and the checkpoints will be stored in the expected
checkpointsfolder in GCS
Planned work on my side:
- Add more comments
- Upgrade
orbax-checkpointversion and consider pinning it inpyproject.toml - Test on TPUs
- Test in-memory checkpointing with Orbax
@markblee Do you mind taking a look and giving some feedback? Please feel free to edit branch directly.