Baichuan-13B
Baichuan-13B copied to clipboard
求教如何设置 解码方式=贪婪解码(greedy decoding)?
问题如上,经查看 /blob/main/modeling_baichuan.py 文件,发现可以为 chat() 方法设置 generation_config 参数来更改解码方式,故有如下实现: `generation_config = { "pad_token_id": 0, "bos_token_id": 1, "eos_token_id": 2, "user_token_id": 195, "assistant_token_id": 196, "max_new_tokens": 2048, "temperature": 0.3, "top_k": 5, "top_p": 0.85, "repetition_penalty": 1.1, "do_sample": False, "transformers_version": "4.29.2" }
messages = [] messages.append({"role": "user", "content": instructs[i]}) generation_config={"num_beams": 1, "do_sample": False} response = model.chat(tokenizer, messages, generation_config) print(response) print('-'*80)`
但执行后发现返回结果非文字,而是 generation_utils.TextIterStreamer 的流式对象,如下所示:
<transformers_modules.baichuan-inc_Baichuan-13B-Chat.generation_utils.TextIterStreamer object at 0x7f0fbbaf3820>
可是我也没有在 chat() 中指定 stream 呀,期望路过的大佬及攻城狮大哥可以空闲时给出解答 @bc-gpd
抱歉才看到,generation_config得是个GenerationConfig对象
generation_config = GenerationConfig.from_pretrained(model_path)
generation_config.do_stream=False
generation_config.do_sample=False
messages = []
messages.append({"role": "user", "content": "世界上最高的山是哪个"})
response = model.chat(tokenizer, messages, stream=False, generation_config=generation_config)
print(response)
print('-'*80)
上面这个代码测试,greedy可用