GPT2-Chinese icon indicating copy to clipboard operation
GPT2-Chinese copied to clipboard

运行generate.py是出错KeyError: 'state_dict'已安装好依赖并使用预训练模型

Open 1402366912 opened this issue 3 years ago • 5 comments

求教,小白一枚,请问如何解决 运行generate.py是出错KeyError: 'state_dict'已安装好依赖并使用预训练模型 python generate.py --model_path "D:\BaiduNetdiskDownload\gpt2\pytorch_model.bin" args: Namespace(batch_size=1, device='0', fast_pattern=False, length=512, model_config='config/model_config.json', model_path='D:\BaiduNetdiskDownload\gpt2\pytorch_model.bin', n_ctx=1024, no_wordpiece=False, nsamples=10, prefix='我', repetition_penalty=1.0, save_samples=False, save_samples_path='.', segment=False, temperature=1, tokenizer_path='vocab/vocab.txt', topk=8, topp=0) Traceback (most recent call last): File "generate.py", line 232, in main() File "generate.py", line 205, in main for key, value in torch.load(args.model_path, map_location="cpu")[ KeyError: 'state_dict'

1402366912 avatar Jul 16 '21 08:07 1402366912

你要按照requirements.txt中版本安装!

aigonna avatar Jul 26 '21 04:07 aigonna

你要按照requirements.txt中版本安装!

您好,我是按照要求安装的,也出现相同的问题,请问这如何解决

zhanghongyong123456 avatar Jul 27 '21 03:07 zhanghongyong123456

求教,小白一枚,请问如何解决 运行generate.py是出错KeyError: 'state_dict'已安装好依赖并使用预训练模型 python generate.py --model_path "D:\BaiduNetdiskDownload\gpt2\pytorch_model.bin" args: Namespace(batch_size=1, device='0', fast_pattern=False, length=512, model_config='config/model_config.json', model_path='D:\BaiduNetdiskDownload\gpt2\pytorch_model.bin', n_ctx=1024, no_wordpiece=False, nsamples=10, prefix='我', repetition_penalty=1.0, save_samples=False, save_samples_path='.', segment=False, temperature=1, tokenizer_path='vocab/vocab.txt', topk=8, topp=0) Traceback (most recent call last): File "generate.py", line 232, in main() File "generate.py", line 205, in main for key, value in torch.load(args.model_path, map_location="cpu")[ KeyError: 'state_dict'

请问您解决了,我也遇到相同的错误,希望给点建议

zhanghongyong123456 avatar Jul 27 '21 03:07 zhanghongyong123456

我在colab上用的pip install torch==1.9.0,这样是没问题的,感觉上还是版本问题。requirements好像是1.8.0吧

aigonna avatar Jul 27 '21 09:07 aigonna

我遇到了同样的错误,通过Google 找到了解决方案。

参考网址

应该是由于train.py重构后,generate.py不支持从前的模型了。

解决方案

使用strict=False加载模型

如下所示,找到generate.py的203行,注释部分代码,并加入新的代码。

# state_dict = {
#     key[6:]: value
#     for key, value in torch.load(args.model_path, map_location="cpu")[
#         "state_dict"
#     ].items()
# }
# model.load_state_dict(state_dict)
model.load_state_dict(torch.load(args.model_path), strict=False)

然后就能正常加载模型了。

fzf404 avatar Jul 29 '21 12:07 fzf404