fastllm icon indicating copy to clipboard operation
fastllm copied to clipboard

chatglm-6B 用lora微调后导出的模型进行推理时,前半部分答案正确,后半部分会一直重复。

Open Vvegetables opened this issue 1 year ago • 5 comments

image

Vvegetables avatar Aug 21 '23 07:08 Vvegetables

可以提供更多信息用于复现问题吗? 例如code、模型以及使用的什么精度?

siemonchan avatar Aug 22 '23 03:08 siemonchan

我尝试了一下,qwen模型也有类似问题,但发生在使用导出模型的情况下,也就是说 qwen模型导出精度int4为本地flm文件,然后再加载进行推理大概率出现不断重复情况 如果再同一个脚本中先进行transformers的加载,再通过接口去直接转换再进行推理没有遇见这种情况 代码的话其实就是用qwen2flm和cli_demo两个

Zhangtiande avatar Sep 05 '23 00:09 Zhangtiande

如果再同一个脚本中先进行transformers的加载,再通过接口去直接转换再进行推理没有遇见这种情况 代码的话其实就是用qwen2flm和cli_demo两个

谢谢,可以给出一定的代码示例吗?没太理解您的意思

Vvegetables avatar Sep 12 '23 06:09 Vvegetables

后面我用了之后,转出来的int4模型好像偶尔都会出现重复的情况

# coding=utf-8
# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.

from argparse import ArgumentParser
import time
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers.generation import GenerationConfig
from sse_starlette.sse import ServerSentEvent, EventSourceResponse


@asynccontextmanager
async def lifespan(app: FastAPI):  # collects GPU memory
    yield
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class ModelCard(BaseModel):
    id: str
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "owner"
    root: Optional[str] = None
    parent: Optional[str] = None
    permission: Optional[list] = None


class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelCard] = []


class ChatMessage(BaseModel):
    role: Literal["user", "assistant", "system"]
    content: str


class DeltaMessage(BaseModel):
    role: Optional[Literal["user", "assistant", "system"]] = None
    content: Optional[str] = None


class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = None
    top_p: Optional[float] = None
    max_length: Optional[int] = None
    stream: Optional[bool] = False
    stop: Optional[List[str]] = []


class ChatCompletionResponseChoice(BaseModel):
    index: int
    message: ChatMessage
    finish_reason: Literal["stop", "length"]


class ChatCompletionResponseStreamChoice(BaseModel):
    index: int
    delta: DeltaMessage
    finish_reason: Optional[Literal["stop", "length"]]


class ChatCompletionResponse(BaseModel):
    model: str
    object: Literal["chat.completion", "chat.completion.chunk"]
    choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))


@app.get("/v1/models", response_model=ModelList)
async def list_models():
    global model_args
    model_card = ModelCard(id="gpt-3.5-turbo")
    return ModelList(data=[model_card])


@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
    global model, tokenizer
    print('-'*5 + "Related documents start" + '-'*5)
    for each in request.messages:
        if each.role == "user":
            print(each.content)
            print()
    print('-'*5 + "Related documents end" + '-'*5)
    if request.messages[-1].role != "user":
        raise HTTPException(status_code=400, detail="The last message must be from user.")
    query = request.messages[-1].content
    stop_words = request.stop
    stop_words.extend(list(map(lambda x: x[1:], filter(lambda x: x.startswith("\n"), stop_words))))
    prev_messages = request.messages[:-1]
    # Temporarily, the system role does not work as expected. We advise that you write the setups for role-play
    # in your query.
    # if len(prev_messages) > 0 and prev_messages[0].role == "system":
    #     query = prev_messages.pop(0).content + query

    history = []
    if len(prev_messages) % 2 == 0:
        for i in range(0, len(prev_messages), 2):
            if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
                history.append([prev_messages[i].content, prev_messages[i + 1].content])
            else:
                raise HTTPException(status_code=400, detail="The message order is invalid.")
    else:
        raise HTTPException(status_code=400, detail="The message number is invalid.")

    if request.stream:
        generate = predict(query, history, request.model, stop_words)
        return EventSourceResponse(generate, media_type="text/event-stream")

    if stop_words:
        react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
        response, _ = model.chat(tokenizer, query, history=history, stop_words_ids=react_stop_words_tokens)
        for stop_ in stop_words:
            if response.endswith(stop_):
                response = response[:response.find(stop_)]
    else:
        response, _ = model.chat(tokenizer, query, history=history)

    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=ChatMessage(role="assistant", content=response),
        finish_reason="stop"
    )

    return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")


