pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

A new after on_save_checkpoint callback for AsyncCheckpointIO

Open clumsy opened this issue 1 year ago • 0 comments

Description & Motivation

We want to log_artifact, publish tags (e.g. with MLFlowLogger) only after saving the checkpoint is complete (successfully). With AsyncCheckpointIO with does not seem to have a way to know when checkpoint saving is actually complete.

Pitch

It would help to either add a new on_after_save_checkpoint to CheckpointHooks or to a new AsyncCheckpointHooks. Then AsyncCheckpointIO can provide a way to register callbacks for the Future's add_done_callback. When initializing the Trainer can iterate Callbacks implementing the new method/hooks and register them with Strategy or AsyncCheckpointIO.

Regular CheckpointIO can also allow registration but will call the callbacks as soon as save_checkpoint is done.

Alternatives

Currently the users are forced to copy-paste AsyncCheckpointIO and alter it to provide such hook.

Additional context

I'm willing to contribute if there are no concerns with the proposed solution.

cc @borda

clumsy avatar Feb 05 '24 21:02 clumsy