verl icon indicating copy to clipboard operation
verl copied to clipboard

[Question] Is vLLMRollout.generate_sequences the right place to implement tool calling?

Open irdbl opened this issue 1 year ago • 14 comments

Hi, I am trying to understand the code. I would like to try RL training on tool calling in an interactive environment.

As I understand it, the reward is calculated by some custom reward function for a particular dataset. In other words, the flow of data during PPO is like this:

graph TD
   DatasetExample --> InferenceRollout --> RewardFunction --> UpdateGradients

But the inference step rollout here is a one-shot input/output function. If online tool calling was desired, we'd have to hook the llm.generate function here, right?

https://github.com/volcengine/verl/blob/main/verl/workers/rollout/vllm_rollout/vllm_rollout.py#L181

Then we could inject in function calling. But i'm confused because the inference engine is not an ordinary VLLM LLM class, but a subclass which monkey patches the output to return tensors instead of the normal VLLM output format.

So what would be the best way to add in dynamic function calling? Hook the generate method of vLLM's LLM class, then call LLM._post_process_output to convert token_id and logprobs from VLLM into torch tensors at the very end?

Or is there an more obvious place to add in this feature?

irdbl avatar Jan 31 '25 05:01 irdbl

Actually-- I thought about it a bit more. Perhaps the best way is to implement a custom LogitsProcessor for vLLM, which does this function calling by hijacking the logits at each step to detect function calls and force inject the function's output tokens. Then it should interface perfect with this library or any others using vLLM for inferencing and make the resulting model production ready.

irdbl avatar Jan 31 '25 06:01 irdbl

Hi @accupham , thanks for your questions!

So what would be the best way to add in dynamic function calling? Hook the generate method of vLLM's LLM class, then call LLM._post_process_output to convert token_id and logprobs from VLLM into torch tensors at the very end?

Yes, adding dynamic function calling by hooking to the generate method is a good way. The _post_process_output is not necessarily a class function of LLM but can be moved to the vllm_rollout after all the results are ready and then converted them into tensors.

PeterSH6 avatar Jan 31 '25 13:01 PeterSH6

Actually-- I thought about it a bit more. Perhaps the best way is to implement a custom LogitsProcessor for vLLM, which does this function calling by hijacking the logits at each step to detect function calls and force inject the function's output tokens. Then it should interface perfect with this library or any others using vLLM for inferencing and make the resulting model production ready.

I agree. Using a custom LogitsProcessor can help detect the function calls. This can already by implemented in the current vLLM by assigning your custom LogitsProcessor functions to the SamplingParams. I believe you can implement this by modifying the code of vllm_rollout.py . But it may be better if we can make the customized func to be passed through config file so that users won't need to modify the vllm_rollout.py file. Are you interested in contributing to this feature?

Moreover, I'm not sure using the customized LogitsProcessor is general enough to cover "all" function calling scenarios? Some function calls may need to detokenize the token_ids first and using customized LogitsProcessor may not achieve the optimal throughput.

PeterSH6 avatar Jan 31 '25 13:01 PeterSH6

Yes I would be interested in contributing. Traditional function calling is usually done with the vLLM LLM.chat() calling semantics.

But we could leave this up to the user by letting them implement a pluggable function, which produces the final output tensors to pass onto the rest of the pipeline.

So we could take this from the vllm_rollout.py:

        with self.update_sampling_params(**kwargs):
            output = self.inference_engine.generate(
                prompts=None,  # because we have already convert it to prompt token id
                sampling_params=self.sampling_params,
                prompt_token_ids=idx_list,
                use_tqdm=False)

And instead have this RolloutSampler as the pluggable module, which can be passed in when initializing vLLMRollout class:

from typing import Protocol
from vllm.outputs import RequestOutput

class RolloutSampler(Protocol):
    def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
       ...

# impl default RolloutSampler
class OneShotRolloutSampler:
    def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
        return llm.generate(
                prompts=prompts,  # pass in prompts instead of token_ids to make it user-friendly
                sampling_params=sampling_params,
                use_tqdm=False)

