maxtext
maxtext copied to clipboard
Integrate emergency checkpointer into standalone_checkpointer for CPUs.
- When local checkpoints are available for restore, alter mesh setup as follows.
- Ignore the JAX coordinator provided by XPK and override the JAX coordinator to be the pod containing process_id 0.
- Persist the IP address of this pod using socket APIs to GCS.
- Let other processes retrieve the address of this coordinator, similar to TPUs.
-
Restore state and not state["items"] if emergency checkpoint manager is enabled.
-
Some changes in Jax/Orbax are needed to get this working end-to-end on CPUs. Other changes are addressed in this PR.