keras
keras copied to clipboard
ModelCheckpoint used with `save_best_only` doesn't handle interruptions, even with BackupAndRestore
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 20.04
- TensorFlow installed from (source or binary): docker image tensorflow/tensorflow:2.10.0
- TensorFlow version (use command below): 2.10.0 (the issue is the same for 2.8 and 2.9)
- Python version: 3.8.10
Describe the problem. I am using ModelCheckpoint callback to save the best model, combined with BackupAndRestore callback to handle interruptions. The problem lies when running again a training script after an interruption. The model restored by BackupAndRetore doesn't have the previous value of losses and metrics. Thus, ModelCheckpoint saves the model on the 1st epoch of this new run, whatever the value of loss, it even overwrites the "best" model with a not-as-good model.
Describe the current behavior.
- I run a training script :heavy_check_mark:
- The script gets interrupted :heavy_check_mark:
- When running it again, the model is correctly restored by BackupAndRestore :heavy_check_mark:
- However, when the model is restored and training resumes, the ModelCheckpoint doesn't behave as expected: on this new run, it saves the model on the first epoch not accounting if the loss improved or not. :negative_squared_cross_mark:
Describe the expected behavior.
- I run a training script :heavy_check_mark:
- The script gets interrupted :heavy_check_mark:
- When running it again, the model is correctly restored by BackupAndRestore :heavy_check_mark:
- The training state is fully restored, including validation loss and metrics, the ModelCheckpoint keeps on doing its job: saving the best model. :heavy_check_mark:
Standalone code to reproduce the issue.
First training
import tensorflow as tf
from tensorflow import keras
import numpy as np
# Dummy datasets
np.random.seed(12)
train_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(500, 10, 4), np.random.randint(0, 5, (500, 1))))
valid_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(100, 10, 4), np.random.randint(0, 5, (100, 1))))
model = keras.Sequential(
[
keras.layers.Dense(40, activation="relu"),
keras.layers.Dense(100, activation="relu"),
keras.layers.Dense(400, activation="relu"),
keras.layers.Dense(10, activation="relu"),
keras.layers.Dense(3, activation="relu"),
keras.layers.Dense(1),
]
)
model.compile(loss='mean_squared_error',
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.005))
# Callbacks
backup_cb = keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup_dir')
ckpt_cb = keras.callbacks.ModelCheckpoint('/tmp/best_model', save_best_only=True, monitor='val_loss', verbose=1)
# Callback that fakes an interruption
class InterruptingCallback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
if epoch == 4:
raise RuntimeError('Interrupting!')
model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb, InterruptingCallback()])
The output is:
Epoch 1/6
476/500 [===========================>..] - ETA: 0s - loss: 2.4351
Epoch 1: val_loss improved from inf to 2.07992, saving model to /tmp/best_model
500/500 [==============================] - 2s 3ms/step - loss: 2.4528 - val_loss: 2.0799
Epoch 2/6
475/500 [===========================>..] - ETA: 0s - loss: 2.2506
Epoch 2: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2690 - val_loss: 2.0819
Epoch 3/6
476/500 [===========================>..] - ETA: 0s - loss: 2.2212
Epoch 3: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2440 - val_loss: 2.0859
Epoch 4/6
486/500 [============================>.] - ETA: 0s - loss: 2.2116
Epoch 4: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2372 - val_loss: 2.0894
Traceback (most recent call last):
File "1st_training.py", line 36, in <module>
model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb, InterruptingCallback()])
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "1st_training.py", line 33, in on_epoch_begin
raise RuntimeError('Interrupting!')
RuntimeError: Interrupting!
2nd training
import tensorflow as tf
from tensorflow import keras
import numpy as np
# Dummy datasets
np.random.seed(12)
train_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(500, 10, 4), np.random.randint(0, 5, (500, 1))))
valid_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(100, 10, 4), np.random.randint(0, 5, (100, 1))))
model = keras.Sequential(
[
keras.layers.Dense(40, activation="relu"),
keras.layers.Dense(100, activation="relu"),
keras.layers.Dense(400, activation="relu"),
keras.layers.Dense(10, activation="relu"),
keras.layers.Dense(3, activation="relu"),
keras.layers.Dense(1),
]
)
model.compile(loss='mean_squared_error',
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.005))
# Callbacks
backup_cb = keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup_dir')
ckpt_cb = keras.callbacks.ModelCheckpoint('/tmp/best_model', save_best_only=True, monitor='val_loss', verbose=1)
model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb])
The output is:
Epoch 5/6
476/500 [===========================>..] - ETA: 0s - loss: 2.2097
Epoch 5: val_loss improved from inf to 2.09097, saving model to /tmp/best_model
500/500 [==============================] - 2s 3ms/step - loss: 2.2322 - val_loss: 2.0910
Epoch 6/6
500/500 [==============================] - ETA: 0s - loss: 2.2200
Epoch 6: val_loss did not improve from 2.09097
500/500 [==============================] - 1s 1ms/step - loss: 2.2200 - val_loss: 2.0978
The problem lies at val_loss improved from inf to 2.09097
, the model restored by BackupAndRetore doesn't restore the previous value of val_loss. The model is initialized with an inf
value, thus ModelCheckpoint doesn't fulfill what it is supposed to do and it even overwrites the "best" model with a not-as-good model.
@gowthamkpr, I was able to reproduce the issue on tensorflow v2.8, v2.9 and nightly. Kindly find the gist of it here.
@nicolasnn - Thanks for reporting this issue with detailed examples.
I am working on an update to BackupAndRestore
callback to address this specific scenario. I will submit a PR in the next few weeks. @rchao - please assign this issue to me.
Thanks Ramesh!