Not benefiting from checkpointing
Hello,
I save checkpoints with:
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
filepath = checkpoint_path + '/epoch-{epoch:02d}/',
monitor = 'val_loss',
save_freq = 'epoch',
save_weights_only = False,
save_best_only = False,
mode = 'auto')
After loading the latest checkpoint and continuing training, I would expect the loss value to be around the loss value in the last checkpoint.
model = tf.keras.models.load_model(
model_path,
custom_objects={"SimilarityModel": tfsim.models.SimilarityModel,
'MyOptimizer': tfa.optimizers.RectifiedAdam})
model.load_index(model_path)
model.fit(
datasampler,
callbacks = callbacks,
epochs = args.epochs,
initial_epoch=initial_epoch_number,
steps_per_epoch = N_TRAIN_SAMPLES ,
verbose=2
)
However, the loss value does not continue from where it left. It looks like it's simply starting the training from scratch and not benefiting from checkpoints.
Thanks for submitting the issue @mstfldmr. Do you have a simple example I can use to try and repro the issue? I can also try and repro this using our basic example, but it might be good to get closer to your current set up as well.
@owenvallis I'm sorry, I can't share the full code because it has some confidential pieces we developed. This was how I configured checkpointing:
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='loss', save_freq='epoch', save_weights_only=False, save_best_only=False)
and how I loaded a checkpoint back:
resumed_model_from_checkpoints = tf.keras.models.load_model(f'{checkpoint_path}/{max_epoch_filename}')
@owenvallis could you reproduce it?
Hi @mstfldmr, sorry for the delay here. I'll try and get to this this week.
Looking into this now as it also looks like there is a breaking change in 2.8 where they removed Optimizer.get_weights() (see https://github.com/keras-team/tf-keras/issues/442). That issue also mentions that SavedModel didn't properly save the weights for certain optimizers in the past (see https://github.com/tensorflow/tensorflow/issues/44670).
Which optimizer were you using? Was it Adam?
@owenvallis yes, it was tfa.optimizers.RectifiedAdam.