generate
generate copied to clipboard
The BaichuanChatParameters missed 'max_tokens' attribute
import asyncio
import os
from generate import BaichuanChat
from generate.platforms.baichuan import BaichuanSettings
async def main():
model = BaichuanChat(
settings=BaichuanSettings(
api_key=os.environ.get("BAICHUAN_API_KEY"), secret_key=""
),
).session()
messages = [
{"role": "system", "content": "You're an assistant."},
{"role": "user", "content": "北京有哪些好玩的地方?"},
]
model.history = messages[:-1]
response = model.async_stream_generate(messages[-1], max_tokens=5)
async for chunk in response:
token = chunk.stream.delta
print(token, flush=True, end="")
asyncio.run(main())
上面代码通过max_tokens参数,期望约束回答的长度。但结果表明约束失败,其代码执行结果如下:
北京有很多好玩的地方,以下是一些建议:
1. 故宫:这是世界上最大的宫殿建筑群,也是中国历史的重要象征。
2. 颐和园:这是一个融合了皇家园林艺术的大型湖光山色园林,被誉为“皇家园林博物馆”。
3. 天安门广场:这是世界上最大的城市广场,也是中国的政治和文化中心。
4. 长城:这是中国最具代表性的景点之一,被誉为世界七大奇迹之一。
5. 北京胡同:这里可以体验到老北京的风情和生活方式。
6. 南锣鼓巷:这里有许多特色小吃和手工艺品店。
7. 王府井步行街:这里是购物和品尝北京小吃的好去处。
8. 798艺术区:这里是中国当代艺术的发源地,有许多画廊、艺术展览和创意设计店。
9. 北京动物园:这里有各种各样的动物,包括大熊猫。
10. 国家大剧院:这是一个现代化的艺术中心,有各种音乐、戏剧和舞蹈表演。
分析可能的原因是:BaichuanChatParameters类缺失了max_tokens定义。 具体可见这个fix: https://github.com/cddc/generate/commit/8a1076586f9c3d1ea03d408c9b27874056f4ebce