Bug when using vllm async rollout
I encountered the following error when trying to use vllm's async rollout.
env: python 3.10, pytorch 2.6, vllm 0.8.5
(AsyncvLLMServer pid=5222) Process EngineCore_0:
(AsyncvLLMServer pid=5222) Traceback (most recent call last):
(AsyncvLLMServer pid=5222) File "/miniconda3/envs/verl/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
(AsyncvLLMServer pid=5222) self.run()
(AsyncvLLMServer pid=5222) File "/miniconda3/envs/verl/lib/python3.10/multiprocessing/process.py", line 108, in run
(AsyncvLLMServer pid=5222) self._target(*self._args, **self._kwargs)
(AsyncvLLMServer pid=5222) File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 400, in run_engine_core
(AsyncvLLMServer pid=5222) raise e
(AsyncvLLMServer pid=5222) File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 387, in run_engine_core
(AsyncvLLMServer pid=5222) engine_core = EngineCoreProc(*args, **kwargs)
(AsyncvLLMServer pid=5222) File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 329, in init
(AsyncvLLMServer pid=5222) super().init(vllm_config, executor_class, log_stats,
(AsyncvLLMServer pid=5222) File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 64, in init
(AsyncvLLMServer pid=5222) self.model_executor = executor_class(vllm_config)
(AsyncvLLMServer pid=5222) File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 52, in init
(AsyncvLLMServer pid=5222) self._init_executor()
(AsyncvLLMServer pid=5222) File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py", line 78, in _init_executor
(AsyncvLLMServer pid=5222) self.collective_rpc("init_worker", args=([kwargs],))
(AsyncvLLMServer pid=5222) File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py", line 98, in collective_rpc
(AsyncvLLMServer pid=5222) outputs = ray.get([worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers])
(AsyncvLLMServer pid=5222) File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py", line 98, in
Traceback (most recent call last): File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/recipe/dapo/main_dapo.py", line 56, in main run_ppo(config) File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/recipe/dapo/main_dapo.py", line 68, in run_ppo ray.get(runner.run.remote(config)) File "/miniconda3/envs/verl/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper return fn(*args, **kwargs) File "/miniconda3/envs/verl/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper return func(*args, **kwargs) File "/miniconda3/envs/verl/lib/python3.10/site-packages/ray/_private/worker.py", line 2822, in get values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout) File "/miniconda3/envs/verl/lib/python3.10/site-packages/ray/_private/worker.py", line 930, in get_objects raise value.as_instanceof_cause() ray.exceptions.RayTaskError(RuntimeError): ray::TaskRunner.run() (pid=2516, ip=10.130.0.229, actor_id=454ae51394a6a91274ca3ea901000000, repr=<main_dapo.TaskRunner object at 0x7f5107765720>) File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/recipe/dapo/main_dapo.py", line 197, in run trainer.init_workers() File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/verl/trainer/ppo/ray_trainer.py", line 743, in init_workers self.async_rollout_manager = AsyncLLMServerManager( File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/verl/workers/rollout/async_server.py", line 272, in init ray.get([server.init_engine.remote() for server in self.async_llm_servers]) ray.exceptions.RayTaskError(RuntimeError): ray::AsyncvLLMServer.init_engine() (pid=5222, ip=10.130.0.229, actor_id=de9bd5a83310a0af03efbd0f01000000, repr=<verl.workers.rollout.vllm_rollout.vllm_async_server.AsyncvLLMServer object at 0x7f03b9dc0670>) File "/miniconda3/envs/verl/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.__get_result() File "/miniconda3/envs/verl/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/cpfs01/user/zhoujiecheng/workload_rl_analyse/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py", line 190, in init_engine self.engine = AsyncLLM.from_vllm_config(vllm_config) File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/async_llm.py", line 150, in from_vllm_config return cls( File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/async_llm.py", line 118, in init self.engine_core = core_client_class( File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 642, in init super().init( File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 398, in init self._wait_for_engine_startup() File "/miniconda3/envs/verl/lib/python3.10/site-packages/vllm/v1/engine/core_client.py", line 430, in _wait_for_engine_startup raise RuntimeError("Engine core initialization failed. " RuntimeError: Engine core initialization failed. See root cause above.
How did you use the vllm async rollout? Could you post a test script?
How did you use the vllm async rollout? Could you post a test script?
I only made some modifications to verl/recipe/dapo/test_dapo_7b.sh, setting actor_rollout_ref.rollout.mode=async. And change NNODE to 1 for debug.
set -xeuo pipefail
project_name='DAPO' exp_name='DAPO-Qwen2.5-7B-Math-Test' export adv_estimator=grpo
use_kl_in_reward=False kl_coef=0.0 use_kl_loss=False kl_loss_coef=0.0
clip_ratio_low=0.2 clip_ratio_high=0.28
max_prompt_length=$((1024 * 2)) max_response_length=$((1024 * 2)) enable_overlong_buffer=True overlong_buffer_len=512 overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
enable_filter_groups=False filter_groups_metric=acc max_num_gen_batches=10 train_prompt_bsz=16 gen_prompt_bsz=$((train_prompt_bsz * 3)) train_prompt_mini_bsz=32 n_resp_per_prompt=16
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} WORKING_DIR=${WORKING_DIR:-"${PWD}"} RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} NNODES=${NNODES:-1}
RAY_DATA_HOME=${RAY_DATA_HOME:-"/cpfs01/user/zhoujiecheng/verl"} MODEL_PATH=${MODEL_PATH:-"/cpfs01/user/zhoujiecheng/workload_rl_analyse/models/Qwen2.5-Math-7B"} CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
temperature=1.0 top_p=1.0 top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
use_dynamic_bsz=True infer_micro_batch_size=null train_micro_batch_size=null offload=False
python3 -m recipe.dapo.main_dapo
data.train_files="${TRAIN_FILE}"
data.val_files="${TEST_FILE}"
data.prompt_key=prompt
data.truncation='left'
data.max_prompt_length=${max_prompt_length}
data.max_response_length=${max_response_length}
data.gen_batch_size=${gen_prompt_bsz}
data.train_batch_size=${train_prompt_bsz}
actor_rollout_ref.rollout.n=${n_resp_per_prompt}
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}
actor_rollout_ref.actor.clip_ratio_c=10.0
algorithm.adv_estimator=${adv_estimator}
algorithm.use_kl_in_reward=${use_kl_in_reward}
algorithm.kl_ctrl.kl_coef=${kl_coef}
algorithm.filter_groups.enable=${enable_filter_groups}
algorithm.filter_groups.metric=${filter_groups_metric}
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches}
actor_rollout_ref.model.use_remove_padding=True
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz}
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz}
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz}
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length))
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length))
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length))
actor_rollout_ref.model.path="${MODEL_PATH}"
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.optim.lr_warmup_steps=10
actor_rollout_ref.actor.optim.weight_decay=0.1
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size}
actor_rollout_ref.actor.fsdp_config.param_offload=${offload}
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload}
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.grad_clip=1.0
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1
actor_rollout_ref.rollout.gpu_memory_utilization=0.8
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size}
actor_rollout_ref.rollout.tensor_model_parallel_size=2
actor_rollout_ref.rollout.name=vllm
actor_rollout_ref.rollout.mode=async
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length))
actor_rollout_ref.rollout.temperature=${temperature}
actor_rollout_ref.rollout.top_p=${top_p}
actor_rollout_ref.rollout.top_k="${top_k}"
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}
actor_rollout_ref.rollout.val_kwargs.do_sample=True
actor_rollout_ref.rollout.val_kwargs.n=1
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size}
actor_rollout_ref.ref.fsdp_config.param_offload=${offload}
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1
reward_model.reward_manager=dapo
reward_model.overlong_buffer.enable=${enable_overlong_buffer}
reward_model.overlong_buffer.len=${overlong_buffer_len}
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor}
trainer.logger=['console']
trainer.project_name="${project_name}"
trainer.experiment_name="${exp_name}"
trainer.n_gpus_per_node=8
trainer.nnodes="${NNODES}"
trainer.val_before_train=True
trainer.test_freq=10
trainer.save_freq=10
trainer.total_epochs=1
trainer.default_local_dir="${CKPTS_DIR}"
trainer.resume_mode=disable
python3 -m recipe.dapo.main_dapo
It uses recipe/dapo/dapo_ray_trainer.py which does not support async_rollout_mode.
I can reproduce this error using latest image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
It also can be reproduced in my host after sglang and megatron dependances were installed. The problem may come with package conflict, but I still can not figure which package cases the error yet.
However, there is a workaround to get it works, run below cmd before execute training script:
export VLLM_USE_V1=1 && ray start --head
cc @wuxibin89
May I ask if there is a solution to this issue? I still have this problem with the latest version of Verl
May I ask if there is a solution to this issue? I still have this problem with the latest version of Verl
Does this help?
export VLLM_USE_V1=1 && ray start --head
It is a temp solution, but can work for now.
@wuxibin89 Do you know when this problem can be fixed?
May I ask if there is a solution to this issue? I still have this problem with the latest version of Verl
Does this help?
export VLLM_USE_V1=1 && ray start --headIt is a temp solution, but can work for now. @wuxibin89 Do you know when this problem can be fixed?
export VLLM_USE_V1=1 && ray start --head doesn't work for me. However, I have found a solution. When I use the main_dapo.py file for training, this error is reported. I replaced the relevant code in main_dapo.py with the code for defining worker classes in main_ppo.py, and the code can now run normally. I wonder whether verl github forget to update main_dapo.py and the code of defining worker classes in main_ppo.py listed below works
# Define worker classes based on the actor strategy.
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
assert config.critic.strategy in ["fsdp", "fsdp2"]
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
# Map roles to their corresponding remote worker classes.
role_worker_mapping = {
Role.ActorRollout: ray.remote(actor_rollout_cls),
Role.Critic: ray.remote(CriticWorker),
}
# Define the resource pool specification.
# Map roles to the resource pool.
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}
# We should adopt a multi-source reward function here:
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# finally, we combine all the rewards together
# The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in ["fsdp", "fsdp2"]:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# Add a reference policy worker if KL loss or KL reward is used.
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
yes, you are right. We need a better method to enable async rollout.
@chenhaiq @2645283289 I have replace ExternalRayDistributedExecutor to ExternalZeroMQDistributedExecutor, https://github.com/volcengine/verl/pull/2246. This problem should be resolved.