pytorch-lightning
pytorch-lightning copied to clipboard
A new after on_save_checkpoint callback for AsyncCheckpointIO
Description & Motivation
We want to log_artifact, publish tags (e.g. with MLFlowLogger) only after saving the checkpoint is complete (successfully). With AsyncCheckpointIO with does not seem to have a way to know when checkpoint saving is actually complete.
Pitch
It would help to either add a new on_after_save_checkpoint to CheckpointHooks or to a new AsyncCheckpointHooks.
Then AsyncCheckpointIO can provide a way to register callbacks for the Future's add_done_callback. When initializing the Trainer can iterate Callbacks implementing the new method/hooks and register them with Strategy or AsyncCheckpointIO.
Regular CheckpointIO can also allow registration but will call the callbacks as soon as save_checkpoint is done.
Alternatives
Currently the users are forced to copy-paste AsyncCheckpointIO and alter it to provide such hook.
Additional context
I'm willing to contribute if there are no concerns with the proposed solution.
cc @borda