GlobalPointer
GlobalPointer copied to clipboard
使用tf2.4 保存pb报错
加载预训练模型
model = build_transformer_model( config_path=config_path, checkpoint_path=checkpoint_path, return_keras_model=False )
output = GlobalPointer(len(categories), 64)(model.output) model = keras.models.Model(model.input, output) model.summary()
model.compile( loss=global_pointer_crossentropy, optimizer=Adam(learning_rate), metrics=[global_pointer_f1_score] )
evaluator = Evaluator()
train_generator = data_generator(train_data, batch_size)
model.fit(
train_generator.forfit(),
steps_per_epoch=len(train_generator),
epochs=epochs,
callbacks=[evaluator]
)
export_path = 'model'
version = "1"
model.save(export_path + version, save_format="tf")
AttributeError: 'Dropout' object has no attribute '_saved_model_inputs_spec'