maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Integrate emergency checkpointer into standalone_checkpointer for CPUs.

Open RoshaniN opened this issue 7 months ago • 1 comments

  1. When local checkpoints are available for restore, alter mesh setup as follows.
  • Ignore the JAX coordinator provided by XPK and override the JAX coordinator to be the pod containing process_id 0.
  • Persist the IP address of this pod using socket APIs to GCS.
  • Let other processes retrieve the address of this coordinator, similar to TPUs.
  1. Restore state and not state["items"] if emergency checkpoint manager is enabled.

  2. Some changes in Jax/Orbax are needed to get this working end-to-end on CPUs. Other changes are addressed in this PR.

RoshaniN avatar Jul 12 '24 22:07 RoshaniN