bert4keras
bert4keras copied to clipboard
多卡训练如何写train_generator.to_dataset()
提问时请尽可能提供如下信息:
基本信息
- 你使用的操作系统: linux
- 你使用的Python版本: 3.6.4
- 你使用的Tensorflow版本: 1.14
- 你使用的Keras版本: 2.3.1
- 你使用的bert4keras版本: 0.10.6
- 你使用纯keras还是tf.keras: 首行代码设置os.environ['TF_KERAS'] = '1' # 必须使用tf.keras
- 你加载的预训练模型: 自己在google-bert-base基础上预训练的bert模型。应用到下游任务做文本匹配,下游任务中模型直接用的bert在nsp预训练任务中使用的网络,如函数build_model()所示。
我想将文本匹配任务适配多GPU,将task_sentence_similarity_lcqmc.py按照task_seq2seq_autotitle_multigpu.py进行改造,但是问题是,在将数据转为tf.data.Dataset的时候,不知道该怎么搞。因为样例代码autotitle_multigpu.py中没有y_true,我该如何在train_generator.to_dataset()中引入y_true呢?核心代码中有我自己写的train_generator.to_dataset,但是报错。
核心代码
# 请在此处贴上你的核心代码。
# 请尽量只保留关键部分,不要无脑贴全部代码。
def build_model():
'''
创建模型
:param checkpoint_path:
:param config_path:
:return:
'''
global tokenizer
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
bert = build_transformer_model(
config_path=config_path,
checkpoint_path=None,
return_keras_model=False,
with_nsp='linear'
)
model = bert.model # 这个才是keras模型
AdamLR = extend_with_piecewise_linear_lr(Adam)
model.compile(
loss='sparse_categorical_crossentropy',
# optimizer=Adam(1e-5), # 用足够小的学习率
optimizer=AdamLR(lr=1e-4, lr_schedule={1000: 1, 2000: 0.1}),
metrics=['accuracy'],
)
bert.load_weights_from_checkpoint(checkpoint_path) # 必须最后才加载预训练权重
return model
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for is_end, (label, text1, text2) in self.sample(random):
token_ids, segment_ids = tokenizer.encode(text1, text2, maxlen=maxlen)
yield [token_ids, segment_ids], [label]
def do_train():
# 转换数据集
train_generator = data_generator(train_data, batch_size)
train_dataset = train_generator.to_dataset(
types=(['float32', 'float32'], ['int32']),
shapes=([[None], [None]],[[None]]), # 配合后面的padded_batch=True,实现自动padding
names=(['Input-Token', 'Input-Segment'], ['NSP-Proba']),
padded_batch=True
) # 数据要转为tf.data.Dataset格式,names跟输入层的名字对应
报错信息
Epoch 1/5
2021-08-03 00:40:29.185873: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcublas.so.10.0
Traceback (most recent call last):
File "train_with_nsp_multigpu.py", line 210, in <module>
do_train()
File "train_with_nsp_multigpu.py", line 184, in do_train
callbacks=[evaluator]
File "/dockerdata/ashersu/Virtualenv/tf1.14/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 649, in fit
validation_freq=validation_freq)
File "/dockerdata/ashersu/Virtualenv/tf1.14/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_distributed.py", line 143, in fit_distributed
steps_name='steps_per_epoch')
File "/dockerdata/ashersu/Virtualenv/tf1.14/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py", line 274, in model_iteration
batch_outs = f(actual_inputs)
File "/dockerdata/ashersu/Virtualenv/tf1.14/lib/python3.6/site-packages/tensorflow/python/keras/backend.py", line 3292, in __call__
run_metadata=self.run_metadata)
File "/dockerdata/ashersu/Virtualenv/tf1.14/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1458, in __call__
run_metadata_ptr)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: All elements in a batch must have the same rank as the padded shape for component2: expected rank 1 but got element with rank 0
[[{{node MultiDeviceIteratorGetNextFromShard}}]]
[[RemoteCall]]
[[IteratorGetNext]]
(1) Invalid argument: All elements in a batch must have the same rank as the padded shape for component2: expected rank 1 but got element with rank 0
[[{{node MultiDeviceIteratorGetNextFromShard}}]]
[[RemoteCall]]
[[IteratorGetNext]]
[[Adam/Adam/update_Transformer-11-MultiHeadSelfAttention/dense_66/bias/update_1/GreaterEqual_1/_7661]]
0 successful operations.
1 derived errors ignored.
自我尝试
不管什么问题,请先尝试自行解决,“万般努力”之下仍然无法解决再来提问。此处请贴上你的努力过程。
实验了很多方法还是报错,还请麻烦看一下,被tf.data.dataset折磨到了,一直看不懂
[label]
改为[[label]]
试试
请问这个解决了吗?遇到类似的问题
[label]
改为[[label]]
试试
按照您的建议修改后,可以成功多卡训练了。周知 @1073521013
谢谢
@SuMeng123 能否参考下您的多卡文本匹配的训练代码?谢谢!