GPT2-Chinese
GPT2-Chinese copied to clipboard
运行generate.py是出错KeyError: 'state_dict'已安装好依赖并使用预训练模型
求教,小白一枚,请问如何解决
运行generate.py是出错KeyError: 'state_dict'已安装好依赖并使用预训练模型
python generate.py --model_path "D:\BaiduNetdiskDownload\gpt2\pytorch_model.bin"
args:
Namespace(batch_size=1, device='0', fast_pattern=False, length=512, model_config='config/model_config.json', model_path='D:\BaiduNetdiskDownload\gpt2\pytorch_model.bin', n_ctx=1024, no_wordpiece=False, nsamples=10, prefix='我', repetition_penalty=1.0, save_samples=False, save_samples_path='.', segment=False, temperature=1, tokenizer_path='vocab/vocab.txt', topk=8, topp=0)
Traceback (most recent call last):
File "generate.py", line 232, in
你要按照requirements.txt中版本安装!
你要按照requirements.txt中版本安装!
您好,我是按照要求安装的,也出现相同的问题,请问这如何解决
求教,小白一枚,请问如何解决 运行generate.py是出错KeyError: 'state_dict'已安装好依赖并使用预训练模型 python generate.py --model_path "D:\BaiduNetdiskDownload\gpt2\pytorch_model.bin" args: Namespace(batch_size=1, device='0', fast_pattern=False, length=512, model_config='config/model_config.json', model_path='D:\BaiduNetdiskDownload\gpt2\pytorch_model.bin', n_ctx=1024, no_wordpiece=False, nsamples=10, prefix='我', repetition_penalty=1.0, save_samples=False, save_samples_path='.', segment=False, temperature=1, tokenizer_path='vocab/vocab.txt', topk=8, topp=0) Traceback (most recent call last): File "generate.py", line 232, in main() File "generate.py", line 205, in main for key, value in torch.load(args.model_path, map_location="cpu")[ KeyError: 'state_dict'
请问您解决了,我也遇到相同的错误,希望给点建议
我在colab上用的pip install torch==1.9.0
,这样是没问题的,感觉上还是版本问题。requirements好像是1.8.0吧
我遇到了同样的错误,通过Google 找到了解决方案。
应该是由于train.py
重构后,generate.py
不支持从前的模型了。
解决方案
使用
strict=False
加载模型
如下所示,找到generate.py
的203行,注释部分代码,并加入新的代码。
# state_dict = {
# key[6:]: value
# for key, value in torch.load(args.model_path, map_location="cpu")[
# "state_dict"
# ].items()
# }
# model.load_state_dict(state_dict)
model.load_state_dict(torch.load(args.model_path), strict=False)
然后就能正常加载模型了。