keras
keras copied to clipboard
Keras `predict_step` is not preserved across save and restore
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04 and macOS
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.6.0 and nightly
- Python version: 3.7, 3.8 and 3.9
Describe the current behavior
When implementing custom prediction logic for Keras models using predict_step as explained here, saving and restoring the Keras model with the saved model format ignores the custom prediction logic. Unfortunately the code silently fails and doesn't inform the user that this is not supported, which could lead to detrimental bugs.
The issue is explained in detail with a minimal example in this colab notebook.
I know I can save a custom serving function using
class MyModel(tf.keras.Model):
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def serve(self, data):
...
as described here. But I feel the current behaviour breaks with user expectations since the saved model format is now the default saving format but doesn't support all of the features and might silently fail resulting in unexpected behaviour. This makes it necessary for users to break the abstraction and start using low level TF APIs instead, which I think doesn't fit well with the progressive disclosure of complexity that Keras tends to strive for.
Describe the expected behavior
Keras models should preserve custom predict_step logic when saving and restoring models.
Standalone code to reproduce the issue
import tensorflow as tf
import numpy as np
class FullyConnectedModel(tf.keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(10)
def predict_step(self, data):
logits = self(data, training=False)
return tf.argmax(logits, axis=-1)
def call(self, inputs):
return self.dense(inputs)
x, y = np.random.uniform(size=(128, 20)).astype(np.float32), np.random.randint(0, 10, size=(128))
model = FullyConnectedModel()
model.compile(optimizer="sgd", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.fit(x, y, epochs=2, batch_size=32)
model.save("/tmp/model", save_traces=True)
reloaded_model = tf.keras.models.load_model("/tmp/model")
y_pred = model.predict(x)
reloaded_y_pred = reloaded_model.predict(x)
np.testing.assert_allclose(reloaded_y_pred, y_pred)
See this notebook for more information.
Also checkout https://github.com/tensorflow/tensorflow/issues/48149 which was originally posted to TF before the move to keras-team/keras.
This issue has been open for 1.5 years now (including the https://github.com/tensorflow/tensorflow/issues/48149).
@fchollet @mattdangerw @k-w-w Is there any chance this will be fixed? I am very happy to look into a fix for this and open a PR but for that it would be good to know why the functions are explicitly ignored during saving.