# impl RolloutSampler
class FnCallRolloutSampler:
    @property
    def tools() -> list[dict]:
         ...

    def __call__(self, llm, prompts, sampling_params) -> list[RequestOutput]:
        r1 = llm.chat(
            messages= [{"role": "user", "content": "blah blah blah please do tool calling" }]
            sampling_params=sampling_params,
            use_tqdm=False,
            tools: self.tools,
        )
       
       ...
       # execute tool calls from r1 and return results 
       r2 = llm.chat(...)
       # get final RequestOutput response from vllm
       r3 = llm.chat(...)
       # (etc etc)
       return r3 # list[RequestOutput] from vllm

So we modify the init function signature of vLLMRollout as follows:

class vLLMRollout(BaseRollout):
    def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, rollout_sampler:  RolloutSampler=OneShotRolloutSampler, **kwargs):
    ...
    self.rollout_sampler = rollout_sampler
    ...

Now users can pass in their own sampler implementation according to their needs by writing familar VLLM API code, or use default one-shot sampler like the old way.


I think for function calling we will always need to detokenize token_ids first. Pass in prompts, and not token_ids into VLLM. It's not user friendly to deal with tokens because parsing function calls is string oriented and we can't expect there to be a dedicated token for function calling. Why don't we just let vLLM handle the tokenization instead?

I thought about it some more and maybe a custom LogitsProcessor may bottleneck the entire batch if there is excess latency for a single function call. Better just to do function calling the traditional way instead of reinventing the wheel.

What do you think of this proposal?

irdbl avatar Jan 31 '25 18:01 irdbl

@accupham The API design is really nice from my perspective. However, it seems that it relies on vLLM 0.7.0 for the chat API. We're working on integrating it in: #116 . Will let you know once merged!

As the key challenge for generation in RL training task is throughput not latency, I have a few questions/concerns about this proposal:

  1. Does the chat API support batch processing in both prefill/decode and function calling? If the function call can also be batched and parallelized, the throughput would be acceptable.

  2. With such a design, how to implement the overlapping of prefill/decode computation with function calling as the function calling may be a remote function and could be quite time-consuming.

I think for function calling we will always need to detokenize token_ids first. Why don't we just let vLLM handle the tokenization instead?

  1. From our experience, tokenize/detokenize is quite time-consuming in vLLM and veRL already tokenized the prompts into token_ids in the dataloader. So, currently, if the users want to use the original strings, they can simply call tokenizer.batch_decode and this operation is more efficient. What's your experience when using tokenizer/detokenize in vLLM?

  2. I would like to raise another point that this proposal doesn't mention and would also related to point 2. It is not quite relevant to the above API designs in your proposal but is very important in veRL. Currently, with Orca scheduler in vLLM, some prompts can finish early and quit. But LLMEngine will not return until all the prompts are finished. I believe if we can fetch these early exit sequences, it's possible to support the overlapping of prefill/decode computation with function calls. However, these features can only be implemented with vLLM support or by hijacking the scheduler code.

PeterSH6 avatar Feb 01 '25 04:02 PeterSH6

After doing a bit of digging, perhaps this API design enabling multi-turn tool calling interaction is not feasible from a performance perspective. Here's why:

Does the chat API support batch processing in both prefill/decode and function calling? If the function call can also be batched and parallelized, the throughput would be acceptable.

  • The vLLM chat API does support batch processing. All this method does is simply apply a chat template, then call the generate API internally, which as expected, will do prefill and other vLLM optimimization if enabled.
  • However, the function call part has issues. The tool call execution instructions will arrive in batches, but we would have to execute those serially, wait for all calls to resolve, then pass it back in as part of that batch group. Even then, we may have multiple back-and-forth calls for some items to complete, but not others, leading to the batch being partially empty toward the end.
    • In other words, we have head-of-line congestion issues with traditional multi-turn tool calling, causing throughput issues.

From our experience, tokenize/detokenize is quite time-consuming in vLLM and veRL already tokenized the prompts into token_ids in the dataloader. So, currently, if the users want to use the original strings, they can simply call tokenizer.batch_decode and this operation is more efficient. What's your experience when using tokenizer/detokenize in vLLM?

  • I've never benchmarked tokenization so I can't comment on that, but i'll have to take your word on it. If it really is time consuming, then using the chat API will be horrible for performance. We go from not having to tokenize/detokenize in the original implementation, to having to do tokenize/detokenize 4 * batch_size * n_turns more times. Chat API needs to tokenize many times to apply chat template, then decode/encode to parse function call. Not optimal.

