textClassifier
textClassifier copied to clipboard
save model error
In textClassifierHATT.py, I try to save the model using the following callback:
mcp = ModelCheckpoint('HANmodel_weights.h5', monitor="val_acc", save_best_only=True, save_weights_only=False)
model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=20, batch_size=50, callbacks = [mcp])
But the following error occurred
...
RuntimeError: Unable to create link (Name already exists)
There are suggestions that non-unique layer names cause this problem here but I haven't seen any duplicate names in this model.
Have you solved it? I also have this problem. Thanks.
Hey guys, @kcsmta was right!
the issue comes from saving weights of the Attention Layer. I've made the fix AttLayer class, where I gave custom names for variables:
def build(self, input_shape):
assert len(input_shape) == 3
self.W = K.variable(self.init((input_shape[-1], self.attention_dim)), name="W")
self.b = K.variable(self.init((self.attention_dim, )), name="b")
self.u = K.variable(self.init((self.attention_dim, 1)), name="u")
self.trainable_weights = [self.W, self.b, self.u]
super(AttLayer, self).build(input_shape)
@richliao, could you accept the pull_request
@pmm-511 , ... if you still need a fix
Just using a named layer wasn't enough for me to save the model. I tried saving both complete model and model with weights as well. The problem was, if you are initializing variables in the __init__
method, then you should also return them in method get_config
.
class AttentionLayer(Layer):
def __init__(self, attention_dim, supports_masking=True, **kwargs):
super(AttentionLayer, self).__init__(name="attention_layer")
self.init = initializers.get("normal")
self.supports_masking = supports_masking
self.attention_dim = attention_dim
super(AttentionLayer, self).__init__(**kwargs)
def get_config(self):
config = {
"supports_masking": self.supports_masking,
"attention_dim": self.attention_dim,
}
base_config = super(AttentionLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Please ensure that you are not returning the init
variable in get_config
method.