dvclive
dvclive copied to clipboard
Make checkpoints part of the logger setup?
An idea (that probably has been discussed a lot already), just want to come back to it, since it bothers me that I can't migrate fully / easy to the DVC logger from the Keras (other frameworks?):
Here is the piece of code:
checkpoint = tf.keras.callbacks.ModelCheckpoint(
os.path.join("model", "best_model"),
monitor="val_accuracy",
mode="max",
verbose=1,
save_best_only=True,
save_weights_only=True,
)
history = model.fit(
train,
validation_data=valid,
epochs=params['epochs'],
callbacks=[checkpoint, DvcLiveCallback()],
)
Can we make DvcLiveCallback handle checkpoints (and accept the same set of params, e.g. track the best model only, etc)? It would simplify the migration a lot.
At the moment, it's impossible achieve the same. Even if start saving all checkpoints, it's not clear how then we should be getting the best one programmatically (e.g. to do inference later).
Related: https://github.com/iterative/dvclive/issues/89.
@shcheklein Are you encountering an issue when using both tf.keras.callbacks.ModelCheckpoint and DvcLiveCallback? Or you want to avoid needing tf.keras.callbacks.ModelCheckpoint?
I don't encounter problems but I have to use both effectively and have to manually report certain metrics after the train loop is done. Also (a bigger question) checkpoints themselves - we do have our own mechanism, but it's not clear if we can integrate it into this workflow.
have to manually report certain metrics after the train loop is done
Like best accuracy? It doesn't seem like keras makes it easy to retrieve this info, but maybe we can make it easier on users here and have some option like DvcLiveCallback(summary="last"/"best")?
Edit: related to https://github.com/iterative/dvclive/issues/89
Leaving aside what our plans are for DVC checkpoints, the primary need here seems to be auto-integrating with frameworks' existing checkpointing features (at least to start; we might also want to go further than what exists in the frameworks).
Can we make
DvcLiveCallbackhandle checkpoints (and accept the same set of params, e.g. track the best model only, etc)? It would simplify the migration a lot.
Do we really want to repeat these params in the dvclive callback? I would rather go with an approach similar to what's in lightning: include some default checkpointing, but if people want to customize, they can continue to use the separate checkpoint callbacks and hopefully dvclive can be aware of these.
If we really need to do something relevant (i.e. #305 / #378) with the models down the road, I would prefer to go in the direction of making DVCLive aware of the existing checkpoint callback and remove our existing logic for saving (and especially not duplicate args). Right now we just duplicate existing logic but with fewer features for no real reason.
I think this awareness is feasible to do in most frameworks (and what other loggers already do for example in pytorch lightning.
Closing in favor of log_model issues (i.e. #665