keras-contrib
keras-contrib copied to clipboard
Saving & Loading CRF layer with load_model
Now I'm using the CRF layer very usefully. Thanx for developed it :) But I found some inconvenience part, which is saving and loading the model.
I wanted to load the CRF layer with load_model(model_path)
.
Other layers in keras-contrib
can automatically load the model with load_model()
.
because get_custom_objects().update({'CosineDense': CosineDense})
infrom to origin-keras about the custom layer.
However, the CRF Layer cannot contain this code, because CRF contains custom loss function Like this code
model.compile(loss=crf.loss_function,
optimizer='adam',
metrics=[crf.accuracy])
I also tried this code too but it's not working
model.save(path, custom_objects={'CRF':CRF})
load_model(path, custom_objects={'CRF':CRF})
File "model.py", line 16, in __init__
self.model = load_model(model_path, custom_objects={'CRF': CRF})
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/models.py", line 264, in load_model
sample_weight_mode=sample_weight_mode)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/models.py", line 781, in compile
**kwargs)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 681, in compile
loss_function = losses.get(loss)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/losses.py", line 102, in get
return deserialize(identifier)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/losses.py", line 94, in deserialize
printable_module_name='loss function')
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 159, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:loss
Please give me some idea for saving CRF layer and loss function :)
The way that I found. But there is some issue.
get_custom_objects().update({'CRF':CRF, 'CRFLoss':CRF.loss_function})
Under code shows the CRF layers loss_function()
@property
def loss_function(self):
if self.learn_mode == 'join':
def loss(y_true, y_pred):
assert self.inbound_nodes, 'CRF has not connected to any layer.'
assert not self.outbound_nodes, 'When learn_model="join", CRF must be the last layer.'
if self.sparse_target:
y_true = K.one_hot(K.cast(y_true[:, :, 0], 'int32'), self.units)
X = self.inbound_nodes[0].input_tensors[0]
mask = self.inbound_nodes[0].input_masks[0]
nloglik = self.get_negative_log_likelihood(y_true, X, mask)
return nloglik
return loss
else:
if self.sparse_target:
return sparse_categorical_crossentropy
else:
return categorical_crossentropy
So loss function is not static (object public function), and it calculates the loss using the own object. Is there any method that can make independent crf_loss_function
?
@codertimo I think a good solution is given here (https://github.com/farizrahman4u/keras-contrib/issues/125).
@codertimo have you solved the issue? Can you help to share some hints with me?
@codertimo Did you find a viable solution different from #125 ?
Taking inspiration from here: https://github.com/Hironsan/keras-crf-layer I think I've found a work-around solution.
to save a model, just use the save method of model...
model.save(path)
to reload the model I use this:
from keras_contrib.layers import CRF
from keras.models import load_model
def create_custom_objects():
instanceHolder = {"instance": None}
class ClassWrapper(CRF):
def __init__(self, *args, **kwargs):
instanceHolder["instance"] = self
super(ClassWrapper, self).__init__(*args, **kwargs)
def loss(*args):
method = getattr(instanceHolder["instance"], "loss_function")
return method(*args)
def accuracy(*args):
method = getattr(instanceHolder["instance"], "accuracy")
return method(*args)
return {"ClassWrapper": ClassWrapper ,"CRF": ClassWrapper, "loss": loss, "accuracy":accuracy}
def load_keras_model(path):
model = load_model(path, custom_objects=create_custom_objects())
return model
@kamei86i Thank you for this solution, and using it to reload the model works well on the "predict" function. However, when I tried to use the "evaluate" after loading the model, it raised an "Incompatible shapes" exception.
@kamei86i Thanks for the solution. However, when I am loading it using this, the weights are actually randomly initialized. The model evaluation gives random results, each time I load the same file. Can you post the full working solution?
@kamei86i Thank you for the solution. i am able to reload the model. However when I tried to continue training the model, it raised an "Incompatible shapes" exception.
Could you help me