Currently, with Orca scheduler in vLLM, some prompts can finish early and quit. But LLMEngine will not return until all the prompts are finished. I believe if we can fetch these early exit sequences, it's possible to support the overlapping of prefill/decode computation with function calls. However, these features can only be implemented with vLLM support or by hijacking the scheduler code.

  • That sounds too complicated to implement in a user-friendly way-- multi-turn tool calling might be too hard to do if throughput is desired.

Alternative Proposal

  • Allow the user to pass in a custom LogitProcessor into vLLM.

Pros

  • This maintains all batching optimizations afforded by all subcomponents of the system. Throughput is maintained.
  • Each LogitProcessor is cloned on vLLM's side, so if function calling is slow, it only affects that generation instance, and not others, making scheduling easy.
  • LogitProcessor can be easily be used in production by passing it into vLLM
  • The style of inline function calling is very fluid and lends to the reasoning style of R1-like models.

Cons

  • Very untraditional approach (I've never seen this be done before)
  • Token boundry issues may occur unless you train with special tokens
  • Implementation is not as user-friendly because it deals with raw tokens and modifying logits manually.
    • See proof of concept... It's hacky AF.

PoC Tool calling Logits Processor

from typing import Dict, List, Callable

class FunctionProcessor:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        function_map: Dict[str, Callable],
        start_tag: str = " <function>",
        end_tag: str = " </function>",
        result_start: str = "<results>",
        result_end: str = "</results>"
    ):
        self.tokenizer = tokenizer
        self.function_map = function_map
        self.buffer = []
        self.in_function = False
        self.current_function = []
        
        # Pre-tokenize markers 
        self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)
        self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)
        self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
        self.result_end = tokenizer.encode(result_end, add_special_tokens=False)
        self.max_marker_len = max(
                len(self.start_marker), 
                len(self.end_marker),
                len(self.result_start),
                len(self.result_end)
            )
        
        self.result_tokens = []
        
    def evaluate_expression(self, expr: str) -> str:
        # Strip the function markers
        expr = expr.replace("<function>", "").replace("</function>", "").strip()
        
        # Parse function call
        func_name = expr.split("(")[0]
        args_str = expr.split("(")[1].rstrip(")")
        
        # Get the function from our map
        if func_name not in self.function_map:
            return f"Error: Unknown function {func_name}"
            
        func = self.function_map[func_name]
        
        try:
            # Parse args - this could be made more sophisticated
            args = [float(arg.strip()) for arg in args_str.split(",")]
            result = func(*args)
            return str(result)
        except Exception as e:
            return f"Error: {str(e)}"
        
    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        try:
            self.buffer.extend(input_ids[-1:])

            if self.result_tokens:
                #scores.fill_(-float('inf'))
                scores[self.result_tokens.pop()] = 100
                return scores
            
            self.buffer = self.buffer[-self.max_marker_len*2:]
            #print(self.tokenizer.decode(self.buffer))
            if not self.in_function and self.check_marker(self.start_marker):
                self.in_function = True
                self.current_function = []
                return scores
                
            if self.in_function:
                self.current_function.extend(input_ids[-1:])
                
                if self.check_marker(self.end_marker):
                    self.in_function = False
                    func_text = self.tokenizer.decode(self.current_function)
                    result = self.evaluate_expression(func_text)
                    
                    self.result_tokens = list(reversed(
                        self.result_start +
                        self.tokenizer.encode(result) +
                        self.result_end
                    ))
                    scores[self.result_tokens.pop()] = 100
                    

                        
            return scores
            
        except Exception as e:
            print(f"Error in processor: {e}")
            return scores
            
    def check_marker(self, marker: List[int]) -> bool:
        #print(marker, self.buffer)
        marker_len = len(marker)
        buffer_len = len(self.buffer)
        
        if buffer_len < marker_len:
            return False
            
        # Only need to check the last possible positions where marker could fit
        start_pos = max(0, buffer_len - marker_len * 2)
        
        for i in range(start_pos, buffer_len - marker_len + 1):
            if self.buffer[i:i + marker_len] == marker:
                return True
                
        return False

PoC Usage

def add(x, y):
    return x + y

