[Question] Is vLLMRollout.generate_sequences the right place to implement tool calling?
Hi, I am trying to understand the code. I would like to try RL training on tool calling in an interactive environment.
As I understand it, the reward is calculated by some custom reward function for a particular dataset. In other words, the flow of data during PPO is like this:
graph TD
DatasetExample --> InferenceRollout --> RewardFunction --> UpdateGradients
But the inference step rollout here is a one-shot input/output function. If online tool calling was desired, we'd have to hook the llm.generate function here, right?
https://github.com/volcengine/verl/blob/main/verl/workers/rollout/vllm_rollout/vllm_rollout.py#L181
Then we could inject in function calling. But i'm confused because the inference engine is not an ordinary VLLM LLM class, but a subclass which monkey patches the output to return tensors instead of the normal VLLM output format.
So what would be the best way to add in dynamic function calling? Hook the generate method of vLLM's LLM class, then call LLM._post_process_output to convert token_id and logprobs from VLLM into torch tensors at the very end?
Or is there an more obvious place to add in this feature?
Actually-- I thought about it a bit more. Perhaps the best way is to implement a custom LogitsProcessor for vLLM, which does this function calling by hijacking the logits at each step to detect function calls and force inject the function's output tokens. Then it should interface perfect with this library or any others using vLLM for inferencing and make the resulting model production ready.
Hi @accupham , thanks for your questions!
So what would be the best way to add in dynamic function calling? Hook the generate method of vLLM's LLM class, then call LLM._post_process_output to convert token_id and logprobs from VLLM into torch tensors at the very end?
Yes, adding dynamic function calling by hooking to the generate method is a good way. The _post_process_output is not necessarily a class function of LLM but can be moved to the vllm_rollout after all the results are ready and then converted them into tensors.
Actually-- I thought about it a bit more. Perhaps the best way is to implement a custom LogitsProcessor for vLLM, which does this function calling by hijacking the logits at each step to detect function calls and force inject the function's output tokens. Then it should interface perfect with this library or any others using vLLM for inferencing and make the resulting model production ready.
I agree. Using a custom LogitsProcessor can help detect the function calls. This can already by implemented in the current vLLM by assigning your custom LogitsProcessor functions to the SamplingParams. I believe you can implement this by modifying the code of vllm_rollout.py . But it may be better if we can make the customized func to be passed through config file so that users won't need to modify the vllm_rollout.py file. Are you interested in contributing to this feature?
Moreover, I'm not sure using the customized LogitsProcessor is general enough to cover "all" function calling scenarios? Some function calls may need to detokenize the token_ids first and using customized LogitsProcessor may not achieve the optimal throughput.
Yes I would be interested in contributing. Traditional function calling is usually done with the vLLM LLM.chat() calling semantics.
But we could leave this up to the user by letting them implement a pluggable function, which produces the final output tensors to pass onto the rest of the pipeline.
So we could take this from the vllm_rollout.py:
with self.update_sampling_params(**kwargs):
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=idx_list,
use_tqdm=False)
And instead have this RolloutSampler as the pluggable module, which can be passed in when initializing vLLMRollout class:
from typing import Protocol
from vllm.outputs import RequestOutput
class RolloutSampler(Protocol):
def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
...
# impl default RolloutSampler
class OneShotRolloutSampler:
def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
return llm.generate(
prompts=prompts, # pass in prompts instead of token_ids to make it user-friendly
sampling_params=sampling_params,
use_tqdm=False)
# impl RolloutSampler
class FnCallRolloutSampler:
@property
def tools() -> list[dict]:
...
def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
r1 = llm.chat(
messages= [{"role": "user", "content": "blah blah blah please do tool calling" }]
sampling_params=sampling_params,
use_tqdm=False,
tools: self.tools,
)
...
# execute tool calls from r1 and return results
r2 = llm.chat(...)
# get final RequestOutput response from vllm
r3 = llm.chat(...)
# (etc etc)
return r3 # list[RequestOutput] from vllm
So we modify the init function signature of vLLMRollout as follows:
class vLLMRollout(BaseRollout):
def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, rollout_sampler: RolloutSampler=OneShotRolloutSampler, **kwargs):
...
self.rollout_sampler = rollout_sampler
...
Now users can pass in their own sampler implementation according to their needs by writing familar VLLM API code, or use default one-shot sampler like the old way.
I think for function calling we will always need to detokenize token_ids first. Pass in prompts, and not token_ids into VLLM. It's not user friendly to deal with tokens because parsing function calls is string oriented and we can't expect there to be a dedicated token for function calling. Why don't we just let vLLM handle the tokenization instead?
I thought about it some more and maybe a custom LogitsProcessor may bottleneck the entire batch if there is excess latency for a single function call. Better just to do function calling the traditional way instead of reinventing the wheel.
What do you think of this proposal?
@accupham The API design is really nice from my perspective. However, it seems that it relies on vLLM 0.7.0 for the chat API. We're working on integrating it in: #116 . Will let you know once merged!
As the key challenge for generation in RL training task is throughput not latency, I have a few questions/concerns about this proposal:
-
Does the
chatAPI support batch processing in both prefill/decode and function calling? If the function call can also be batched and parallelized, the throughput would be acceptable. -
With such a design, how to implement the overlapping of prefill/decode computation with function calling as the function calling may be a remote function and could be quite time-consuming.
I think for function calling we will always need to detokenize token_ids first. Why don't we just let vLLM handle the tokenization instead?
-
From our experience, tokenize/detokenize is quite time-consuming in vLLM and veRL already tokenized the prompts into token_ids in the dataloader. So, currently, if the users want to use the original strings, they can simply call
tokenizer.batch_decodeand this operation is more efficient. What's your experience when using tokenizer/detokenize in vLLM? -
I would like to raise another point that this proposal doesn't mention and would also related to point 2. It is not quite relevant to the above API designs in your proposal but is very important in veRL. Currently, with Orca scheduler in vLLM, some prompts can finish early and quit. But LLMEngine will not return until all the prompts are finished. I believe if we can fetch these early exit sequences, it's possible to support the overlapping of prefill/decode computation with function calls. However, these features can only be implemented with vLLM support or by hijacking the scheduler code.
After doing a bit of digging, perhaps this API design enabling multi-turn tool calling interaction is not feasible from a performance perspective. Here's why:
Does the chat API support batch processing in both prefill/decode and function calling? If the function call can also be batched and parallelized, the throughput would be acceptable.
- The vLLM chat API does support batch processing. All this method does is simply apply a chat template, then call the generate API internally, which as expected, will do prefill and other vLLM optimimization if enabled.
- However, the function call part has issues. The tool call execution instructions will arrive in batches, but we would have to execute those serially, wait for all calls to resolve, then pass it back in as part of that batch group. Even then, we may have multiple back-and-forth calls for some items to complete, but not others, leading to the batch being partially empty toward the end.
- In other words, we have head-of-line congestion issues with traditional multi-turn tool calling, causing throughput issues.
From our experience, tokenize/detokenize is quite time-consuming in vLLM and veRL already tokenized the prompts into token_ids in the dataloader. So, currently, if the users want to use the original strings, they can simply call tokenizer.batch_decode and this operation is more efficient. What's your experience when using tokenizer/detokenize in vLLM?
- I've never benchmarked tokenization so I can't comment on that, but i'll have to take your word on it. If it really is time consuming, then using the chat API will be horrible for performance. We go from not having to tokenize/detokenize in the original implementation, to having to do tokenize/detokenize
4 * batch_size * n_turnsmore times. Chat API needs to tokenize many times to apply chat template, then decode/encode to parse function call. Not optimal.
Currently, with Orca scheduler in vLLM, some prompts can finish early and quit. But LLMEngine will not return until all the prompts are finished. I believe if we can fetch these early exit sequences, it's possible to support the overlapping of prefill/decode computation with function calls. However, these features can only be implemented with vLLM support or by hijacking the scheduler code.
- That sounds too complicated to implement in a user-friendly way-- multi-turn tool calling might be too hard to do if throughput is desired.
Alternative Proposal
- Allow the user to pass in a custom LogitProcessor into vLLM.
Pros
- This maintains all batching optimizations afforded by all subcomponents of the system. Throughput is maintained.
- Each LogitProcessor is cloned on vLLM's side, so if function calling is slow, it only affects that generation instance, and not others, making scheduling easy.
- LogitProcessor can be easily be used in production by passing it into vLLM
- The style of inline function calling is very fluid and lends to the reasoning style of R1-like models.
Cons
- Very untraditional approach (I've never seen this be done before)
- Token boundry issues may occur unless you train with special tokens
- Implementation is not as user-friendly because it deals with raw tokens and modifying logits manually.
- See proof of concept... It's hacky AF.
PoC Tool calling Logits Processor
from typing import Dict, List, Callable
class FunctionProcessor:
def __init__(
self,
tokenizer: PreTrainedTokenizer,
function_map: Dict[str, Callable],
start_tag: str = " <function>",
end_tag: str = " </function>",
result_start: str = "<results>",
result_end: str = "</results>"
):
self.tokenizer = tokenizer
self.function_map = function_map
self.buffer = []
self.in_function = False
self.current_function = []
# Pre-tokenize markers
self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)
self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)
self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
self.result_end = tokenizer.encode(result_end, add_special_tokens=False)
self.max_marker_len = max(
len(self.start_marker),
len(self.end_marker),
len(self.result_start),
len(self.result_end)
)
self.result_tokens = []
def evaluate_expression(self, expr: str) -> str:
# Strip the function markers
expr = expr.replace("<function>", "").replace("</function>", "").strip()
# Parse function call
func_name = expr.split("(")[0]
args_str = expr.split("(")[1].rstrip(")")
# Get the function from our map
if func_name not in self.function_map:
return f"Error: Unknown function {func_name}"
func = self.function_map[func_name]
try:
# Parse args - this could be made more sophisticated
args = [float(arg.strip()) for arg in args_str.split(",")]
result = func(*args)
return str(result)
except Exception as e:
return f"Error: {str(e)}"
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
try:
self.buffer.extend(input_ids[-1:])
if self.result_tokens:
#scores.fill_(-float('inf'))
scores[self.result_tokens.pop()] = 100
return scores
self.buffer = self.buffer[-self.max_marker_len*2:]
#print(self.tokenizer.decode(self.buffer))
if not self.in_function and self.check_marker(self.start_marker):
self.in_function = True
self.current_function = []
return scores
if self.in_function:
self.current_function.extend(input_ids[-1:])
if self.check_marker(self.end_marker):
self.in_function = False
func_text = self.tokenizer.decode(self.current_function)
result = self.evaluate_expression(func_text)
self.result_tokens = list(reversed(
self.result_start +
self.tokenizer.encode(result) +
self.result_end
))
scores[self.result_tokens.pop()] = 100
return scores
except Exception as e:
print(f"Error in processor: {e}")
return scores
def check_marker(self, marker: List[int]) -> bool:
#print(marker, self.buffer)
marker_len = len(marker)
buffer_len = len(self.buffer)
if buffer_len < marker_len:
return False
# Only need to check the last possible positions where marker could fit
start_pos = max(0, buffer_len - marker_len * 2)
for i in range(start_pos, buffer_len - marker_len + 1):
if self.buffer[i:i + marker_len] == marker:
return True
return False
PoC Usage
def add(x, y):
return x + y
def multiply(x, y):
return x * y
# Create function map
function_map = {
"add": add,
"multiply": multiply
}
my_tool_processor = FunctionProcessor(tokenizer, function_map)
prompts = [
"Hello world. please say <function> multiply(3, 302) </function> \nthis is a test",
"Hello world. please say <function> add(3, 302) </function> ",
]
r = llm.generate( # llm is a vllm LLM instance
prompts,
SamplingParams(
logits_processors=[my_tool_processor],
max_tokens=200,
))
#print(r)
for rr in r:
print(rr.outputs[0].text)
print("----")
<function> multiply(3, 302) </function><result>906.0</result>. Can I assist with anything else?
---
<function> add(3, 302) </function><result>305.0</result>. Is there anything else I can help with?
@accupham Sorry for the late response. Too busy recently, I will investigate your proposal this weekend.
I don’t think my proposal is feasible… it’s fast, but too cumbersome to work with.
Check out the work being done in TRL though
https://github.com/huggingface/trl/pull/2810
if there's interest, would love to try and get this working with the verifiers repo i'm building, mostly focused on TRL so far (https://github.com/huggingface/trl/pull/2810) but hopefully the Environments can be interoperable across libraries
in TRL, the semantics are basically:
if env is None:
outputs = llm.generate(prompts, sampling_params)
completion_ids = [o.outputs[0].token_ids for o in outputs]
else
completion_ids = env.generate(llm, prompts, sampling_params)
i.e. we're just wrapping the generate step to allow custom rollout logic (seems similar to the discussion here)
with regard to the performance discussions above, it seems to me like the repeated tokenization from llm.chat() shouldn't be a major bottleneck relative to actual generation, no? this is done anyways when using any LLM API for tool calling
i'm implementing batched vLLM requests, dug into AsyncLLMEngine for allowing async tool calls but seems to add a high amount of complexity when inside of a trainer
@accupham can you elaborate on "too cumbersome to work with"? thanks.
@accupham can you elaborate on "too cumbersome to work with"? thanks.
It’s difficult to work with raw tokens to implement tool calling via logit processor. Token boundaries are very weird and rarely align with how you’d like to do tool calling. So while it’s possible, perhaps by building a trie from a tokenizer vocabulary, it is very difficult to build tool calling on top of this idea.
So I like @willccbb’s idea much better, as we can use the familiar semantics of tool calling to build RL pipelines on top of.
Perhaps throughput requirements should be relaxed here in favor of something at least working?
After doing a bit of digging, perhaps this API design enabling multi-turn tool calling interaction is not feasible from a performance perspective. Here's why:
Does the chat API support batch processing in both prefill/decode and function calling? If the function call can also be batched and parallelized, the throughput would be acceptable.
The vLLM chat API does support batch processing. All this method does is simply apply a chat template, then call the generate API internally, which as expected, will do prefill and other vLLM optimimization if enabled.
However, the function call part has issues. The tool call execution instructions will arrive in batches, but we would have to execute those serially, wait for all calls to resolve, then pass it back in as part of that batch group. Even then, we may have multiple back-and-forth calls for some items to complete, but not others, leading to the batch being partially empty toward the end.
- In other words, we have head-of-line congestion issues with traditional multi-turn tool calling, causing throughput issues.
From our experience, tokenize/detokenize is quite time-consuming in vLLM and veRL already tokenized the prompts into token_ids in the dataloader. So, currently, if the users want to use the original strings, they can simply call tokenizer.batch_decode and this operation is more efficient. What's your experience when using tokenizer/detokenize in vLLM?
- I've never benchmarked tokenization so I can't comment on that, but i'll have to take your word on it. If it really is time consuming, then using the chat API will be horrible for performance. We go from not having to tokenize/detokenize in the original implementation, to having to do tokenize/detokenize
4 * batch_size * n_turnsmore times. Chat API needs to tokenize many times to apply chat template, then decode/encode to parse function call. Not optimal.Currently, with Orca scheduler in vLLM, some prompts can finish early and quit. But LLMEngine will not return until all the prompts are finished. I believe if we can fetch these early exit sequences, it's possible to support the overlapping of prefill/decode computation with function calls. However, these features can only be implemented with vLLM support or by hijacking the scheduler code.
- That sounds too complicated to implement in a user-friendly way-- multi-turn tool calling might be too hard to do if throughput is desired.
Alternative Proposal
- Allow the user to pass in a custom LogitProcessor into vLLM.
Pros
- This maintains all batching optimizations afforded by all subcomponents of the system. Throughput is maintained.
- Each LogitProcessor is cloned on vLLM's side, so if function calling is slow, it only affects that generation instance, and not others, making scheduling easy.
- LogitProcessor can be easily be used in production by passing it into vLLM
- The style of inline function calling is very fluid and lends to the reasoning style of R1-like models.
Cons
Very untraditional approach (I've never seen this be done before)
Token boundry issues may occur unless you train with special tokens
Implementation is not as user-friendly because it deals with raw tokens and modifying logits manually.
- See proof of concept... It's hacky AF.
PoC Tool calling Logits Processor
from typing import Dict, List, Callable
class FunctionProcessor: def init( self, tokenizer: PreTrainedTokenizer, function_map: Dict[str, Callable], start_tag: str = "
", end_tag: str = " ", result_start: str = "", result_end: str = " " ): self.tokenizer = tokenizer self.function_map = function_map self.buffer = [] self.in_function = False self.current_function = []# Pre-tokenize markers self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False) self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False) self.result_start = tokenizer.encode(result_start, add_special_tokens=False) self.result_end = tokenizer.encode(result_end, add_special_tokens=False) self.max_marker_len = max( len(self.start_marker), len(self.end_marker), len(self.result_start), len(self.result_end) ) self.result_tokens = [] def evaluate_expression(self, expr: str) -> str: # Strip the function markers expr = expr.replace("<function>", "").replace("</function>", "").strip() # Parse function call func_name = expr.split("(")[0] args_str = expr.split("(")[1].rstrip(")") # Get the function from our map if func_name not in self.function_map: return f"Error: Unknown function {func_name}" func = self.function_map[func_name] try: # Parse args - this could be made more sophisticated args = [float(arg.strip()) for arg in args_str.split(",")] result = func(*args) return str(result) except Exception as e: return f"Error: {str(e)}" def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: try: self.buffer.extend(input_ids[-1:]) if self.result_tokens: #scores.fill_(-float('inf')) scores[self.result_tokens.pop()] = 100 return scores self.buffer = self.buffer[-self.max_marker_len*2:] #print(self.tokenizer.decode(self.buffer)) if not self.in_function and self.check_marker(self.start_marker): self.in_function = True self.current_function = [] return scores if self.in_function: self.current_function.extend(input_ids[-1:]) if self.check_marker(self.end_marker): self.in_function = False func_text = self.tokenizer.decode(self.current_function) result = self.evaluate_expression(func_text) self.result_tokens = list(reversed( self.result_start + self.tokenizer.encode(result) + self.result_end )) scores[self.result_tokens.pop()] = 100 return scores except Exception as e: print(f"Error in processor: {e}") return scores def check_marker(self, marker: List[int]) -> bool: #print(marker, self.buffer) marker_len = len(marker) buffer_len = len(self.buffer) if buffer_len < marker_len: return False # Only need to check the last possible positions where marker could fit start_pos = max(0, buffer_len - marker_len * 2) for i in range(start_pos, buffer_len - marker_len + 1): if self.buffer[i:i + marker_len] == marker: return True return FalsePoC Usage
def add(x, y): return x + y
def multiply(x, y): return x * y
Create function map
function_map = { "add": add, "multiply": multiply }
my_tool_processor = FunctionProcessor(tokenizer, function_map)
prompts = [ "Hello world. please say
multiply(3, 302) \nthis is a test", "Hello world. please sayadd(3, 302) ", ] r = llm.generate( # llm is a vllm LLM instance prompts, SamplingParams( logits_processors=[my_tool_processor], max_tokens=200, ))#print(r) for rr in r: print(rr.outputs[0].text) print("----")
<function> multiply(3, 302) </function><result>906.0</result>. Can I assist with anything else? --- <function> add(3, 302) </function><result>305.0</result>. Is there anything else I can help with?
I've tried the method mentioned above with verl 0.2 and vllm 0.6.3, However, it randomly hangs after running for 1 to 2 steps. Specifically, the GPU utilization gets stuck at 100%, while the power consumption drops significantly low, and the logs stop updating. Below is the code snippet I used:
class FunctionProcessor:
def __init__(
self,
tokenizer,
start_tag: str = "<tool_call>",
end_tag: str = "</tool_call>",
result_start: str = "\n<tool_result>\n",
result_end: str = "\n</tool_result>\n<think>"
):
self.tokenizer = tokenizer
self.buffer = []
self.in_function = False
self.current_function = []
# Pre-tokenize markers
self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)[0]
self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)[0]
self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
self.result_end = tokenizer.encode(result_end, add_special_tokens=False)
self.result_tokens = []
self.state_dict = {}
def evaluate_expression(self, expr: str) -> str:
try:
# get_tool_resp is the function that will be called to evaluate the expression, time cost no more than 3 seconds.
result = get_tool_resp(expr)
return str(result)
except Exception as e:
return f"Error: {str(e)}"
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
try:
if input_ids[-1] == self.end_marker:
idx = 1
while idx <= len(input_ids):
if input_ids[-idx] == self.start_marker:
if input_ids[-idx:].count(self.start_marker) > 1 or input_ids[-idx:].count(self.end_marker) > 1:
break
current_function = input_ids[-idx:]
func_text = self.tokenizer.decode(current_function)
try:
result = self.evaluate_expression(func_text)
except:
result = "{'result': 'Tool Call Error'}"
result_tokens = list(reversed(
self.result_start +
self.tokenizer.encode(str(result)) +
self.result_end
))
state_dict_key = tuple(input_ids)
self.state_dict[state_dict_key] = result_tokens
token_id = self.state_dict[state_dict_key].pop()
scores[token_id] = 100
break
idx += 1
else:
for idx in range(1, len(self.end_marker)):
if input_ids[-idx] == self.start_marker:
state_dict_key = tuple(input_ids[:-idx + 1])
result_tokens = self.state_dict.get(state_dict_key, [])
if result_tokens:
self.state_dict[state_dict_key] = result_tokens
token_id = self.state_dict[state_dict_key].pop()
scores[token_id] = 100
break
except Exception as e:
print(f"Error in FunctionProcessor: {e}")
return scores
I made adjustments to the code since, during batch decoding, the shared object leads to value overwriting among different sequences.
@PeterSH6 are you interested in a PR that implements this:
...the customized func to be passed through config file so that users won't need to modify the vllm_rollout.py file. Are you interested in contributing to this feature?
My thought is to add a string value to the config file allowing the user to specify a class to be instantiated at runtime with importlib. This hook will allow the user to pass in arbitrary values of logits_processors.
@frrad If you do end up implementing this, you should make it compatible with how VLLM passes in custom logits processors:
pip install logits-processor-zoo
vllm serve ... --logits-processor-pattern '.*'
And call API with extra_body:
{"logits_processors": "qualname": "logits_processor_zoo.vllm.cite_prompt"}
You'll want to instaniate the tokenizer only once and cache it for performance.
Comments from @youkaichao - the best way to do tool calling with vllm, is to initialize vllm with the async llm engine (instead of using the current llm.generate API. The llm.generate API is for bathc inference, while async llm engine is like a API server which naturally supports async request and multi-turn generation. It also supports SPMD.
Here is a reference implementation from the community: https://github.com/cfpark00/verl/blob/a3d761be3510974b9ad605475a31329cebf324e3/verl/workers/rollout/vllm_rollout/vllm_rollout.py
Here is a reference implementation from the community: https://github.com/cfpark00/verl/blob/a3d761be3510974b9ad605475a31329cebf324e3/verl/workers/rollout/vllm_rollout/vllm_rollout.py
Yeah I think currently there are some nice implementations about using AsyncLLMEngine. The core problem is that I have no idea how to update the weights of AsyncLLMEngine.
To my understanding, verl uses llm.generate right now. And because verl uses the SPMD mode, llm.generate is replicated TP times (tensor-parallel size times).
If we want to have multi-turn / tool calling inside the same process, and we write it this way:
answer = llm.generate(prompts)
new_prompts = call_tool(answer, prompt) # interact with the environment
new_answer = llm.generate(new_prompts)
Then call_tool will also be replicated TP times. What's worse, if the environment is not deterministic, then every TP rank gets different new_prompts, and this violates the SPMD mode, and the output from new_answer = llm.generate(new_prompts) will be wrong, or hanging.
To make it work, we will need to manually make it SPMD-compatible:
answer = llm.generate(prompts)
if tp_rank == 0:
new_prompts = call_tool(answer, prompt) # interact with the environment, only for tp rank 0
broadcast_across_tp(new_prompts, src=0)
else:
new_prompts = broadcast_across_tp(src=0)
new_answer = llm.generate(new_prompts)
This might work functionally, but will not be performant. All prompts needs to finish generation before anyone can call the tool.
Ideally, we would want to have async between model generation and tool calling, but since verl forces vLLM to live in the same process as the training worker, and the tool caller in this case, I don't see any chance how we can have the async execution.
@youkaichao thanks for the comments. while the original issue is discussing the design based on vllm v0.6.3 + verl, we actually have a branch integrated vllm async server the newer version of vllm. we can upstream this implementation such that tool calling can be done with async generation, as long as the input to TP ranks are consistent
Here is a reference implementation from the community: https://github.com/cfpark00/verl/blob/a3d761be3510974b9ad605475a31329cebf324e3/verl/workers/rollout/vllm_rollout/vllm_rollout.py
Yeah I think currently there are some nice implementations about using AsyncLLMEngine. The core problem is that I have no idea how to update the weights of AsyncLLMEngine.
@SparkJiao do you mean this fork does not handle weight update correctly, or you just do not know how it is handled in actual implementation?
To my understanding, verl uses
llm.generateright now. And because verl uses the SPMD mode,llm.generateis replicated TP times (tensor-parallel size times).If we want to have multi-turn / tool calling inside the same process, and we write it this way:
answer = llm.generate(prompts) new_prompts = call_tool(answer, prompt) # interact with the environment new_answer = llm.generate(new_prompts) Then
call_toolwill also be replicated TP times. What's worse, if the environment is not deterministic, then every TP rank gets differentnew_prompts, and this violates the SPMD mode, and the output fromnew_answer = llm.generate(new_prompts)will be wrong, or hanging.To make it work, we will need to manually make it SPMD-compatible:
answer = llm.generate(prompts) if tp_rank == 0: new_prompts = call_tool(answer, prompt) # interact with the environment, only for tp rank 0 broadcast_across_tp(new_prompts, src=0) else: new_prompts = broadcast_across_tp(src=0) new_answer = llm.generate(new_prompts) This might work functionally, but will not be performant. All prompts needs to finish generation before anyone can call the tool.
Ideally, we would want to have async between model generation and tool calling, but since verl forces vLLM to live in the same process as the training worker, and the tool caller in this case, I don't see any chance how we can have the async execution.
This works for me! Specifically, using vllm_ps to get local tp rank and conduct broadcast across tp.
tp_rank = vllm_ps.get_tensor_model_parallel_rank()
if tp_rank == 0:
tool_call_results = tool_call(prompt) # do tool call only fro tp rank 0
broadcast_data = {
'tool_call_results': tool_call_results,
}
broadcast_data = vllm_ps._TP.broadcast_object(broadcast_data, src=0) # broadcast tool call results across tp
Here is a reference implementation from the community: https://github.com/cfpark00/verl/blob/a3d761be3510974b9ad605475a31329cebf324e3/verl/workers/rollout/vllm_rollout/vllm_rollout.py
Yeah I think currently there are some nice implementations about using AsyncLLMEngine. The core problem is that I have no idea how to update the weights of AsyncLLMEngine.
@SparkJiao do you mean this fork does not handle weight update correctly, or you just do not know how it is handled in actual implementation?
@eric-haibin-lin No, it handles correctly. But it does not use AsyncLLM. We can use vllm.LLM now as a temporary solution for now (like what search-r1 did). The problem is that all requests would wait each other at the end of each round, this would cause high concurrency to the environment and increase the latency.
I think the ultimate solution is using AsyncLLM, like what Kaichao said, each request will do its own multi-round rollout till the termination state. Currently the problem is (1) as Kaichao said, we need some specific logic to ensure the consistency across TP ranks for each request; (2) AsyncLLM with online weight update is no longer supported since vllm == 0.8.2.
I think the repo below provides a nice implementation based on AsyncLLM:
The problem is that I have tried the few lines for online updating AsnycLLM weights for both vllm 0.8.2 and vllm 0.8.3, but both failed: https://github.com/agentica-project/verl-pipeline/blob/master/verl/workers/sharding_manager/fsdp_vllm.py#L99-L102
The problem is that I have tried the few lines for online updating AsnycLLM weights for both vllm 0.8.2 and vllm 0.8.3, but both failed: agentica-project/verl-pipeline@master/verl/workers/sharding_manager/fsdp_vllm.py#L99-L102
You need to follow
https://github.com/vllm-project/vllm/blob/fdcb850f1424eca5f914578187ef31642c6e422d/vllm/v1/engine/async_llm.py#L430
to see how async_llm.wake_up calls the underlying worker's wake_up function, and add async_llm.collective_rpc functionality so that you can call async_llm.collective_rpc("update_weight"), with your worker extension class defined at the start time of creating the async_llm. Then, follow https://github.com/vllm-project/vllm/blob/fdcb850f1424eca5f914578187ef31642c6e422d/examples/offline_inference/rlhf_colocate.py to pass cuda ipc handles istead of passing tensors to synchronize weights.
Thank you @youkaichao ! I will take a look on this and think about how to implement it.
The agent loop is the recommended place: https://verl.readthedocs.io/en/latest/start/agentic_rl.html
I've tried the method mentioned above with verl 0.2 and vllm 0.6.3, However, it randomly hangs after running for 1 to 2 steps. Specifically, the GPU utilization gets stuck at 100%, while the power consumption drops significantly low, and the logs stop updating.
@AIBionics hi 我也是verl训练GRPO 总是hang在第二个迭代update_policy的backward阶段 ray actor都是活的,GPU利用率全是100% 功率比较低 请问你的问题最后有解决吗 root cause是啥