flax
flax copied to clipboard
Inconsistency in `keep` argument for save_checkpoint and save_checkpoint_multiprocess
There seems to be an inconsistency in keep argument for save_checkpoint and save_checkpoint_multiprocess
System information
Internal
Problem you have encountered:
I am trying to migrate to using the Orbax Checkpointer w flax.save_checkpoint_multiprocess and have noticed an inconsistency in behavior between
- current and past behavior w flax checkpointing
- save_checkpoint and save_checkpoint_multiprocess
The method save_checkpoint calls save first for the current checkpoint and then removes excessive checkpoints, but the save_checkpoint_multiprocess calls remove first and then saves the current checkpoint.
What you expected to happen:
In the past when specifying keep=N for the save method, the number of checkpoints at the end of the save_checkpoint call was equal to N. Currently when calling save_checkpoint_multiprocess.py the number of checkpoints after the save is equal to N+1.
Logs, error messages, etc:
Internal
In save_checkpoint_multiprocess, we remove old checkpoints before writing new ones, which results in the total number of checkpoints to be keep+1. This is because Orbax AsyncCheckpointer does not support custom on-commit callback and it is impossible to explicitly delete old checkpoints after the current checkpoint is async-saved.
In save_checkpoint, the removal always happens after save, and we kept it that way for backward compatibility. Maybe it makes more sense to unify the behavior. Either way, I should make it more explicit in the API doc.