keras
keras copied to clipboard
tf.keras.initializers.TruncatedNormal is not available when using tf.keras.models.load_model
Click to expand!
Issue Type
Bug
Source
binary
Tensorflow Version
2.9.1
Custom Code
No
OS Platform and Distribution
tensorflow/tensorflow Docker image
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/Compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current Behaviour?
When using `tf.keras.initializers.TruncatedNormal` in a custom model, the model can be saved, but fails to load.
Adding `custom_objects={"TruncatedNormal": tf.keras.initializers.TruncatedNormal})` to `tf.keras.models.load_model` prevents the error.
I expected that passing TF classes as custom objects would not be necessary.
Standalone code to reproduce the issue
import tensorflow as tf
@tf.keras.utils.register_keras_serializable()
class MyModel(tf.keras.Model):
def __init__(self, initializer):
self._config = {"initializer": initializer}
super().__init__(inputs=[], outputs=[])
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
mymodel = MyModel(initializer=tf.keras.initializers.TruncatedNormal())
tf.keras.models.save_model(mymodel, "mymodel.sm", overwrite=True)
loaded_mymodel = tf.keras.models.load_model("mymodel.sm")
Relevant log output
Traceback (most recent call last):
File "tfbug.py", line 18, in <module>
loaded_mymodel = tf.keras.models.load_model("mymodel.sm")
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/training/tracking/base.py", line 587, in _method_wrapper
result = method(self, *args, **kwargs)
TypeError: __init__() missing 2 required positional arguments: 'inputs' and 'outputs'
```</details>
Have you tried with:
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self, initializer):
self.initializer = initializer
super().__init__(inputs=[], outputs=[])
def get_config(self):
return {"initializer": self.initializer}
mymodel = MyModel(initializer=tf.keras.initializers.TruncatedNormal())
tf.keras.models.save_model(mymodel, "mymodel.sm", overwrite=True)
import tensorflow as tf
@tf.keras.utils.register_keras_serializable()
class MyModel(tf.keras.Model):
def __init__(self, initializer):
self.initializer = initializer
super().__init__(inputs=[], outputs=[])
def get_config(self):
return {"initializer": self.initializer}
mymodel = MyModel(initializer=tf.keras.initializers.TruncatedNormal())
tf.keras.models.save_model(mymodel, "mymodel.sm", overwrite=True)
loaded_mymodel = tf.keras.models.load_model("mymodel.sm")
Raises the same error. Without @tf.keras.utils.register_keras_serializable() it does not recognize custom model and without tf.keras.models.load_model("mymodel.sm") it does not load the model anyway.
Do you want to do something like:
import tensorflow as tf
from tensorflow import keras
# Define a subclassed model with the same architecture
class MyModel(keras.Model):
def __init__(self, output_dim,name=None, initializer=None):
super(MyModel, self).__init__(name=name)
self.output_dim = output_dim
self.initializer = initializer
self.dense_1 = keras.layers.Dense(64, activation="relu", name="dense_1",
kernel_initializer=self.initializer)
self.dense_2 = keras.layers.Dense(64, activation="relu", name="dense_2",
kernel_initializer=self.initializer)
self.dense_3 = keras.layers.Dense(output_dim, name="predictions",
kernel_initializer=self.initializer)
def call(self, inputs):
x = self.dense_1(inputs)
x = self.dense_2(x)
x = self.dense_3(x)
return x
def get_config(self):
return {"output_dim": self.output_dim, "name": self.name, "initializer": self.initializer}
mymodel = MyModel(10, initializer=tf.keras.initializers.TruncatedNormal())
# Call the subclassed model once to create the weights.
mymodel(tf.ones((1, 784)))
tf.keras.models.save_model(mymodel, "mymodel.sm", overwrite=True)
loaded_mymodel = tf.keras.models.load_model("mymodel.sm")
This is just a minimal example, the specific model that I have a problem with is BERT in this line.
I understand that there might be a workaround, but I still don't understand where the initial error is coming from and it does look like a bug.
I don't see a subsclassed model there and the initializer is used in a custom layer there. So how do you have derived your "minimal" subclassed model gist? Also mine is minimal and derived just adapting the example in the official documentation.
This is a subclassed model. From there I have derived this example to reproduce this error. But I am not really sure how relevant this is. Your code works fine, but it is a different code. Am I misusing TF and Keras in my initial example?
I suppose this is very similar to your initial post or not: https://colab.research.google.com/gist/bhack/f15b04a8181774c72254a8f72485fc4f/untitled129.ipynb
It is missing the tf.keras.models.load_model("mymodel.sm") line which is causing the error,
Yes it was just to understand your use case, but what I meant why you need to use super().__init__(inputs=[], outputs=[])?
As it is seems to me unrelated from the ISSUE title "tf.keras.initializers.TruncatedNormal is not available ".
See also some model subclassing official examples: https://www.tensorflow.org/guide/keras/custom_layers_and_models#the_model_class
As I understand it there are two ways to define a subclassed keras model:
1:
import tensorflow as tf
@tf.keras.utils.register_keras_serializable()
class MyModel(tf.keras.Model):
def __init__(self, initializer):
self.initializer = initializer
inputs = tf.keras.layers.Input(shape=(5,))
dense = tf.keras.layers.Dense(10, activation="relu", name="dense", kernel_initializer=self.initializer)
outputs = dense(inputs)
super().__init__(inputs=[inputs], outputs=[outputs])
def get_config(self):
return {"initializer": self.initializer}
mymodel = MyModel(initializer=tf.keras.initializers.TruncatedNormal())
mymodel(tf.random.normal((1, 5)))
tf.keras.models.save_model(mymodel, "mymodel.sm", overwrite=True)
tf.keras.models.load_model("mymodel.sm")
2:
import tensorflow as tf
@tf.keras.utils.register_keras_serializable()
class MyModel(tf.keras.Model):
def __init__(self, initializer):
super().__init__()
self.initializer = initializer
self.dense = tf.keras.layers.Dense(10, activation="relu", name="dense", kernel_initializer=self.initializer)
def call(self, inputs):
return self.dense(inputs)
def get_config(self):
return {"initializer": self.initializer}
mymodel = MyModel(initializer=tf.keras.initializers.TruncatedNormal())
mymodel(tf.random.normal((1, 5)))
tf.keras.models.save_model(mymodel, "mymodel.sm", overwrite=True)
tf.keras.models.load_model("mymodel.sm")
Method nr 2 does not raise errors, but method nr 1 does (but only during the loading of a saved model). Is nr 1 a wrong way to define a Keras model? If not then why is the code failing on loading the model?
I think 1 run cause init accept kwargs.
But I don't find a subclassed model definition like your (1) also in tests:
https://github.com/keras-team/keras/blob/master/keras/tests/model_subclassing_test.py
@gadagashwini , I was able to reproduce the issue in tensorflow v2.8, v2.9 and nightly. Please find the gist of it here.
Hi @ptarasiewiczNV, Does Making new Layers and Models via subclassing tutorial helps.Thank you!
Hi @gadagashwini , I understand that there are better ways to create models. Just want to make sure that the way keras is used in my example is an incorrect way of using keras and this is not a bug. Can you confirm that?
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.
@k-w-w Kathy, mind taking a look at this issue? Thanks!