keras icon indicating copy to clipboard operation
keras copied to clipboard

Unable to restore a layer of class TextVectorization - Text Classification

Open sachinprasadhs opened this issue 3 years ago • 6 comments

Moving user issue from: https://github.com/tensorflow/tensorflow/issues/45231

Describe the problem.

**When I run the example provided by official tensorflow Basic text classification, everything runs fine until model save. But when I load the model it gives me this error.

RuntimeError: Unable to restore a layer of class TextVectorization. Layers of class TextVectorization require that the class be provided to the model loading code, either by registering the class using @keras.utils.register_keras_serializable on the class def and including that file in your program, or by passing the class in a keras.utils.CustomObjectScope that wraps this load call. **

Model should be loaded successfully and process raw input

https://colab.research.google.com/gist/amahendrakar/8b65a688dc87ce9ca07ffb0ce50b84c7/44199.ipynb#scrollTo=fEjmSrKIqiiM

Example Link: https://tensorflow.google.cn/tutorials/keras/text_classification

sachinprasadhs avatar Feb 09 '22 00:02 sachinprasadhs

Attaching the gist with reproducing error here.

The reported error can be avoided by registering the class with @keras.utils.register_keras_serializable() here is the working gist.

However, this comment from the user https://github.com/tensorflow/tensorflow/issues/45231#issuecomment-1026512621 does not agree with the above approach.

sachinprasadhs avatar Feb 09 '22 00:02 sachinprasadhs

Notes from triage: The error message can be improved here - as the issue is with the standardize argument, not the layer

LukeWood avatar Feb 10 '22 18:02 LukeWood

It is expected that when passing a custom callable to either the standardize or split arguments of TextVectorization, that the function will need to be registered with register_keras_serializable or passed in the custom_objects argument during loading.

We should improve the error message here though, and make it clear this is an issue with serializing the argument to the layer and not the layer itself.

mattdangerw avatar Feb 17 '22 07:02 mattdangerw

Attaching the gist with reproducing error here.

The reported error can be avoided by registering the class with @keras.utils.register_keras_serializable() here is the working gist.

However, this comment from the user tensorflow/tensorflow#45231 (comment) does not agree with the above approach.

This approach helps for those loading the model into the same notebook that they trained it in but it still does not address loading the same model in a different notebook. If you open a new notebook that can access the saved model and run the last cell from the training notebook it will error out.

In order to load it back in given the new notebook context you must run:

@tf.keras.utils.register_keras_serializable()
def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
  return tf.strings.regex_replace(stripped_html,
                                  '[%s]' % re.escape(string.punctuation),
                                  '')

# load model
loaded_model = tf.keras.models.load_model('./model/test/basic-text-class-export')
print(loaded_model.summary())

This seems like a poor solution for someone that is trying to re-load the model from a different notebook, especially if they don't know how custom_standarization was constructed in the first place. If that's the case, they are stuck

tmbluth avatar Mar 03 '22 15:03 tmbluth

I confirm that this is a problem not only when working with notebooks, but with custom model deployments and possibly tensorflow serving.

mihailyanchev avatar Jun 10 '22 09:06 mihailyanchev

Any update on this?

innat avatar Jul 09 '22 20:07 innat