transformers_tasks
transformers_tasks copied to clipboard
多卡训练后,推理时提示找不到tokenization_chatglm.py
老师您好,在根据您的方法进行多gpu运算之后,在执行python inference.py进行推理的时候,出现了错误。
期待您的解惑。
我意识到自己,在train_multi_gpu.py中更改了本地的模型地址。请问我该如何做?
我意识到可能是config.json中的寻址可以调整为本地,不知道是否如此
同样的问题,请问解决了吗
不知道为啥,ptune训练结束后的模型没有包含一些关键的py文件,我的解决办法是 从原始的chatglm-6b模型处加载基础模型,然后从ptune训练后的模型中加载ptune训练后的参数到基础模型中
diff --git a/LLM/finetune/inference.py b/LLM/finetune/inference.py
index f7d1311..77f9c30 100644
--- a/LLM/finetune/inference.py
+++ b/LLM/finetune/inference.py
@@ -1,3 +1,4 @@
+# coding: utf8
# !/usr/bin/env python3
"""
==== No Bugs in code, just some Random Unexpected FEATURES ====
@@ -20,10 +21,10 @@ inference 训练好的模型。
Author: pankeyu
Date: 2023/03/17
"""
-import time
+import time,os
import torch
-from transformers import AutoTokenizer, AutoModel
+from transformers import AutoTokenizer, AutoModel, AutoConfig
torch.set_default_tensor_type(torch.cuda.HalfTensor)
@@ -64,26 +65,54 @@ if __name__ == '__main__':
device = 'cuda:0'
max_new_tokens = 300
- model_path = "checkpoints/model_1000"
+ ptune_model_path = "D://Software//transformers_tasks-main//LLM//finetune//checkpoints//ptuning//model_best"
+
+ config = AutoConfig.from_pretrained("D:\\software\\chatglm-6b\\chatglm-6b", trust_remote_code=True)
+ config.pre_seq_len = 128
+ config.prefix_projection = False
tokenizer = AutoTokenizer.from_pretrained(
- model_path,
+ "D:\\software\\chatglm-6b\\chatglm-6b",
trust_remote_code=True
)
model = AutoModel.from_pretrained(
- model_path,
+ "D:\\software\\chatglm-6b\\chatglm-6b", config=config,
trust_remote_code=True
).half().to(device)
+
+ prefix_state_dict = torch.load(os.path.join(ptune_model_path, "pytorch_model-00002-of-00002.bin"))
+ new_prefix_state_dict = {}
+ if k.startswith("transformer.prefix_encoder."):
+ new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
+
+ model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
+
+ # 非常重要,用于将pytorch_model-00002-of-00002.bin从gpu内存中删除,否则oom
+ del prefix_state_dict
+
+
+ model = model.half()
+ model.transformer.prefix_encoder.float()
你好,请问解决了吗
你本地模型中包含tokenization_chatglm.py吗