keras icon indicating copy to clipboard operation
keras copied to clipboard

Allow ModelCheckpoint callback to save Tensorflow SavedModel

Open jordisoler opened this issue 3 years ago • 5 comments

Here's why we have that policy:.

Keras developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.

System information.

TensorFlow version (you are using): 2.9.1 Are you willing to contribute it (Yes/No) : Yes

Describe the feature and the current behavior/state.

Currently, tf.keras.callbacks.ModelCheckpoint does not accept a save_model parameter, while tf.Keras.Model.save does accept it to specify whether to use Tensorflow SavedModel or HDF5 file. Hence, only HDF5 models will be stored by the callback.

A use-case where this is a problem is when you want to use ModelCheckpoint with a model that contains weights that are not instances of tf.Variable, such as tf.keras.layers.TextVectorization. In such case an error that looks like the following is shown

NotImplementedError: Save or restore weights that is not an instance of `tf.Variable` is not supported in h5, use `save_format='tf'` instead. Received a model or layer TextVectorization with weights [<keras.layers.preprocessing.index_lookup.VocabWeightHandler object at 0x7f86859bf520>]

even though vase_format='tf' is not accessible via ModelCheckpoint

Will this change the current api? How?

Yes. ModelCheckpoint would accept the optional parameter save_format='h5' in its constructor

Who will benefit from this feature? Those who are trying to use ModelCheckpoint with a layer that is not storable in HDF5 format

Contributing

  • Do you want to contribute a PR? (yes/no): yes
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing): I'd add the optional argument save_format='h5' to the ModelCheckpoint.__init__(...), store it as a model attribute and use it when saving the complete model self.model.save(...) (it would have no effect when save_weights_only=True)

jordisoler avatar Jun 06 '22 07:06 jordisoler

The default file format is the Keras Model.Save API is the Tensorflow format in versions >2.x. HDF5 is the default in versions 1.x

So your model saved will be the TensorFlow Protobuf file in version 2.9.x in the ModelCheckpoint callback

AshwinJay101 avatar Jun 06 '22 18:06 AshwinJay101

@AshwinJay101 Then I guess what is wrong is the format inference based on the filename as trying to use ModelCheckpoint with modelname.keras will try to use HDF5 and crash if there is something like TextVectorization. Does that make sense?

jordisoler avatar Jun 07 '22 06:06 jordisoler

Hi @jordisoler, As per the description mentioned here, tf.keras.callbacks.ModelCheckpoint accepts hdf5 file format.

filepath string or PathLike, path to save the model file. e.g. filepath = os.path.join(working_dir, 'ckpt', file_name). filepath can contain named formatting options, which will be filled the value of epoch and keys in logs (passed in on_epoch_end). For example: if filepath is weights.{epoch:02d}-{val_loss:.2f}.hdf5, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. The directory of the filepath should not be reused by any other callbacks to avoid conflicts.

I could able to save weights in hdf5 file format

checkpoint_path = "cp.hdf5"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

gadagashwini avatar Jun 08 '22 06:06 gadagashwini

Hi @gadagashwini . Yes, the checkpoint is definitely capable of saving the model in HDF5 format, my point is that the message may not be obvious. What happened to me, and could happen to other users is:

  1. I used tf.keras.callbacks.ModelCheckpoint with a filename whatever.keras and worked just fine. Was using HDF5 because of the filename extension but I didn't know/care, really
  2. I introduced a TextVectorization layer which can't be stored with HDF5 format. At this point the callback crashed with a message saying that the option save_format='tf' but using the callback I had no real access to the argument save_format from Model.save and I didn't realize that the problem was using HDF5 because of the filename.

So, my suggestion is either:

  • Would it make sense to allow an explicit save_format in the ModelCheckpoint callback that would prevent a rather obscure file format inference from the filename?
  • Should the error message be changed/rebranded in the callback to avoid the situation above? I think the user experience can be improved

jordisoler avatar Jun 08 '22 08:06 jordisoler

Hi @gadagashwini . Yes, the checkpoint is definitely capable of saving the model in HDF5 format, my point is that the message may not be obvious. What happened to me, and could happen to other users is:

  1. I used tf.keras.callbacks.ModelCheckpoint with a filename whatever.keras and worked just fine. Was using HDF5 because of the filename extension but I didn't know/care, really
  2. I introduced a TextVectorization layer which can't be stored with HDF5 format. At this point the callback crashed with a message saying that the option save_format='tf' but using the callback I had no real access to the argument save_format from Model.save and I didn't realize that the problem was using HDF5 because of the filename.

So, my suggestion is either:

  • Would it make sense to allow an explicit save_format in the ModelCheckpoint callback that would prevent a rather obscure file format inference from the filename?
  • Should the error message be changed/rebranded in the callback to avoid the situation above? I think the user experience can be improved

I just faced the same issue, just changing the file name (in callback) from abc.hdf5 to abc made it to use tf-saved model format instead. Error message is very misleading and needs to be changed.

TalhaUsuf avatar Sep 13 '22 18:09 TalhaUsuf