bert4keras icon indicating copy to clipboard operation
bert4keras copied to clipboard

Bert模型输入端是否支持像input_mask,label_ids这些输入呢?

Open jxyxiangyu opened this issue 3 years ago • 3 comments

提问时请尽可能提供如下信息:

基本信息

  • 你使用的操作系统:
  • 你使用的Python版本:
  • 你使用的Tensorflow版本:
  • 你使用的Keras版本:
  • 你使用的bert4keras版本: 0.11.3
  • 你使用纯keras还是tf.keras:
  • 你加载的预训练模型: bert

核心代码

# 参考的是您本仓库举的例子task_iflytek_adversarial_training.py中的一段代码
# 加载预训练模型
bert = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    return_keras_model=False,
)

output = Lambda(lambda x: x[:, 0])(bert.model.output)
output = Dense(
    units=num_classes,
    activation='sigmoid',
    kernel_initializer=bert.initializer
)(output)

model = keras.models.Model(bert.model.input, output)
# 预测部分
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
pred = model.predict([[token_ids], [segment_ids]])

请问,bert4keras可以像pytorch一样支持input_mask/attention_mask, label_ids这些输入吗?
如何才能再输入端支持这样的输入呢?我拜读了您写的源码,奈何我能力有限,只看到了bert的输入是Input

jxyxiangyu avatar Aug 15 '22 09:08 jxyxiangyu

补充:另外可以问下模型输入的Input-Segment和Input-Token可以自己自定义成其他名字吗?可以在外面再包一层接口吗?

jxyxiangyu avatar Aug 15 '22 09:08 jxyxiangyu

attention_mask、Input-Segment和Input-Token改名,可以通过继承Transformer类/BERT类等来实现。

另外,bert4keras只负责构建bert(transformer),其余大部分需求是个人通过学习keras后实现的。

bojone avatar Sep 19 '22 06:09 bojone

十分感谢

jxyxiangyu avatar Sep 19 '22 07:09 jxyxiangyu