sglang icon indicating copy to clipboard operation
sglang copied to clipboard

no batch run when using openai's format for calling.

Open xjw00654 opened this issue 9 months ago • 0 comments

I just use this command to start the server CUDA_VISIBLE_DEVICES=0 python -m sglang.launch_server --model-path LLMs/Qwen-14B-Chat --port 30000 --trust-remote-code --stream-interval 1 --enable-flashinfer --schedule-conservativeness 50 and using the following code to test the concurrent capability.

It can only generate code with ~10tokens/s whereas the vllm can be ~30tokens/s. it seems the call method does not support batch inferencing. the logs show as below: image there is always 1 runnning_req.

The question is should we do it myself to support the batch inferencing when API calling or is something wrong with my setup? BTW, I also tried the batching example from the README, and it works fine and running faster then I expected!!!

Thank you so much ahead.

SCRIPTS

def run(ds):
    winner = "a" if "_" not in ds["winner"] else ds["winner"].split("_")[1]
    conversation = ds[f"conversation_{winner}"]

    st = time.time()
    answer = []
    for i in range(ds["turn"]):
        current_msg_start_time = time.time()
        query = conversation[i * 2]
        history = conversation[: i * 2]
        messages = history + [query]
        assert query["role"] == "user"
        resp = requests.post(
            "http://localhost:30000/v1/chat/completions",
            data=json.dumps(
                {
                    "messages": messages,
                    "stream": True,
                    "temperature": 0,
                    "stop": [
                        "<|endoftext|>", "<|im_end|>",
                    ],
                    "model": "Qwen-14B-Chat",
                    "max_tokens": "2048",
                }
            ),
            headers={"accept": "application/json"},
            timeout=600,
            stream=True,
        )

        client = sseclient.SSEClient(resp)
        data = ""
        for i, event in enumerate(client.events()):
            if event.data != "[DONE]":
                if i == 1:
                    first_packet_time = time.time() - current_msg_start_time
                data = event.data
        data = json.loads(data)

if __name__ == "__main__":
    for N in [5, ]:
        user_pool = []
        for i in range(N):
            th = Thread(target=run, args=(ds[random.randint(0, 32999)],))
            user_pool.append(th)
        for t in user_pool:
            t.start()
        for t in user_pool:
            t.join()

xjw00654 avatar Apr 30 '24 01:04 xjw00654