maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Checkpoint tfds data iterator

Open mattdonati opened this issue 4 months ago • 0 comments

Enable deterministic training with preemption when using tfds pipeline by checkpointing data iterator.

Creates a checkpoint handler for data iterator that implements orbax.checkpoint.CheckpointHandler, similar to https://github.com/google/grain/blob/main/grain/_src/python/checkpoint_handlers.py. Handler utilizes tf.train.Checkpoint to save and restore iterator.

Makes checkpointing the data iterator optional, since this method will save large checkpoints. Adds a bool flag to base.yml

Async checkpointing is handled at the level of the orbax checkpoint manager.

Updates input pipeline description to reflect option to checkpoint tfds iterator.

mattdonati avatar Oct 07 '24 20:10 mattdonati