sglang
sglang copied to clipboard
Tutorial for Batch Decoding and Obtaining Log Probs
Hi Thanks for the great library I have a usecase which I think will benefit a lot from Radix Attention. I need to obtain log probs for around a 100K sequences which can be binned into groups of 100 having a similar prefix like 'Wikipedia originated in' and having 100 different suffixes. I do not need to generate anything and I only need the log probs for the input. Is there a tutorial for such a usecase?
Yes, RadixAttention can help your case a lot. We do not have this interface/tutorial ready, but I can easily make one for you.
What do you need specifically?
Given an input token sequence [a, b, c, d]
- the probably of all tokens at each location, which is a tensor of the shape [4, 32000]
- the logprob for each selected token, which is a tensor of the shape [4, 1]
- the logprob for the whole sequence (sum), which is a scalar.
- the logprob for the whole sequence (mean), which is a scalar.
@merrymercy Thanks for the quick response I just need the log probs of the selected tokens in the sequence given at each position (Option 2). Here is what I do currently with vLLM -
sampling_params = SamplingParams(max_tokens=1, prompt_logprobs=1)
llm = LLM(model= model_path, tokenizer= model_path)
outputs = llm.generate(all_input_texts, sampling_params=sampling_params)
# to get log probs for ith sample
outputs[i].prompt_logprobs
Do you also need the logprob for the shared prefix? Unfortunately, we do not store the logprob for the prefix. We only store the KV cache, so it is not very easy to also return the logprob for the shared prefix.
@merrymercy Nope I don't need it for the shared prefix, only for the non-shared portions
Like for example if the sentences are "Wikipedia originated in India", "Wikipedia originated in U.S.A", etc. I need it only for "India", "U.S.A" etc.
Great! This is easier. What do you do with the logprob? Do you compute the normalized logprob for selecting purposes?
Actually, the choices
in sglang is implemented by comparing the logprob of these choices. (https://github.com/sgl-project/sglang/blob/9076386d904171c7cc88ace681ca3ebbec2c71ea/examples/usage/readme_examples.py#L7)
Yeah I use the normalized logprobs and store them for later analysis. This example looks very relevant. If I understand correctly, something like this would populate the most likely choice from the options, right?
@sgl.function
def tool_use(s, question):
s += question
s += sgl.gen("tool", choices=["U.S.A", "India"])
runtime = Runtime(model_path='Model_Saves/teknium--OpenHermes-2.5-Mistral-7B')
set_default_backend(runtime)
driver_tool_use()
How can I now access the logprobs as well?
Yes, it will populate the most likely choices from the options based on the normalized logprobs (sum of the logprobs divided by the number of tokens)
I am working on some examples and interface updates for you to easily get the logprobs. I will upload them very soon!
Thank you for taking out the time! That would be really helpful!
@aflah02 Could you try this with the main branch? Does it meet your needs?
https://github.com/sgl-project/sglang/blob/main/examples/usage/choices_logprob.py
Output:
questions: What is 5 + 5?
choice: calculator
logprobs of choice 1 [-4.4240264892578125, -0.0002205615019192919]
logprobs of choice 2 [-12.680765151977539, -0.08715292066335678]
--------------------------------------------------
questions: What is 5 + 6?
choice: calculator
logprobs of choice 1 [-5.266744136810303, -0.00022354240354616195]
logprobs of choice 2 [-12.893030166625977, -0.09100916236639023]
--------------------------------------------------
questions: Who is Michael Jordan?
choice: search engine
logprobs of choice 1 [-10.858648300170898, -0.002947198925539851]
logprobs of choice 2 [-6.427036762237549, -0.00434991717338562]
--------------------------------------------------
Thanks a lot for sharing this! I need to install from source and then try this right? I'll do this by tomorrow and let you know!
Yes
Hi, I see that there is a parameter that can be passed by requests here to return logprobs https://github.com/sgl-project/sglang/blob/d3fc86a43e2287e0446a4b3c9acf1300611f1f85/python/sglang/srt/managers/router/model_rpc.py#L222-L223
Is there a way that we could specify this from the python end with sgl.gen
or SglFunction.run
?
@Ja1Zhou It is possible. I can work on an interface for this later. What kind of logprob do you need?
Do you need the logprob of prompts, the logprob of generation, the logprob of selected tokens, or the logprob of top-5 tokens?
Many thanks! Currently I would need logprobs of top-5 (or top-n passed as parameter) tokens for each generated token. The scenario is essentially the same as passing the top_logprobs=n
parameter to openai api.
An example would be the top_logprobs
field in this discussion.
One related question would be if the regex constraint is going to affect the top_logprobs
returned?
Thanks again for the swift reply. I would also love to look into supporting this logprobs feature!
Great! If you are interested, please go ahead. Our bandwidth is limited so your help would be great.
You can start from https://github.com/sgl-project/sglang/blob/0147f940ddc5642e6f88e404123881d69c2b7f0a/test/srt/test_httpserver_decode.py#L21-L33
Sorry for the delay @merrymercy Thanks a lot! This works really well. Leaving the issue open though as it seems there's another ongoing discussion, but my original issues have been resolved.
@merrymercy in your example, it doesn't seem like the sum exp of the log probs sums to one. I've been running this locally with Mistral 7B:
# launch server
# python -m sglang.launch_server --model-path /user/models/Mistral-7B-Instruct-v0.2-AWQ --port 30000
import sglang as sgl
set_default_backend(RuntimeEndpoint("http://localhost:30000"))
@sgl.function
def tool_use(s, question):
s += "To answer this question: " + question + ", "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])
# Run one case
question = "What is 5 + 5?"
state = tool_use.run(question)
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print("probs of choice 1", np.exp(meta_info["prompt_logprob"][0]))
print("probs of choice 2", np.exp(meta_info["prompt_logprob"][1]))
print("prob sum", np.exp(meta_info["prompt_logprob"][0][0]) + np.exp(meta_info["prompt_logprob"][0][1]))
print("prob sum", np.exp(meta_info["prompt_logprob"][1][0]) + np.exp(meta_info["prompt_logprob"][1][1]))
print('-' * 50)
# Run a batch
questions = [
"What is 5 + 6?",
"Who is Michael Jordan?",
]
states = tool_use.run_batch([{"question": q} for q in questions])
for question, state in zip(questions, states):
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print("probs of choice 1", np.exp(meta_info["prompt_logprob"][0]))
print("probs of choice 2", np.exp(meta_info["prompt_logprob"][1]))
print("prob sum", np.exp(meta_info["prompt_logprob"][0][0]) + np.exp(meta_info["prompt_logprob"][0][1]))
print("prob sum", np.exp(meta_info["prompt_logprob"][1][0]) + np.exp(meta_info["prompt_logprob"][1][1]))
print('-' * 50)
With output:
questions: What is 5 + 5?
choice: calculator
logprobs of choice 1 [-8.053388595581055, -0.011829511262476444]
logprobs of choice 2 [-12.069820404052734, -0.0010686860186979175]
probs of choice 1 [3.18022445e-04 9.88240182e-01]
probs of choice 2 [5.72985459e-06 9.98931885e-01]
prob sum 0.9885582047666096
prob sum 0.9989376146774299
--------------------------------------------------
questions: What is 5 + 6?
choice: calculator
logprobs of choice 1 [-6.829063892364502, -0.009234977886080742]
logprobs of choice 2 [-11.620172500610352, -0.0014222837053239346]
probs of choice 1 [0.00108187 0.99080753]
probs of choice 2 [8.98303733e-06 9.98578727e-01]
prob sum 0.9918894039483331
prob sum 0.9985877102981204
--------------------------------------------------
questions: Who is Michael Jordan?
choice: search engine
logprobs of choice 1 [-11.84224796295166, -0.2913426160812378]
logprobs of choice 2 [-9.289620399475098, -0.0018001894932240248]
probs of choice 1 [7.19410971e-06 7.47259611e-01]
probs of choice 2 [9.23781204e-05 9.98201430e-01]
prob sum 0.7472668051043772
prob sum 0.9982938079964204
Am I doing anything wrong? Ideally, I think the exp sum of binary log probs should sum to one.
@mlinegar They are not binary log probs. It is the log prob over the whole vocab set. The meaning of this log prob is the same as the log prob defined in the OpenAI API.
For any new questions. Please open a new issue.
@aflah02 Did you notice any performance improvement vs vllm or other libraries?
@merrymercy Yep it's a very significant speed up over vllm for my usecase :) Thanks for this library. My only pain point is some of the models I'm using are not supported under sglang for now, so I need to use vllm for them. It would be great to have support for Pythia, OPT and Falcon models. I also fail in loading Mixtral but haven't had the time to open an issue yet. Let me do that
We currently do not have the bandwidth to add these models. If you are interested, you can help us contribute them.
Adding a new model is very easy. We use an architecture very similar to vLLM. Here are the steps to add a new model
- Compare these two files (https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py, https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). You can learn how to convert a model implementation from vLLM to SGLang. We need to replace PagedAttention with RadixAttention. The other parts are almost the same.
- Convert models like OPT, Falcon, Pythia from vLLM to SGLang.
Thanks! I'll take a look at this
Hi how do i get the last_logits out? i dont need logprob for every token, but just the last one. I am using llava1.6 mistral 7b and it has a bug where i get some image encoding tensor dim error whenever i use the choices arg, so i cannot use. But if i dont use, how do i get the logits? I checked the source code and there seem to be a part where i can get the last logit instead of logprob for every token. How do I achieve that?