seq2seq
seq2seq copied to clipboard
Using keras.models.load_model leads to TypeError
from keras.models import load_model
from recurrentshop.engine import RecurrentContainer
from seq2seq.cells import LSTMDecoderCell
def mean_squared_displacement(y_true, y_pred):
return K.mean(K.sum(K.square(y_true - y_pred), axis=-1))
model = Seq2Seq(input_dim=2,
input_length=8,
output_dim=2,
output_length=12,
hidden_dim=128,
dropout=0.0,
peek=False,
depth=1)
model.compile(loss= mean_squared_displacement, optimizer='adam', metrics=['accuracy'])
model.fit(X_train,Y_train, nb_epoch=20)
model.save('my_model.h5')
del model
model = load_model('my_model.h5', custom_objects={'RecurrentContainer': RecurrentContainer, 'LSTMDecoderCell': LSTMDecoderCell, 'mean_squared_displacement': mean_squared_displacement})
File "save_load_test.py", line 25, in <module>
results = runPredictionByConfig(trajectories, config, save_dir)
File "/Prediction/run_prediction.py", line 83, in runTrajectoryPredictionByConfig
verbose=config.getint('general', 'verbose'))
File "/Prediction/run_prediction.py", line 192, in runTrajectoryPrediction
'mean_squared_displacement': mean_squared_displacement})
File "/.local/lib/python2.7/site-packages/keras/models.py", line 140, in load_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/.local/lib/python2.7/site-packages/keras/models.py", line 189, in model_from_config
return layer_from_config(config, custom_objects=custom_objects)
File "/.local/lib/python2.7/site-packages/keras/utils/layer_utils.py", line 34, in layer_from_config
return layer_class.from_config(config['config'])
File "/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2398, in from_config
process_layer(layer_data)
File "/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2395, in process_layer
layer(input_tensors)
File "/.local/lib/python2.7/site-packages/recurrentshop/engine.py", line 531, in __call__
for i in range(len(x[3])):
TypeError: object of type 'TensorVariable' has no len()
a workaround ist obviously
model1 = Seq2Seq(input_dim=2,
input_length=8,
output_dim=2,
output_length=12,
hidden_dim=128,
depth=1)
model1.save_weights('model_weights.h5')
...
model2 = Seq2Seq(input_dim=2,
input_length=8,
output_dim=2,
output_length=12,
hidden_dim=128,
depth=1)
model2.load_weights('model_weights.h5')
@BmlmnnsDev Hi, Any updates on this thread? Met the same problem.