ChatGLM-Tuning
ChatGLM-Tuning copied to clipboard
大佬们,能提供api.py吗?类似https://github.com/THUDM/ChatGLM-6B/blob/main/api.py
lora微调后的的模型服务脚本???
model, tokenlizer ,手动改改就是一样的了啊, 就是加了个pt, 一样的、你看看脚本里面的内容粘贴改改也可以
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
import uvicorn
import json
import datetime
import torch
from peft import get_peft_model, LoraConfig, TaskType
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
@app.post("/")
async def chat(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95,
do_sample=False)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + \
prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
if __name__ == '__main__':
torch.set_default_tensor_type(torch.cuda.HalfTensor)
tokenizer = AutoTokenizer.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
peft_path = "output/you/train/model/adapter_model.bin"
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=True,
r=8,
lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
model.load_state_dict(torch.load(peft_path), strict=False)
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModel import uvicorn import json import datetime import torch from peft import get_peft_model, LoraConfig, TaskType DEVICE = "cuda" DEVICE_ID = "0" CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(CUDA_DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI() @app.post("/") async def chat(request: Request): global model, tokenizer json_post_raw = await request.json() json_post = json.dumps(json_post_raw) json_post_list = json.loads(json_post) prompt = json_post_list.get('prompt') history = json_post_list.get('history') max_length = json_post_list.get('max_length') top_p = json_post_list.get('top_p') temperature = json_post_list.get('temperature') response, history = model.chat(tokenizer, prompt, history=history, max_length=max_length if max_length else 2048, top_p=top_p if top_p else 0.7, temperature=temperature if temperature else 0.95, do_sample=False) now = datetime.datetime.now() time = now.strftime("%Y-%m-%d %H:%M:%S") response = engine.process(response) answer = { "response": response, "history": history, "status": 200, "time": time } log = "[" + time + "] " + '", prompt:"' + \ prompt + '", response:"' + repr(response) + '"' print(log) torch_gc() return answer if __name__ == '__main__': torch.set_default_tensor_type(torch.cuda.HalfTensor) tokenizer = AutoTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True) model = AutoModel.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True).half().cuda() peft_path = "output/you/train/model/adapter_model.bin" peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=True, r=8, lora_alpha=32, lora_dropout=0.1 ) model = get_peft_model(model, peft_config) model.load_state_dict(torch.load(peft_path), strict=False) model.eval() uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
太强了
@Ling-yunchi
大佬太强了!要是能提gradio更好啦
@Ling-yunchi
大佬太强了!要是能提gradio更好啦
类似大佬这个前后端分离的,可以去看看fastchat
@suc16 fastchat???
@suc16 fastchat???
应该叫参考一下fastchat fastchat的这个server,应该是前后端分离的比较好的 https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/gradio_web_server.py
这个repo好像也集成了chatglm了,可以参考下: https://github.com/oobabooga/text-generation-webui
这个repo好像也集成了chatglm了,可以参考下: https://github.com/oobabooga/text-generation-webui
这个repo确实更便于参考,stream_generate的api也实现了,fastchat改动难度有点大
浏览器打不开http://127.0.0.1:800,怎么调用?
INFO: Started server process [882] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8004 (Press CTRL+C to quit) The dtype of attention mask (torch.int64) is not bool INFO: 127.0.0.1:42284 - "POST / HTTP/1.1" 500 Internal Server Error ERROR: Exception in ASGI application Traceback (most recent call last):
response = engine.process(response)
NameError: name 'engine' is not defined
@suc16 报错? 浏览器打不开http://127.0.0.1:800,怎么调用?
INFO: Started server process [882] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8004/ (Press CTRL+C to quit) The dtype of attention mask (torch.int64) is not bool INFO: 127.0.0.1:42284 - "POST / HTTP/1.1" 500 Internal Server Error ERROR: Exception in ASGI application Traceback (most recent call last):
response = engine.process(response) NameError: name 'engine' is not defined
@mymusise
你的是model.generate,他的是response, history = model.chat ,curl -X POST "http://127.0.0.1:8000"
-H 'Content-Type: application/json'
-d '{"prompt": "你好", "history": []}'
部署的时候没有成功调用lora微调后的模型,还是调用的官方老模型
@Ling-yunchi 你的是model.generate,他的是response, history = model.chat ,curl -X POST "http://127.0.0.1:8000/" -H 'Content-Type: application/json' -d '{"prompt": "你好", "history": []}'
部署的时候没有成功调用lora微调后的模型,还是调用的官方老模型
而且 INFO: Started server process [882] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8004/ (Press CTRL+C to quit) The dtype of attention mask (torch.int64) is not bool INFO: 127.0.0.1:42284 - "POST / HTTP/1.1" 500 Internal Server Error ERROR: Exception in ASGI application Traceback (most recent call last):
response = engine.process(response) NameError: name 'engine' is not defined
@Ling-yunchi 你的是model.generate,他的是response, history = model.chat ,curl -X POST "http://127.0.0.1:8000/" -H 'Content-Type: application/json' -d '{"prompt": "你好", "history": []}'
部署的时候没有成功调用lora微调后的模型,还是调用的官方老模型
而且 INFO: Started server process [882] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8004/ (Press CTRL+C to quit) The dtype of attention mask (torch.int64) is not bool INFO: 127.0.0.1:42284 - "POST / HTTP/1.1" 500 Internal Server Error ERROR: Exception in ASGI application Traceback (most recent call last):
response = engine.process(response) NameError: name 'engine' is not defined
把response = engine.process(response)这一行删掉即可
@Ling-yunchi
调用的时候这样写吗?这样写没有调用自己微调的lora模型。
curl -X POST "http://127.0.0.1:8000"
-H 'Content-Type: application/json'
-d '{"prompt": "你好", "history": []}'
是不是
@Ling-yunchi
我的推理代码如下:
from transformers import AutoModel import torch from transformers import AutoTokenizer from peft import PeftModel
model = AutoModel.from_pretrained("../chatglm_models", trust_remote_code=True, load_in_8bit=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("../chatglm_models", trust_remote_code=True)
model = PeftModel.from_pretrained(model, "./output/")
import json
instructions = json.load(open("data/alpaca_data.json",encoding="utf-8"),strict=False)
answers = [] from cover_alpaca2jsonl import format_example
with torch.no_grad(): for idx, item in enumerate(instructions[50:60]): feature = format_example(item) input_text = feature['context'] ids = tokenizer.encode(input_text) input_ids = torch.LongTensor([ids]) out = model.generate( input_ids=input_ids, max_length=768, do_sample=False, temperature=0 ) out_text = tokenizer.decode(out[0]) answer = out_text.replace(input_text, "").replace("\nEND", "").strip() item['infer_answer'] = answer print(out_text) print(f"### {idx+1}.Answer:\n", item.get('output'), '\n\n') answers.append({'index': idx, **item})
针对这个,api部署的时候,没有加载lora的模型,一直都是加载官方的模型