axlearn icon indicating copy to clipboard operation
axlearn copied to clipboard

Add Orbax Checkpointing

Open jiya-zhang opened this issue 1 year ago • 0 comments

At this time, you need to use Orbax Checkpointing to test out in-memory checkpointing. It is also better integrated with Pathways, when the time to test that comes.

This PR is currently a draft. It does the following: (more notes are here)

  • Adds the option to use an orbax checkpointer
  • Renames the original checkpointer to StateStorageCheckpointer
  • Sets use_orbax to False by default for models under experiments/text/gpt folder

To test orbax checkpointer:

  • Manually set cfg.checkpointer.use_orbax = True
  • Launch training as usual and the checkpoints will be stored in the expected checkpoints folder in GCS

Planned work on my side:

  • Add more comments
  • Upgrade orbax-checkpoint version and consider pinning it in pyproject.toml
  • Test on TPUs
  • Test in-memory checkpointing with Orbax

@markblee Do you mind taking a look and giving some feedback? Please feel free to edit branch directly.

jiya-zhang avatar Aug 05 '24 16:08 jiya-zhang