bert4pytorch icon indicating copy to clipboard operation
bert4pytorch copied to clipboard

返回 embedding 和 huggingface 的返回结果不完全一致

Open mmmwhy opened this issue 3 years ago • 4 comments

比如 bert-base-chinese,作者是否有做过这方面的评估测试呀~

mmmwhy avatar Jan 07 '22 07:01 mmmwhy

sentence = "我是一个好男人!" max_len = 32 已设置 .eval

huggingface 结果

image

bert4pytorch 结果

image

mmmwhy avatar Jan 07 '22 07:01 mmmwhy

原始版本

from transformers import BertModel
from transformers import BertTokenizer

sentence = "我是一个好男人!"
max_len = 32

bert_model = BertModel.from_pretrained("/bert-base-chinese")
bert_model.eval()

text_tokenizer = BertTokenizer.from_pretrained("/bert-base-chinese", do_lower_case=True)
tensor_caption = text_tokenizer.encode(sentence, 
                return_tensors="pt",
                padding='max_length',
                truncation=True,max_length=max_len)

pooler_output = bert_model(tensor_caption).pooler_output
last_hidden_state = bert_model(tensor_caption).last_hidden_state

bert4pytorch 版本

import torch
from bert4pytorch.modeling import build_transformer_model
from bert4pytorch.tokenization import Tokenizer

sentence = "我是一个好男人!"
max_len = 32

root_model_path = "/bert-base-chinese"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'

# 建立分词器
tokenizer = Tokenizer(vocab_path)

# 读取数据
tokens_ids, segments_ids = tokenizer.encode(sentence, max_len=max_len)
tokens_ids = tokens_ids + (max_len - len(tokens_ids)) * [0]
segments_ids = segments_ids + (max_len - len(segments_ids)) * [0]
tokens_ids_tensor = torch.tensor([tokens_ids])
segment_ids_tensor = torch.tensor([segments_ids])



model = build_transformer_model(config_path, checkpoint_path, with_pool=True)
model.eval()

encoded_layers, pooled_output = model(tokens_ids_tensor, segment_ids_tensor)

mmmwhy avatar Jan 08 '22 04:01 mmmwhy

试过把transformer中max_length这个入参去掉,两者是一致的

Tongjilibo avatar Jan 25 '22 03:01 Tongjilibo

经过我的调试,这个问题最终定位是hugging face 的模型中对layerNorm参数的命名是"gamma"和“beta”。 但是作者导入参数时写的mapping是weight和bias,因此参数导入失败

DimariaW avatar Apr 23 '22 10:04 DimariaW