P-tuning icon indicating copy to clipboard operation
P-tuning copied to clipboard

转pb文件,PtuningEmbedding层报错

Open MonkeyTB opened this issue 4 years ago • 2 comments

https://github.com/bojone/P-tuning/blob/aec82943f21268a6c813877ac055631e57cb96c3/bert.py#L122 苏神,再用P-tuning代码保存为h5格式后,转pb文件时(参考https://github.com/bojone/bert4keras/issues/194),会报错ValueError: Unknown layer: PtuningEmbedding,麻烦问下知道这是为啥吗?代码如下 ''' import os os.environ['TF_KERAS'] = '1' import numpy as np import pandas as pd from bert4keras.backend import keras,K from bert4keras.layers import Loss, Embedding from bert4keras.tokenizers import Tokenizer from bert4keras.models import build_transformer_model, BERT from bert4keras.optimizers import Adam from bert4keras.snippets import sequence_padding, DataGenerator from bert4keras.snippets import open from bert4keras.layers import Lambda, Dense from keras.models import load_model import tensorflow as tf from tensorflow.python.framework.ops import disable_eager_execution disable_eager_execution() model = 'model/Bert_Ptuning.h5' base = '/model/pb' keras_model = load_model(model,compile=False) keras_model.save(base + '/Bert_Ptuning/1',save_format='tf') # <====注意model path里面的1是代表版本号,必须有这个不然tf serving 会报找不到可以serve的model '''

MonkeyTB avatar Nov 08 '21 02:11 MonkeyTB

load_model的时候传入custom_objects={'PtuningEmbedding': PtuningEmbedding}

bojone avatar Nov 08 '21 02:11 bojone

可以了,感谢~

MonkeyTB avatar Nov 08 '21 06:11 MonkeyTB