def multiply(x, y):
    return x * y

# Create function map
function_map = {
    "add": add,
    "multiply": multiply
}

my_tool_processor = FunctionProcessor(tokenizer, function_map)

prompts = [
    "Hello world. please say <function> multiply(3, 302) </function> \nthis is a test",
    "Hello world. please say <function> add(3, 302) </function> ",
]
r = llm.generate( # llm is a vllm LLM instance
    prompts,
    SamplingParams(
        logits_processors=[my_tool_processor],
        max_tokens=200,
    ))

#print(r)
for rr in r:
    print(rr.outputs[0].text)
    print("----")
<function> multiply(3, 302) </function><result>906.0</result>. Can I assist with anything else?
---
<function> add(3, 302) </function><result>305.0</result>. Is there anything else I can help with?

irdbl avatar Feb 01 '25 06:02 irdbl

@accupham Sorry for the late response. Too busy recently, I will investigate your proposal this weekend.

PeterSH6 avatar Feb 13 '25 06:02 PeterSH6

I don’t think my proposal is feasible… it’s fast, but too cumbersome to work with.

Check out the work being done in TRL though

https://github.com/huggingface/trl/pull/2810

irdbl avatar Feb 13 '25 14:02 irdbl

if there's interest, would love to try and get this working with the verifiers repo i'm building, mostly focused on TRL so far (https://github.com/huggingface/trl/pull/2810) but hopefully the Environments can be interoperable across libraries

in TRL, the semantics are basically:

if env is None:
    outputs = llm.generate(prompts, sampling_params)
    completion_ids = [o.outputs[0].token_ids for o in outputs]
else
    completion_ids = env.generate(llm, prompts, sampling_params)

i.e. we're just wrapping the generate step to allow custom rollout logic (seems similar to the discussion here)

with regard to the performance discussions above, it seems to me like the repeated tokenization from llm.chat() shouldn't be a major bottleneck relative to actual generation, no? this is done anyways when using any LLM API for tool calling

i'm implementing batched vLLM requests, dug into AsyncLLMEngine for allowing async tool calls but seems to add a high amount of complexity when inside of a trainer

willccbb avatar Feb 15 '25 05:02 willccbb

@accupham can you elaborate on "too cumbersome to work with"? thanks.

oleole avatar Feb 18 '25 17:02 oleole

@accupham can you elaborate on "too cumbersome to work with"? thanks.

It’s difficult to work with raw tokens to implement tool calling via logit processor. Token boundaries are very weird and rarely align with how you’d like to do tool calling. So while it’s possible, perhaps by building a trie from a tokenizer vocabulary, it is very difficult to build tool calling on top of this idea.

So I like @willccbb’s idea much better, as we can use the familiar semantics of tool calling to build RL pipelines on top of.

Perhaps throughput requirements should be relaxed here in favor of something at least working?

irdbl avatar Feb 18 '25 18:02 irdbl

After doing a bit of digging, perhaps this API design enabling multi-turn tool calling interaction is not feasible from a performance perspective. Here's why:

Does the chat API support batch processing in both prefill/decode and function calling? If the function call can also be batched and parallelized, the throughput would be acceptable.

  • The vLLM chat API does support batch processing. All this method does is simply apply a chat template, then call the generate API internally, which as expected, will do prefill and other vLLM optimimization if enabled.

  • However, the function call part has issues. The tool call execution instructions will arrive in batches, but we would have to execute those serially, wait for all calls to resolve, then pass it back in as part of that batch group. Even then, we may have multiple back-and-forth calls for some items to complete, but not others, leading to the batch being partially empty toward the end.

    • In other words, we have head-of-line congestion issues with traditional multi-turn tool calling, causing throughput issues.

From our experience, tokenize/detokenize is quite time-consuming in vLLM and veRL already tokenized the prompts into token_ids in the dataloader. So, currently, if the users want to use the original strings, they can simply call tokenizer.batch_decode and this operation is more efficient. What's your experience when using tokenizer/detokenize in vLLM?

  • I've never benchmarked tokenization so I can't comment on that, but i'll have to take your word on it. If it really is time consuming, then using the chat API will be horrible for performance. We go from not having to tokenize/detokenize in the original implementation, to having to do tokenize/detokenize 4 * batch_size * n_turns more times. Chat API needs to tokenize many times to apply chat template, then decode/encode to parse function call. Not optimal.

Currently, with Orca scheduler in vLLM, some prompts can finish early and quit. But LLMEngine will not return until all the prompts are finished. I believe if we can fetch these early exit sequences, it's possible to support the overlapping of prefill/decode computation with function calls. However, these features can only be implemented with vLLM support or by hijacking the scheduler code.

  • That sounds too complicated to implement in a user-friendly way-- multi-turn tool calling might be too hard to do if throughput is desired.

Alternative Proposal

  • Allow the user to pass in a custom LogitProcessor into vLLM.

Pros

  • This maintains all batching optimizations afforded by all subcomponents of the system. Throughput is maintained.
  • Each LogitProcessor is cloned on vLLM's side, so if function calling is slow, it only affects that generation instance, and not others, making scheduling easy.
  • LogitProcessor can be easily be used in production by passing it into vLLM
  • The style of inline function calling is very fluid and lends to the reasoning style of R1-like models.

Cons

  • Very untraditional approach (I've never seen this be done before)

  • Token boundry issues may occur unless you train with special tokens

  • Implementation is not as user-friendly because it deals with raw tokens and modifying logits manually.

    • See proof of concept... It's hacky AF.

PoC Tool calling Logits Processor

from typing import Dict, List, Callable

class FunctionProcessor: def init( self, tokenizer: PreTrainedTokenizer, function_map: Dict[str, Callable], start_tag: str = " ", end_tag: str = " ", result_start: str = "", result_end: str = "" ): self.tokenizer = tokenizer self.function_map = function_map self.buffer = [] self.in_function = False self.current_function = []

    # Pre-tokenize markers 
    self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)
    self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)
    self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
    self.result_end = tokenizer.encode(result_end, add_special_tokens=False)
    self.max_marker_len = max(
            len(self.start_marker), 
            len(self.end_marker),
            len(self.result_start),
            len(self.result_end)
        )
    
    self.result_tokens = []
    
