verl
verl copied to clipboard
[FEAT] Support async multi-turn rollout with simulation feedback
Checklist Before Starting
- [x] Search for similar PR(s).
What does this PR do?
- This PR adds support for simulated user feedback in multi-turn interactions, which is something we really need for a lot of situations
High-Level Design
- I followed the existing design for multi-turn tool calls. I added a new state where we wait for user feedback, and folks can customize how that simulated feedback works, similar to how you customize tools. The feedback then just gets added right into the message
Specific Changes
- Custom feedback function
- New AsyncRolloutRequestState type
API
Demonstrate how the API changes if any.
Usage Example
python3 -m verl.trainer.main_ppo \
--config-path="$CONFIG_PATH" \
--config-name='gsm8k_multiturn_grpo' \
algorithm.adv_estimator=grpo \
data.train_batch_size=256 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang_async \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='gsm8k_async_rl' \
trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16-4cards' \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=20 \
trainer.total_epochs=15 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \
critic.ppo_max_token_len_per_gpu=8192 \
critic.forward_max_token_len_per_gpu=8192 \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml" \
actor_rollout_ref.rollout.multi_turn.max_user_turns=1 \
Test
For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc.
Additional Info.
- Issue Number: Fixes issue # or discussion # if any.
- Training: [Note which backend this PR will affect: FSDP, Megatron, both, or none]
- Inference: [Note which backend this PR will affect: vLLM, SGLang, both, or none]
Checklist Before Submitting
- [x] Read the Contribute Guide.
- [x] Apply pre-commit checks.
- [x] Add
[BREAKING]to the PR title if it breaks any API. - [x] Update the documentation about your changes in the docs.
- [x] Add CI test(s) if necessary.
@zhaochenyang20
nice work!
We will take a look these days, stay tuned in 24h .
@kinza99 my wechat is 18015766633. Feel free to discuss, thanks!
To-do:
- Review new examples in wandb log.
- Refactor duplicated unit tests.
There seems to be lots of changes. I'll take some time to review as well
Finally works, learn a lot with He Du. https://wandb.ai/swordfaith/gsm8k_async_rl/runs/8n409ugt?nw=nwuserswordfaith fin
is there a reference paper with "simulation feedback" that this PR implements?
You can refer to https://arxiv.org/abs/2409.12917.
Before merging, please note that this PR introduces breaking changes to the multi_turn configuration. The max_turns parameter has been renamed to max_assistant_turns, retaining the same semantic meaning. Additionally, a new parameter max_user_turns has been added to control the number of user interactions.
Why is user message used to calcualte the reward instead of the assistant message?
Why is user message used to calcualte the reward instead of the assistant message?
We are developing a message and turn-level reward system. Calculating rewards during the user's turn doesn't necessarily mean that the user's message requires a reward. Instead, it allows the user to reallocate rewards between the assistant's turns based on the user's turn rewards.
Why is user message used to calcualte the reward instead of the assistant message?