verl icon indicating copy to clipboard operation
verl copied to clipboard

[Bug?] vLLMRollout.generate_sequences Randomly Hangs After 1-2 Steps When trying to Implement Tool Calling with Logits Processors

Open AIBionics opened this issue 1 year ago • 8 comments

I've tried the method of vLLMRollout.generate_sequences to implement tool calling with verl 0.2 and vllm 0.6.3, However, it randomly hangs after running for 1 to 2 steps. Specifically, the GPU utilization gets stuck at 100%, while the power consumption drops significantly low, and the logs stop updating.

Below is the code snippet I used:

my_tool_processor = FunctionProcessor(self.tokenizer)
sampling_params.logits_processors = [my_tool_processor]
class FunctionProcessor:
    def __init__(
        self,
        tokenizer,
        start_tag: str = "<tool_call>",
        end_tag: str = "</tool_call>",
        result_start: str = "\n<tool_result>\n",
        result_end: str = "\n</tool_result>\n<think>"
    ):
        self.tokenizer = tokenizer
        self.buffer = []
        self.in_function = False
        self.current_function = []
        
        # Pre-tokenize markers 
        self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)[0]
        self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)[0]
        self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
        self.result_end = tokenizer.encode(result_end, add_special_tokens=False)

        self.result_tokens = []
        self.state_dict = {}
    
    
    def evaluate_expression(self, expr: str) -> str:
        try:
            # get_tool_resp is the function that will be called to evaluate the expression, time cost no more than 3 seconds.
            result = get_tool_resp(expr)

            return str(result)
        except Exception as e:
            return f"Error: {str(e)}"
        

    
    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        try:
            if input_ids[-1] == self.end_marker:
                idx = 1
                while idx <= len(input_ids):

                    if input_ids[-idx] == self.start_marker:

                        if input_ids[-idx:].count(self.start_marker) > 1 or input_ids[-idx:].count(self.end_marker) > 1:
                            break
                        
                        current_function = input_ids[-idx:]
                        func_text = self.tokenizer.decode(current_function)
                        try:
                            result = self.evaluate_expression(func_text)
                        except:
                            result = "{'result': 'Tool Call Error'}"
                        result_tokens = list(reversed(
                            self.result_start +
                            self.tokenizer.encode(str(result)) +
                            self.result_end
                        ))
                        state_dict_key = tuple(input_ids)
                        
                        self.state_dict[state_dict_key] = result_tokens
                        token_id = self.state_dict[state_dict_key].pop()
                        scores[token_id] = 100
                        break
                        
                    idx += 1
            else:
                for idx in range(1, len(self.end_marker)):
                    if input_ids[-idx] == self.start_marker:
                        state_dict_key = tuple(input_ids[:-idx + 1])
                        result_tokens = self.state_dict.get(state_dict_key, [])
                        if result_tokens:
                            self.state_dict[state_dict_key] = result_tokens
                            token_id = self.state_dict[state_dict_key].pop()
                            scores[token_id] = 100

                        break
            
        except Exception as e:
            print(f"Error in FunctionProcessor: {e}")
        return scores

Env

PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (GCC) 12.2.0
Clang version: 3.8.0 (tags/RELEASE_380/final)
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.11.3 (main, Apr  5 2023, 14:15:06) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.0-2.0.0.2-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: CF-NG-HZZ1-O
GPU 1: CF-NG-HZZ1-O
GPU 2: CF-NG-HZZ1-O
GPU 3: CF-NG-HZZ1-O
GPU 4: CF-NG-HZZ1-O
GPU 5: CF-NG-HZZ1-O
GPU 6: CF-NG-HZZ1-O
GPU 7: CF-NG-HZZ1-O

Nvidia driver version: 535.183.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   52 bits physical, 57 bits virtual
CPU(s):                          192
On-line CPU(s) list:             0-191
Thread(s) per core:              2
Core(s) per socket:              48
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           143
Model name:                      Intel(R) Xeon(R) Platinum 8468V
Stepping:                        8
CPU MHz:                         2900.000
CPU max MHz:                     3800.0000
CPU min MHz:                     800.0000
BogoMIPS:                        4800.00
Virtualization:                  VT-x
L1d cache:                       4.5 MiB
L1i cache:                       3 MiB
L2 cache:                        192 MiB
L3 cache:                        195 MiB
NUMA node0 CPU(s):               0-47,96-143
NUMA node1 CPU(s):               48-95,144-191
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hfi avx512vbmi umip pku waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pyzmq==26.2.1
[pip3] torch==2.4.0
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.19.0
[pip3] transformers==4.47.1
[pip3] triton==3.0.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: N/A (dev)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    NIC8    NIC9    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     0-47,96-143     0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     0-47,96-143     0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     0-47,96-143     0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    NODE    NODE    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     0-47,96-143     0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    48-95,144-191   1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    48-95,144-191   1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     48-95,144-191   1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    48-95,144-191   1               N/A
NIC0    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS      X      PIX     NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS
NIC1    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     PIX      X      NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS
NIC2    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     NODE    NODE     X      NODE    NODE    NODE    SYS     SYS     SYS     SYS
NIC3    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE     X      NODE    NODE    SYS     SYS     SYS     SYS
NIC4    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE     X      NODE    SYS     SYS     SYS     SYS
NIC5    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE     X      SYS     SYS     SYS     SYS
NIC6    SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS      X      NODE    NODE    NODE
NIC7    SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     NODE     X      NODE    NODE
NIC8    SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE     X      NODE
NIC9    SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE     X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7
  NIC8: mlx5_8
  NIC9: mlx5_9

NVIDIA_VISIBLE_DEVICES=GPU-00c17004-1a68-8b2e-bf1f-ce4a849177c9,GPU-50c603ba-d9c5-5c24-dfa1-610ef45f5dfe,GPU-38cdf56a-4962-87b6-f54e-3c591c3b6f94,GPU-e667f094-50f7-6a86-0173-98bc170da5d4,GPU-2fe323cd-e3ae-7915-1299-542a62e79926,GPU-f2c15f28-fdbf-499d-4d81-26545c14c0fd,GPU-89fd4b88-9a9e-7e89-b4fd-820425524007,GPU-b6221a29-7a6f-d2e5-fc50-2806885da539
NCCL_P2P_DISABLE=0
NVIDIA_REQUIRE_CUDA=cuda>=12.0 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511 brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511 brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511 brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516 brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516 brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516 brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516
NCCL_IB_CUDA_SUPPORT=0
NVIDIA_LIB=/usr/local/nvidia/lib64
NCCL_VERSION=2.17.1-1
NCCL_SOCKET_IFNAME=xgbe0
NCCL_DEBUG_SUBSYS=INIT,ENV,GRAPH
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NCCL_DEBUG=INFO
NVIDIA_PRODUCT_NAME=CUDA
NCCL_IB_GID_INDEX=3
CUDA_VERSION=12.0.1
NVIDIA_TOOLS=/home/opt/cuda_tools
NCCL_DEBUG_FILE=/root/workspace/log/nccl.%h.%p.log
NCCL_IB_QPS_PER_CONNECTION=2
NCCL_IB_CONNECT_RETRY_CNT=15
NCCL_ERROR_FILE=/root/workspace/log/err.%h.%p.log
NCCL_IB_TIMEOUT=22
CUDNN_VERSION=8.9.1
LD_LIBRARY_PATH=/root/venv/lib/python3.11/site-packages/cv2/../../lib64:/usr/local/lib:/usr/local/x86_64-pc-linux-gnu/lib:/home/opt/nvidia_lib:/usr/local/cuda/lib64:/usr/lib64:/usr/local/lib:/usr/lib/x86_64-linux-gnu/
NCCL_IB_DISABLE=0
NCCL_IB_ADAPTIVE_ROUTING=1
CUDA_MODULE_LOADING=LAZY
Image Image

AIBionics avatar Feb 21 '25 13:02 AIBionics

Hi @AIBionics , could you enable CUDA_LAUNCH_BLOCKING=1 and use py-spy to investigate which lines of code make the vLLM hang?

It could help investigate the core reason

PeterSH6 avatar Feb 21 '25 14:02 PeterSH6

Image

Hi @AIBionics , could you enable CUDA_LAUNCH_BLOCKING=1 and use py-spy to investigate which lines of code make the vLLM hang?

It could help investigate the core reason

Here is the file. Thank you!

AIBionics avatar Feb 21 '25 14:02 AIBionics

Hi @AIBionics , could you enable CUDA_LAUNCH_BLOCKING=1 and use py-spy to investigate which lines of code make the vLLM hang? It could help investigate the core reason

dalao fangbian jia ge weixin? 或者有用户微信群什么的嘛?

I think the link can be found in the README of verl.

BearBiscuit05 avatar Feb 21 '25 15:02 BearBiscuit05

Hi @AIBionics , could you enable CUDA_LAUNCH_BLOCKING=1 and use py-spy to investigate which lines of code make the vLLM hang? It could help investigate the core reason

dalao fangbian jia ge weixin? 或者有用户微信群什么的嘛?

I think the link can be found in the README of verl.

thx

AIBionics avatar Feb 21 '25 16:02 AIBionics

@AIBionics The file seems to include too much information. Let me clarify a bit more.

When you launch the job using CUDA_LAUNCH_BLOCKING=1 and the job is hanged. You should check the pid of the ray:WorkerDict job. Then using py-spy dump --pid <the pid of your hanged job> to get the stack of the hanging step.

PeterSH6 avatar Feb 22 '25 07:02 PeterSH6

@AIBionics The file seems to include too much information. Let me clarify a bit more.

When you launch the job using CUDA_LAUNCH_BLOCKING=1 and the job is hanged. You should check the pid of the ray:WorkerDict job. Then using py-spy dump --pid <the pid of your hanged job> to get the stack of the hanging step.

OK, got it. I used this command to check all the hanging jobs, and all of them appear like this:

    main_loop (ray/_private/worker.py:935)
    <module> (ray/_private/workers/default_worker.py:297)

Additionally, when using the logits_processors to invoke the tool method, the issue of hanging does not occur with the original LLM.generation interface in vLLM. (Tested for an hour, and no hanging was encountered during that time at least.)

AIBionics avatar Feb 22 '25 12:02 AIBionics

OK, got it. I used this command to check all the hanging jobs, and all of them appear like this:

    main_loop (ray/_private/worker.py:935)
    <module> (ray/_private/workers/default_worker.py:297)

Additionally, when using the logits_processors to invoke the tool method, the issue of hanging does not occur with the original LLM.generation interface in vLLM. (Tested for an hour, and no hanging was encountered during that time at least.)

hi, I'm interested in this bug and I plan to support it in the tests. Could we add wechat to discuss further?

BearBiscuit05 avatar Feb 24 '25 07:02 BearBiscuit05

OK, got it. I used this command to check all the hanging jobs, and all of them appear like this:

    main_loop (ray/_private/worker.py:935)
    <module> (ray/_private/workers/default_worker.py:297)

Additionally, when using the logits_processors to invoke the tool method, the issue of hanging does not occur with the original LLM.generation interface in vLLM. (Tested for an hour, and no hanging was encountered during that time at least.)

hi, I'm interested in this bug and I plan to support it in the tests. Could we add wechat to discuss further?

Train CMD

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$TRAIN \
    data.val_files=$TEST \
    data.train_batch_size=8 \
    data.val_batch_size=8 \
    data.max_prompt_length=1024 \
    data.max_response_length=4096 \
    actor_rollout_ref.model.path=/root/Qwen2.5-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=3e-7 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.use_dynamic_bsz=True \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=True \
    actor_rollout_ref.actor.fsdp_config.grad_offload=True \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
    actor_rollout_ref.rollout.disable_log_stats=False \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=4 \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.kl_ctrl.kl_coef=0.001 \
    trainer.default_local_dir=$save_path \
    trainer.critic_warmup=0 \
    trainer.logger=['console','mlflow'] \
    trainer.project_name='tool' \
    trainer.experiment_name='qwen_7b-8' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=5000 \
    trainer.test_freq=2000 \
    trainer.total_epochs=15 $@ 

Agent Call Rollout Demo

replace: verl/workers/rollout/vllm_rollout/agent_rollout.py

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
When working with Megatron:
- Use Megatron weight loader
- During training, only the current pp stage holds the parameters
- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters)
- Bind the parameters to the inference engine
- Do inference in tp. pp is treated as additional dp
- After inference, all the parameters that doesn't belong to this pp rank is freed.
"""
from typing import List
from contextlib import contextmanager
from omegaconf import DictConfig
import torch
import torch.distributed
from tensordict import TensorDict
from torch import nn

from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.workers.rollout.base import BaseRollout
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
from vllm import SamplingParams

# TODO
# 1. support pp in vllm
# 2. passing tokenizer is not necessary? no encoding/decoding is happending here
# 3. simplify init logics



def get_tool_resp_from_tool_call_string(func_args_str):
    import time
    import random
    input_data = {
        "func_args_str": func_args_str,
        }
    try:
        # response = requests.post(url, json=input_data, timeout=3)
        # response_json = response.json()
        # resp_text = response_json.get('response', '{"result": "请求超时。"}')
        time.sleep(random.random() * 2 + 1)
        tool_result_str = "工具结果" * random.randint(99, 999)
        resp_text = '{"result": "__TOOL_RESULT__"}'.replace("__TOOL_RESULT__", tool_result_str)
    except:
        resp_text = '{"result": "请求失败。"}'

    return resp_text

class FunctionProcessor:
    def __init__(
        self,
        tokenizer,
        start_tag: str = "<tool_call>",
        end_tag: str = "</tool_call>",
        result_start: str = "\n<|quad_start|>\n",
        result_end: str = "\n<|quad_end|>\n<think>"
    ):
        self.tokenizer = tokenizer
        self.buffer = []
        self.in_function = False
        self.current_function = []
        
        # Pre-tokenize markers 
        self.start_marker = tokenizer.encode(start_tag, add_special_tokens=False)
        self.end_marker = tokenizer.encode(end_tag, add_special_tokens=False)
        self.result_start = tokenizer.encode(result_start, add_special_tokens=False)
        self.result_end = tokenizer.encode(result_end, add_special_tokens=False)
        self.max_marker_len = max(
                len(self.start_marker), 
                len(self.end_marker),
                len(self.result_start),
                len(self.result_end)
            )
        self.result_tokens = []
        self.state_dict = {}
    
    
    def evaluate_expression(self, expr: str) -> str:
        try:
            result = get_tool_resp_from_tool_call_string(expr)
            return str(result)
        except Exception as e:
            print(f"FunctionProcessor.evaluate_expression Error: {e}\n ---- expr ---\n{expr}\n-------")
            return f"Error: {str(e)}"
        

    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        if len(input_ids) == 0:
            return scores
        try:
            if input_ids[-1] == 151658:
                idx = 1
                while idx <= len(input_ids):
                    if input_ids[-idx] == 151657:
                        if input_ids[-idx:].count(151657) > 1 \
                                or input_ids[-idx:].count(151658) > 1:
                            break
                        
                        current_function = input_ids[-idx:]
                        func_text = self.tokenizer.decode(current_function)
                        try:
                            result = self.evaluate_expression(func_text)
                        except:
                            result = "{'result': 'Error: Timeout!'}"

                        result_tokens = list(reversed(
                            self.result_start +
                            self.tokenizer.encode(result) +
                            self.result_end
                        ))
                        state_dict_key = tuple(input_ids)
                        self.state_dict[state_dict_key] = result_tokens
                        scores[self.state_dict[state_dict_key].pop()] = 100
     
                    idx += 1
            else:
                for idx in range(1, len(input_ids)):
                    if input_ids[-idx] == 151658:
                        state_dict_key = tuple(input_ids[:-idx + 1])
                        result_tokens = self.state_dict.get(state_dict_key, [])
                        if result_tokens:
                            scores[result_tokens.pop()] = 100
                        break
            

        except Exception as e:
            print(f"Error in FunctionProcessor.__call__: {e}")
        return scores
            

# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
    # remove the left padding in the prompt token_id
    # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
    non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
    token_ids = prompt_token_ids[non_pad_index:].tolist()
    return token_ids


class vLLMRollout(BaseRollout):

    def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model_hf_config, **kwargs):
        """A vLLM rollout. It requires the module is supported by the vllm.

        Args:
            module: module here follows huggingface APIs
            config: DictConfig
            tokenizer: the task/model tokenizer
            model_hf_config: the huggingface config to initiallize the generating model in vllm
            **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group
        """
        super().__init__()
        self.config = config
        assert not (not config.enforce_eager and config.free_cache_engine), \
            "disable CUDA graph (enforce_eager = False) if free cache engine"

        tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
        assert tensor_parallel_size <= torch.distributed.get_world_size(), \
            "tensor parallel size should be less than or equal to the world size"
        max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192)

        if kwargs.get('train_tp', None) is not None:
            # deployed with megatron
            import os
            os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0'
            os.environ['MEGATRON_IMPORT_TIMERS'] = '0'
            train_tp = kwargs.get('train_tp', None)
            num_tp_per_train_tp = train_tp // tensor_parallel_size
            if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
                vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size,
                                                  num_tp_per_train_tp=num_tp_per_train_tp)

        assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
            "model context length should be greater than total sequence length"
        self.inference_engine = LLM(
            actor_module,
            tokenizer=tokenizer,
            model_hf_config=model_hf_config,
            tensor_parallel_size=tensor_parallel_size,
            dtype=config.dtype,
            enforce_eager=config.enforce_eager,
            gpu_memory_utilization=config.gpu_memory_utilization,
            skip_tokenizer_init=False,
            max_model_len=config.prompt_length + config.response_length,
            load_format=config.load_format,
            disable_log_stats=config.disable_log_stats,
            max_num_batched_tokens=max_num_batched_tokens,
            enable_chunked_prefill=config.enable_chunked_prefill,
        )

        # Offload vllm model to reduce peak memory usage
        self.inference_engine.offload_model_weights()

        kwargs = dict(
            n=1,
            logprobs=1,  # can be set to 0 and let actor to recompute
            max_tokens=config.response_length,
        )

        # we may detokenize the result all together later
        if vllm_version in ('0.4.2', '0.5.4', '0.6.3'):
            kwargs['detokenize'] = False

        # supporting adding any sampling params from the config file
        for k in config.keys():
            if hasattr(SamplingParams(), str(k)):
                kwargs[k] = config.get(k)

        print(f"kwargs: {kwargs}")
        self.sampling_params = SamplingParams(**kwargs)

        self.pad_token_id = tokenizer.pad_token_id

    @contextmanager
    def update_sampling_params(self, **kwargs):
        # update sampling params
        old_sampling_params_args = {}
        if kwargs:
            for key, value in kwargs.items():
                if hasattr(self.sampling_params, key):
                    old_value = getattr(self.sampling_params, key)
                    old_sampling_params_args[key] = old_value
                    setattr(self.sampling_params, key, value)
        yield
        # roll back to previous sampling params
        # if len(old_sampling_params_args):
        for key, value in old_sampling_params_args.items():
            setattr(self.sampling_params, key, value)

    @torch.no_grad()
    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        # rebuild vllm cache engine
        if self.config.free_cache_engine:
            self.inference_engine.init_cache_engine()

        idx = prompts.batch['input_ids']  # (bs, prompt_length)
        # left-padded attention_mask
        attention_mask = prompts.batch['attention_mask']
        position_ids = prompts.batch['position_ids']

        # used to construct attention_mask
        eos_token_id = prompts.meta_info['eos_token_id']

        batch_size = idx.size(0)

        idx_list = []
        # parse idx from torch.Tensor to List[List[str]]
        for i in range(batch_size):
            idx_list.append(_pre_process_inputs(self.pad_token_id, idx[i]))

        do_sample = prompts.meta_info.get('do_sample', True)
        if not do_sample:
            kwargs = {
                'best_of': 1,
                'top_p': 1.0,
                'top_k': -1,
                'min_p': 0.0,
                'temperature': 0,
                'n': 1  # if greedy, only 1 response
            }
            
        # tool call code
        my_tool_processor = FunctionProcessor(self.tokenizer)
        kwargs['logits_processors'] = [my_tool_processor]
        
        # users can customize different sampling_params at different run
        with self.update_sampling_params(**kwargs):
            output = self.inference_engine.generate(
                prompts=None,  # because we have already convert it to prompt token id
                sampling_params=self.sampling_params,
                prompt_token_ids=idx_list,
                use_tqdm=False)

        # TODO(sgm): disable logprob when recompute_log_prob is enable
        # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
        response = output[0].to(idx.device)
        log_probs = output[1].to(idx.device)
        
        response_list = response.tolist()
        import json
        with open("/root/paddlejob/workspace/env_run/data/vllm_output_list", 'w') as wf:
            wf.write(json.dumps({"response_list": response_list}, indent=4))
        exit()

        if response.shape[1] < self.config.response_length:
            response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
            log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)

        if self.config.n > 1 and do_sample:
            idx = idx.repeat_interleave(self.config.n, dim=0)
            attention_mask = attention_mask.repeat_interleave(self.config.n, dim=0)
            position_ids = position_ids.repeat_interleave(self.config.n, dim=0)
            batch_size = batch_size * self.config.n
        seq = torch.cat([idx, response], dim=-1)

        response_length = response.size(1)
        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
        delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)

        # TODO(sgm): fix position_ids on right_pad
        # prompt: left pad + response: right pad
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
        response_position_ids = position_ids[:, -1:] + delta_position_id
        position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
        response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)

        # all the tp ranks should contain the same data here. data in all ranks are valid
        batch = TensorDict(
            {
                'prompts': idx,
                'responses': response,
                'input_ids': seq,  # here input_ids become the whole sentences
                # 'old_log_probs': log_probs, # we will recompute old log prob with actor
                'attention_mask': attention_mask,
                'position_ids': position_ids
            },
            batch_size=batch_size)

        # free vllm cache engine
        if self.config.free_cache_engine:
            self.inference_engine.free_cache_engine()

        return DataProto(batch=batch)

Test prompt demo

<|im_start|>system
You are a helpful assistant.

# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type": "function", "function": {"name": "search", "description": "A function that searches for a given query list.", "parameters": {"type": "object", "properties": {"query_list": {"type": "array", "items": {"type": "string"}, "description": "The query list to search for.  return search_result_text"}}, "required": ["query_list"]}}}
</tools>

For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>

You can invoke a tool once or multiple times as needed. 

You must provide the parameters for the tool call in JSON format enclosed by <tool_call></tool_call> tags. 

Afterwards, the tool will return the corresponding result wrapped with <|quad_start|><|quad_end|> tags.

The specific format is as follows:
<think>
Your reasoning process 1
</think>
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
<|quad_start|>
tool result
<|quad_end|>
<think>
Your thinking process 2
</think>
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
<|quad_start|>
tool result
<|quad_end|>
<think>
Your thinking process 3
</think>
...
<think>
Your thinking process n
</think>
<answer>
Final answer
</answer>
<|im_start|>user
今天天气怎么样?给几个最合适旅游的地方推荐。分别介绍特点攻略,然后选出最合适我的地点。我们一家三口人,需要风景宜人,但安全的地方。最好可以捡贝壳,或者和小动物亲近。<|im_end|>
<|im_start|>assistant
<think>```

AIBionics avatar Feb 24 '25 09:02 AIBionics

any updates on this, i can get the same issue.

i don't think it is caused by vllm. it hangout in ray::WorkerDict not rollout.

Thread 3500638 (idle): "MainThread"
    main_loop (ray/_private/worker.py:935)
    <module> (ray/_private/workers/default_worker.py:297)
Thread 3501472 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3501743 (idle): "Thread-2 (_read_thread)"
    _recv_msg (torch/_inductor/compile_worker/subproc_pool.py:57)
    _read_thread (torch/_inductor/compile_worker/subproc_pool.py:123)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3513352 (idle): "Thread-3"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3513762 (idle)

jxiw avatar Mar 19 '25 15:03 jxiw

Are all TP ranks receiving the same inputs after tool calling?

eric-haibin-lin avatar Apr 14 '25 16:04 eric-haibin-lin

Any update?

My pyspy looks exactly same as this.

any updates on this, i can get the same issue.

i don't think it is caused by vllm. it hangout in ray::WorkerDict not rollout.

Thread 3500638 (idle): "MainThread"
    main_loop (ray/_private/worker.py:935)
    <module> (ray/_private/workers/default_worker.py:297)
Thread 3501472 (idle): "Thread-1"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3501743 (idle): "Thread-2 (_read_thread)"
    _recv_msg (torch/_inductor/compile_worker/subproc_pool.py:57)
    _read_thread (torch/_inductor/compile_worker/subproc_pool.py:123)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3513352 (idle): "Thread-3"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3513762 (idle)

GeoffreyChen777 avatar May 19 '25 21:05 GeoffreyChen777