def evaluate_expression(self, expr: str) -> str:
    # Strip the function markers
    expr = expr.replace("<function>", "").replace("</function>", "").strip()
    
    # Parse function call
    func_name = expr.split("(")[0]
    args_str = expr.split("(")[1].rstrip(")")
    
    # Get the function from our map
    if func_name not in self.function_map:
        return f"Error: Unknown function {func_name}"
        
    func = self.function_map[func_name]
    
    try:
        # Parse args - this could be made more sophisticated
        args = [float(arg.strip()) for arg in args_str.split(",")]
        result = func(*args)
        return str(result)
    except Exception as e:
        return f"Error: {str(e)}"
    
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
    try:
        self.buffer.extend(input_ids[-1:])

        if self.result_tokens:
            #scores.fill_(-float('inf'))
            scores[self.result_tokens.pop()] = 100
            return scores
        
        self.buffer = self.buffer[-self.max_marker_len*2:]
        #print(self.tokenizer.decode(self.buffer))
        if not self.in_function and self.check_marker(self.start_marker):
            self.in_function = True
            self.current_function = []
            return scores
            
        if self.in_function:
            self.current_function.extend(input_ids[-1:])
            
            if self.check_marker(self.end_marker):
                self.in_function = False
                func_text = self.tokenizer.decode(self.current_function)
                result = self.evaluate_expression(func_text)
                
                self.result_tokens = list(reversed(
                    self.result_start +
                    self.tokenizer.encode(result) +
                    self.result_end
                ))
                scores[self.result_tokens.pop()] = 100
                

                    
        return scores
        
    except Exception as e:
        print(f"Error in processor: {e}")
        return scores
        
def check_marker(self, marker: List[int]) -> bool:
    #print(marker, self.buffer)
    marker_len = len(marker)
    buffer_len = len(self.buffer)
    
    if buffer_len < marker_len:
        return False
        
    # Only need to check the last possible positions where marker could fit
    start_pos = max(0, buffer_len - marker_len * 2)
    
    for i in range(start_pos, buffer_len - marker_len + 1):
        if self.buffer[i:i + marker_len] == marker:
            return True
            
    return False

PoC Usage

def add(x, y): return x + y

def multiply(x, y): return x * y

Create function map

function_map = { "add": add, "multiply": multiply }

my_tool_processor = FunctionProcessor(tokenizer, function_map)

prompts = [ "Hello world. please say multiply(3, 302) \nthis is a test", "Hello world. please say add(3, 302) ", ] r = llm.generate( # llm is a vllm LLM instance prompts, SamplingParams( logits_processors=[my_tool_processor], max_tokens=200, ))

#print(r) for rr in r: print(rr.outputs[0].text) print("----")

<function> multiply(3, 302) </function><result>906.0</result>. Can I assist with anything else?
---
<function> add(3, 302) </function><result>305.0</result>. Is there anything else I can help with?

I've tried the method mentioned above with verl 0.2 and vllm 0.6.3, However, it randomly hangs after running for 1 to 2 steps. Specifically, the GPU utilization gets stuck at 100%, while the power consumption drops significantly low, and the logs stop updating. Below is the code snippet I used:

class FunctionProcessor:
    def __init__(
        self,
        tokenizer,
        start_tag: str = "<tool_call>",
        end_tag: str = "</tool_call>",
        result_start: str = "\n<tool_result>\n",
        result_end: str = "\n</tool_result>\n<think>"
    ):
        self.tokenizer = tokenizer
        self.buffer = []
        self.in_function = False
        self.current_function = []
        
        # Pre-tokenize markers 
        self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)[0]
        self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)[0]
        self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
        self.result_end = tokenizer.encode(result_end, add_special_tokens=False)

        self.result_tokens = []
        self.state_dict = {}
    
    
    def evaluate_expression(self, expr: str) -> str:
        try:
            # get_tool_resp is the function that will be called to evaluate the expression, time cost no more than 3 seconds.
            result = get_tool_resp(expr)

            return str(result)
        except Exception as e:
            return f"Error: {str(e)}"
        

    
    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        try:
            if input_ids[-1] == self.end_marker:
                idx = 1
                while idx <= len(input_ids):

                    if input_ids[-idx] == self.start_marker:

                        if input_ids[-idx:].count(self.start_marker) > 1 or input_ids[-idx:].count(self.end_marker) > 1:
                            break
                        
                        current_function = input_ids[-idx:]
                        func_text = self.tokenizer.decode(current_function)
                        try:
                            result = self.evaluate_expression(func_text)
                        except:
                            result = "{'result': 'Tool Call Error'}"
                        result_tokens = list(reversed(
                            self.result_start +
                            self.tokenizer.encode(str(result)) +
                            self.result_end
                        ))
                        state_dict_key = tuple(input_ids)
                        
                        self.state_dict[state_dict_key] = result_tokens
                        token_id = self.state_dict[state_dict_key].pop()
                        scores[token_id] = 100
                        break
                        
                    idx += 1
            else:
                for idx in range(1, len(self.end_marker)):
                    if input_ids[-idx] == self.start_marker:
                        state_dict_key = tuple(input_ids[:-idx + 1])
                        result_tokens = self.state_dict.get(state_dict_key, [])
                        if result_tokens:
                            self.state_dict[state_dict_key] = result_tokens
                            token_id = self.state_dict[state_dict_key].pop()
                            scores[token_id] = 100

                        break
            
        except Exception as e:
            print(f"Error in FunctionProcessor: {e}")
        return scores

I made adjustments to the code since, during batch decoding, the shared object leads to value overwriting among different sequences.

Image Image

AIBionics avatar Feb 21 '25 13:02 AIBionics

@PeterSH6 are you interested in a PR that implements this:

...the customized func to be passed through config file so that users won't need to modify the vllm_rollout.py file. Are you interested in contributing to this feature?

My thought is to add a string value to the config file allowing the user to specify a class to be instantiated at runtime with importlib. This hook will allow the user to pass in arbitrary values of logits_processors.

frrad avatar Feb 27 '25 19:02 frrad

@frrad If you do end up implementing this, you should make it compatible with how VLLM passes in custom logits processors:

pip install logits-processor-zoo
vllm serve ... --logits-processor-pattern '.*'

And call API with extra_body:

{"logits_processors": "qualname": "logits_processor_zoo.vllm.cite_prompt"}

VLLM docs

You'll want to instaniate the tokenizer only once and cache it for performance.

irdbl avatar Feb 27 '25 19:02 irdbl

Comments from @youkaichao - the best way to do tool calling with vllm, is to initialize vllm with the async llm engine (instead of using the current llm.generate API. The llm.generate API is for bathc inference, while async llm engine is like a API server which naturally supports async request and multi-turn generation. It also supports SPMD.

eric-haibin-lin avatar Apr 11 '25 03:04 eric-haibin-lin

Here is a reference implementation from the community: https://github.com/cfpark00/verl/blob/a3d761be3510974b9ad605475a31329cebf324e3/verl/workers/rollout/vllm_rollout/vllm_rollout.py

eric-haibin-lin avatar Apr 11 '25 16:04 eric-haibin-lin

Here is a reference implementation from the community: https://github.com/cfpark00/verl/blob/a3d761be3510974b9ad605475a31329cebf324e3/verl/workers/rollout/vllm_rollout/vllm_rollout.py

Yeah I think currently there are some nice implementations about using AsyncLLMEngine. The core problem is that I have no idea how to update the weights of AsyncLLMEngine.

SparkJiao avatar Apr 13 '25 00:04 SparkJiao

To my understanding, verl uses llm.generate right now. And because verl uses the SPMD mode, llm.generate is replicated TP times (tensor-parallel size times).

If we want to have multi-turn / tool calling inside the same process, and we write it this way:

answer = llm.generate(prompts)
new_prompts = call_tool(answer, prompt) # interact with the environment
new_answer = llm.generate(new_prompts)

Then call_tool will also be replicated TP times. What's worse, if the environment is not deterministic, then every TP rank gets different new_prompts, and this violates the SPMD mode, and the output from new_answer = llm.generate(new_prompts) will be wrong, or hanging.

To make it work, we will need to manually make it SPMD-compatible:

answer = llm.generate(prompts)
if tp_rank == 0:
    new_prompts = call_tool(answer, prompt) # interact with the environment, only for tp rank 0
    broadcast_across_tp(new_prompts, src=0)
else:
    new_prompts = broadcast_across_tp(src=0)
new_answer = llm.generate(new_prompts)

This might work functionally, but will not be performant. All prompts needs to finish generation before anyone can call the tool.

Ideally, we would want to have async between model generation and tool calling, but since verl forces vLLM to live in the same process as the training worker, and the tool caller in this case, I don't see any chance how we can have the async execution.

youkaichao avatar Apr 13 '25 07:04 youkaichao

@youkaichao thanks for the comments. while the original issue is discussing the design based on vllm v0.6.3 + verl, we actually have a branch integrated vllm async server the newer version of vllm. we can upstream this implementation such that tool calling can be done with async generation, as long as the input to TP ranks are consistent

eric-haibin-lin avatar Apr 14 '25 16:04 eric-haibin-lin

Here is a reference implementation from the community: https://github.com/cfpark00/verl/blob/a3d761be3510974b9ad605475a31329cebf324e3/verl/workers/rollout/vllm_rollout/vllm_rollout.py

Yeah I think currently there are some nice implementations about using AsyncLLMEngine. The core problem is that I have no idea how to update the weights of AsyncLLMEngine.

@SparkJiao do you mean this fork does not handle weight update correctly, or you just do not know how it is handled in actual implementation?

eric-haibin-lin avatar Apr 14 '25 16:04 eric-haibin-lin

To my understanding, verl uses llm.generate right now. And because verl uses the SPMD mode, llm.generate is replicated TP times (tensor-parallel size times).

If we want to have multi-turn / tool calling inside the same process, and we write it this way:

answer = llm.generate(prompts) new_prompts = call_tool(answer, prompt) # interact with the environment new_answer = llm.generate(new_prompts) Then call_tool will also be replicated TP times. What's worse, if the environment is not deterministic, then every TP rank gets different new_prompts, and this violates the SPMD mode, and the output from new_answer = llm.generate(new_prompts) will be wrong, or hanging.

To make it work, we will need to manually make it SPMD-compatible:

answer = llm.generate(prompts) if tp_rank == 0: new_prompts = call_tool(answer, prompt) # interact with the environment, only for tp rank 0 broadcast_across_tp(new_prompts, src=0) else: new_prompts = broadcast_across_tp(src=0) new_answer = llm.generate(new_prompts) This might work functionally, but will not be performant. All prompts needs to finish generation before anyone can call the tool.

Ideally, we would want to have async between model generation and tool calling, but since verl forces vLLM to live in the same process as the training worker, and the tool caller in this case, I don't see any chance how we can have the async execution.

This works for me! Specifically, using vllm_ps to get local tp rank and conduct broadcast across tp.

tp_rank = vllm_ps.get_tensor_model_parallel_rank()
if tp_rank == 0:
    tool_call_results = tool_call(prompt) # do tool call only fro tp rank 0
    broadcast_data = {
        'tool_call_results': tool_call_results,
    }                
broadcast_data = vllm_ps._TP.broadcast_object(broadcast_data, src=0) # broadcast tool call results across tp

AnselCmy avatar Apr 14 '25 18:04 AnselCmy

Here is a reference implementation from the community: https://github.com/cfpark00/verl/blob/a3d761be3510974b9ad605475a31329cebf324e3/verl/workers/rollout/vllm_rollout/vllm_rollout.py

Yeah I think currently there are some nice implementations about using AsyncLLMEngine. The core problem is that I have no idea how to update the weights of AsyncLLMEngine.

@SparkJiao do you mean this fork does not handle weight update correctly, or you just do not know how it is handled in actual implementation?

@eric-haibin-lin No, it handles correctly. But it does not use AsyncLLM. We can use vllm.LLM now as a temporary solution for now (like what search-r1 did). The problem is that all requests would wait each other at the end of each round, this would cause high concurrency to the environment and increase the latency.

I think the ultimate solution is using AsyncLLM, like what Kaichao said, each request will do its own multi-round rollout till the termination state. Currently the problem is (1) as Kaichao said, we need some specific logic to ensure the consistency across TP ranks for each request; (2) AsyncLLM with online weight update is no longer supported since vllm == 0.8.2.

I think the repo below provides a nice implementation based on AsyncLLM:

https://github.com/agentica-project/verl-pipeline/blob/master/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L404-L565

The problem is that I have tried the few lines for online updating AsnycLLM weights for both vllm 0.8.2 and vllm 0.8.3, but both failed: https://github.com/agentica-project/verl-pipeline/blob/master/verl/workers/sharding_manager/fsdp_vllm.py#L99-L102

SparkJiao avatar Apr 15 '25 02:04 SparkJiao

The problem is that I have tried the few lines for online updating AsnycLLM weights for both vllm 0.8.2 and vllm 0.8.3, but both failed: agentica-project/verl-pipeline@master/verl/workers/sharding_manager/fsdp_vllm.py#L99-L102

You need to follow

https://github.com/vllm-project/vllm/blob/fdcb850f1424eca5f914578187ef31642c6e422d/vllm/v1/engine/async_llm.py#L430

to see how async_llm.wake_up calls the underlying worker's wake_up function, and add async_llm.collective_rpc functionality so that you can call async_llm.collective_rpc("update_weight"), with your worker extension class defined at the start time of creating the async_llm. Then, follow https://github.com/vllm-project/vllm/blob/fdcb850f1424eca5f914578187ef31642c6e422d/examples/offline_inference/rlhf_colocate.py to pass cuda ipc handles istead of passing tensors to synchronize weights.

youkaichao avatar Apr 16 '25 02:04 youkaichao

Thank you @youkaichao ! I will take a look on this and think about how to implement it.

SparkJiao avatar Apr 17 '25 05:04 SparkJiao

The agent loop is the recommended place: https://verl.readthedocs.io/en/latest/start/agentic_rl.html

eric-haibin-lin avatar Jul 24 '25 20:07 eric-haibin-lin

I've tried the method mentioned above with verl 0.2 and vllm 0.6.3, However, it randomly hangs after running for 1 to 2 steps. Specifically, the GPU utilization gets stuck at 100%, while the power consumption drops significantly low, and the logs stop updating.

@AIBionics hi 我也是verl训练GRPO 总是hang在第二个迭代update_policy的backward阶段 ray actor都是活的,GPU利用率全是100% 功率比较低 请问你的问题最后有解决吗 root cause是啥

lkygithub avatar Aug 02 '25 09:08 lkygithub