maxtext
maxtext copied to clipboard
Checkpoint tfds data iterator
Enable deterministic training with preemption when using tfds pipeline by checkpointing data iterator.
Creates a checkpoint handler for data iterator that implements orbax.checkpoint.CheckpointHandler, similar to https://github.com/google/grain/blob/main/grain/_src/python/checkpoint_handlers.py. Handler utilizes tf.train.Checkpoint to save and restore iterator.
Makes checkpointing the data iterator optional, since this method will save large checkpoints. Adds a bool flag to base.yml
Async checkpointing is handled at the level of the orbax checkpoint manager.
Updates input pipeline description to reflect option to checkpoint tfds iterator.