mindnlp
mindnlp copied to clipboard
llama-3模型与torch版本结果不一致
复现代码如下:
from mindspore import set_context
set_context(device_target="GPU",device_id=1)
from transformers import AutoModelForCausalLM as PT_AutoModelForCausalLM
from transformers import AutoTokenizer as PT_AutoTokenizer
from mindnlp.transformers import AutoModelForCausalLM as MS_AutoModelForCausalLM
from mindnlp.transformers import AutoTokenizer as MS_AutoTokenizer
model_path = "/data00/jiajie_jin/model/LLaMA-3-8b-Instruct"
generation_params = {'do_sample': False, 'max_new_tokens': 100, 'eos_token_id': [128001, 128009]}
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nAnswer the question based on your own knowledge. Only give me the answer and do not output any other words.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nQuestion: What is Osbert Lancaster best known for producing?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
ms_model = MS_AutoModelForCausalLM.from_pretrained(model_path)
ms_tokenizer = MS_AutoTokenizer.from_pretrained(model_path)
ms_tokenizer.pad_token = ms_tokenizer.eos_token
ms_input = ms_tokenizer(prompt,
return_tensors="ms",
padding=True,
truncation=True,
max_length=512
)
ms_outputs = ms_model.generate(
**ms_input,
output_scores=True,
return_dict_in_generate=True,
**generation_params
)
ms_response = ms_tokenizer.decode(ms_outputs.sequences[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
pt_model= PT_AutoModelForCausalLM.from_pretrained(model_path)
pt_tokenizer = PT_AutoTokenizer.from_pretrained(model_path)
pt_model.to("cuda:2")
pt_model.eval()
pt_tokenizer.pad_token = pt_tokenizer.eos_token
pt_input = pt_tokenizer(prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to("cuda:2")
pt_outputs = pt_model.generate(
**pt_input,
output_scores=True,
return_dict_in_generate=True,
**generation_params
)
pt_response = pt_tokenizer.decode(pt_outputs.sequences[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
print("###### MS RESPONSE:")
print(ms_response)
print("###### MS RESPONSE:")
print(pt_response)
结果如下: