bert4keras
bert4keras copied to clipboard
Bert模型输入端是否支持像input_mask,label_ids这些输入呢?
提问时请尽可能提供如下信息:
基本信息
- 你使用的操作系统:
- 你使用的Python版本:
- 你使用的Tensorflow版本:
- 你使用的Keras版本:
- 你使用的bert4keras版本: 0.11.3
- 你使用纯keras还是tf.keras:
- 你加载的预训练模型: bert
核心代码
# 参考的是您本仓库举的例子task_iflytek_adversarial_training.py中的一段代码
# 加载预训练模型
bert = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
return_keras_model=False,
)
output = Lambda(lambda x: x[:, 0])(bert.model.output)
output = Dense(
units=num_classes,
activation='sigmoid',
kernel_initializer=bert.initializer
)(output)
model = keras.models.Model(bert.model.input, output)
# 预测部分
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
pred = model.predict([[token_ids], [segment_ids]])
请问,bert4keras可以像pytorch一样支持input_mask/attention_mask, label_ids这些输入吗?
如何才能再输入端支持这样的输入呢?我拜读了您写的源码,奈何我能力有限,只看到了bert的输入是Input
补充:另外可以问下模型输入的Input-Segment和Input-Token可以自己自定义成其他名字吗?可以在外面再包一层接口吗?
attention_mask、Input-Segment和Input-Token改名,可以通过继承Transformer类/BERT类等来实现。
另外,bert4keras只负责构建bert(transformer),其余大部分需求是个人通过学习keras后实现的。
十分感谢