flax icon indicating copy to clipboard operation
flax copied to clipboard

Inconsistency in `keep` argument for save_checkpoint and save_checkpoint_multiprocess

Open priyakasimbeg opened this issue 2 years ago • 1 comments

There seems to be an inconsistency in keep argument for save_checkpoint and save_checkpoint_multiprocess

System information

Internal

Problem you have encountered:

I am trying to migrate to using the Orbax Checkpointer w flax.save_checkpoint_multiprocess and have noticed an inconsistency in behavior between

  1. current and past behavior w flax checkpointing
  2. save_checkpoint and save_checkpoint_multiprocess

The method save_checkpoint calls save first for the current checkpoint and then removes excessive checkpoints, but the save_checkpoint_multiprocess calls remove first and then saves the current checkpoint.

What you expected to happen:

In the past when specifying keep=N for the save method, the number of checkpoints at the end of the save_checkpoint call was equal to N. Currently when calling save_checkpoint_multiprocess.py the number of checkpoints after the save is equal to N+1.

Logs, error messages, etc:

Internal

priyakasimbeg avatar Jul 12 '23 00:07 priyakasimbeg

In save_checkpoint_multiprocess, we remove old checkpoints before writing new ones, which results in the total number of checkpoints to be keep+1. This is because Orbax AsyncCheckpointer does not support custom on-commit callback and it is impossible to explicitly delete old checkpoints after the current checkpoint is async-saved.

In save_checkpoint, the removal always happens after save, and we kept it that way for backward compatibility. Maybe it makes more sense to unify the behavior. Either way, I should make it more explicit in the API doc.

IvyZX avatar Jul 12 '23 20:07 IvyZX