Persist best epoch per trial
System information.
TensorFlow version (you are using): 2.9.2
Feature and the current behavior/state
Currently, For train_end() method of keras callback, the output of the last call to on_epoch_end() is saved as logs, though it is not the best epoch for a given trial. Apparently it must be the output of best of all epochs.
Description
- Why it would be helpful to have the best epoch to be passed to train_end() method is because, at the end of the trial all we would be interested in, is the best metrics only with given objectives to monitor(val_f1/f1/loss..).
- Currently its misleading to see the latest epoch metrics being logged, even though its not the best one
- This is causing major issues while integrating with platforms like MLFLow using callbacks, because for every trial the latest epoch's model and metrics are stored, though they are not the best ones among all the epochs for a particular trial.
- On top of it, it would be seamless to log and view the best metrics for a trial. I don't see a way currently to view the best epoch and metrics for a given trial.
Will this change the current api? How? No
Who will benefit from this feature? Everyone who is keen on logging and monitoring the metrics for every trial
@rakshithvsk, Could you please elaborate about your Feature. Also, please specify the Use Cases for this feature. Thank you!
@tilakrayal Let me explain with a requirement...
If I want to train my model with trials=1 and epoch=10, ideally out of 10 epochs for a trial, the best one can fall anyway between Epoch-0 and Epoch-9. Now at end of every trial, I need to get the details of best epoch for a given trial(Details can include metrics, model). Currently I see that this is possible with "keras.callbacks.Callback.on_train_end()" callback.
However, As I noticed, the input for "on_train_end()" method is output of last call to epoch_end, in my example it is output of epoch-9. Which means though the epoch-9 might not be the best one, but still we have model and metrics of epoch-9 available on_train_end(), which shouldn't be the case. In fact, it should be the best epoch available on train_end.
Feature:- How this can be handled? At end of every epoch, Persist details only if there's improvement in epoch, very similar to "keras.callbacks.ModelCheckpoint(save_best_only=True)".
Let me know if it makes sense @tilakrayal Also let know if already there's a way to achieve this, if I'm missing on something
thanks!!
I don't think we would want to change the inputs to on_train_end, that would be disruptive to all users relying on the current behavior. If anything, the more general case would seem to be to return the entire epoch history rather that last or best, but that would break all the on_train_end users relying on logs today.
I think the best option would be to write your own callback here, overriding on_epoch_end and on_train_end. Use on_epoch_end to track your logs according to whatever metric you want for "best" and then access that callback state during on_train_end. ModelCheckpoint might be a good reference here, but given a specific model yours could probably be significantly simpler.
Hi @mattdangerw, I definitely agree with your thought of not modifying inputs for on_train_end. However issue with overriding on_train_end/on_epoch_end and having implementation similar to ModelCheckpoint is that, at on_epoch_end method, no objective details on which best models is to be picked up is available. Hence without this it would be difficult to decide on the best ones on every call to on_epoch_end . On the other hand this objective is readily available in Modelcheckpoint save_model.
Let me know you thoughts @mattdangerw Thanks!!