async def predict(query: str, history: List[List[str]], model_id: str, stop_words: List[str]):
    global model, tokenizer
    assert stop_words == [], "in stream format, stop word is output"
    choice_data = ChatCompletionResponseStreamChoice(
        index=0,
        delta=DeltaMessage(role="assistant"),
        finish_reason=None
    )
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
    yield "{}".format(chunk.model_dump_json(exclude_unset=True))

    current_length = 0
    if stop_words:
        react_stop_words_tokens = [tokenizer.encode(stop_) for stop_ in stop_words]
        response_generator = model.stream_chat(tokenizer, query, history=history,
                                               stop_words_ids=react_stop_words_tokens)
    else:
        response_generator = model.stream_chat(tokenizer, query, history=history)

    for new_response in response_generator:
        new_response = new_response[0]
        if len(new_response) == current_length:
            continue
        new_text = new_response[current_length:]
        print(new_text, end='', flush=True)
        current_length = len(new_response)

        choice_data = ChatCompletionResponseStreamChoice(
            index=0,
            delta=DeltaMessage(content=new_text),
            finish_reason=None
        )
        chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
        yield "{}".format(chunk.model_dump_json(exclude_unset=True))

    choice_data = ChatCompletionResponseStreamChoice(
        index=0,
        delta=DeltaMessage(),
        finish_reason="stop"
    )
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
    print('', flush=True)
    yield "{}".format(chunk.model_dump_json(exclude_unset=True))
    yield '[DONE]'


def _get_args():
    parser = ArgumentParser()
    parser.add_argument("-c", "--checkpoint-path", type=str, default='Qwen/Qwen-7B-Chat',
                        help="Checkpoint name or path, default to %(default)r")
    parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
    parser.add_argument("--server-port", type=int, default=8090,
                        help="Demo server port.")
    parser.add_argument("--server-name", type=str, default="0.0.0.0",
                        help="Demo server name.")

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = _get_args()
    tokenizer = AutoTokenizer.from_pretrained(
        args.checkpoint_path, trust_remote_code=True, resume_download=True,
    )

    if args.cpu_only:
        model_origin = AutoModelForCausalLM.from_pretrained(
            args.checkpoint_path,
            device_map="cpu",
            trust_remote_code=True,
            resume_download=True,
        ).eval().float()
    else:
        model_origin = AutoModelForCausalLM.from_pretrained(
            args.checkpoint_path,
            device_map="auto",
            trust_remote_code=True,
            resume_download=True,
            fp16=True
        ).eval()

    model_origin.generation_config = GenerationConfig.from_pretrained(
        args.checkpoint_path, trust_remote_code=True, resume_download=True, max_new_tokens=1024, min_new_tokens=128
    )

    from fastllm_pytools import llm
    
    model = llm.from_hf(
        model_origin, tokenizer, dtype="int4"
    )

    del model_origin
    torch.cuda.empty_cache()

    uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)

Zhangtiande avatar Sep 12 '23 07:09 Zhangtiande

就算没有Lora,直接Qwen 7B导出成int4,问“你是谁”,也是这种效果。

./webui -p tools/qwen-7b-int4.flm --port 1234
image

failable avatar Oct 11 '23 07:10 failable