verl
verl copied to clipboard
[recipe] feat: add support for Single-stream Policy Optimization (SPO)
What does this PR do?
Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.
Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: New features, no similar PRs.
- [x] Format the PR title as
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data- If this PR involves multiple modules, separate them with
,like[megatron, fsdp, doc] {type}is infeat,fix,refactor,chore,test- If this PR breaks any API (CLI arguments, config, function signature, etc.), add
[BREAKING]to the beginning of the title. - Example:
[BREAKING][fsdp, megatron] feat: dynamic batching
Test
For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.
API and Usage Example
Demonstrate how the API changes if any, and provide usage example(s) if possible.
Here's a complete example combining all preprocessing and training steps:
# Step 1: Split dataset into subsets
python recipe/spo/estimate_offline_values/split_dapo_into_subsets.py \
--dataset open-r1/DAPO-Math-17k-Processed \
--output_dir DAPO-Math-17k-Processed_Splits \
--num_subsets 5
# Step 2: Generate offline value estimates for each subset
for i in {0..4}; do
OUTPUT_DIR=spo_verl_pr \
DATA_FILE=DAPO-Math-17k-Processed_Splits/subset_${i}.parquet \
MODEL_PATH=Qwen/Qwen3-8B \
EXP_NAME=offline_value_estimation_subset_${i} \
sh recipe/spo/estimate_offline_values/eval.sh
done
# Step 3: Merge offline value estimates
python recipe/spo/estimate_offline_values/merge_offline_values.py \
--input_dir spo_verl_pr/offline_value_estimation \
--output_file DAPO-Math-17k-Processed_Splits/offline_values.json
# Step 4: Train with SPO
OUTPUT_DIR=spo_verl_pr \
TRAIN_DATA_DIR=DAPO-Math-17k-Processed_Splits \
MODEL_PATH=Qwen/Qwen3-8B \
EXP_NAME=spo_training \
METHOD=SPO \
OFFLINE_VALUES=DAPO-Math-17k-Processed_Splits/offline_values.json \
sh recipe/spo/train.sh
Design & Code Changes
Demonstrate the high-level design if this PR is complex, and list the specific changes.
This PR implements Single-stream Policy Optimization (SPO), an efficient reinforcement learning algorithm that reduces GPU memory consumption by 8x compared to traditional GRPO while maintaining comparable performance. The implementation is built on top of the VERL framework and includes a complete pipeline for math problem solving with tool-augmented reasoning.
High-Level Design
SPO introduces two key innovations over standard GRPO (Group Relative Policy Optimization):
- Single-Response Generation: Generates 1 response per prompt instead of 8, reducing memory requirements from 768 to 96 tokens per batch
- Sampling with Offline Values: Uses pretrained model estimates to intelligently select prompts for training, maintaining sample efficiency
The architecture consists of three main components:
- Offline Value Estimation Pipeline: Preprocesses training data to estimate response quality
- SPO Training Loop: Implements Thompson Sampling-based prompt selection and advantage estimation
- Tool-Augmented Agent: Multi-turn reasoning with Python code execution capability
Specific Code Changes
1. Core SPO Training Implementation (spo_ray_trainer.py:1160-1496)
Sampling with Adaptive Weighting:
- Maintains Beta distributions α and β for each prompt to model success probability
- Samples prompts proportionally to uncertainty:
weight ∝ √(p̂(1-p̂)) + ε - Updates distributions using adaptive decay factor
ρbased on KL divergence
# Weighted sampling (lines 1184-1230)
prompt2phat = {k: α[k]/(α[k]+β[k]) for k in prompts}
prompt2weight = {k: √(p̂*(1-p̂)) + 0.05 for k in prompts}
selected_prompts = np.random.choice(prompts, size=batch_size, p=weights)
# Advantage estimation (lines 1346-1378)
advantages = reward - p_hats # SPO advantage
advantages = (advantages - mean) / (std + 1e-8) # Normalize
# Distribution updates (lines 1474-1489)
ρ = 2^(-D/D_half) # Decay based on KL divergence
α_new = ρ*α + reward
β_new = ρ*β + (1-reward)
Key Parameters:
- offline_N=8: Number of offline samples used for prior estimation
- rho.type="kl": Adaptive decay based on policy drift
- clip_lower=0.875: Minimum decay factor to prevent over-reliance on old data
- Offline Value Estimation Pipeline
Data Preprocessing (estimate_offline_values/split_dapo_into_subsets.py):
- Splits large datasets (17k samples) into 5 subsets for parallel processing
- Outputs .parquet files for distributed evaluation
Value Generation (estimate_offline_values/eval.sh):
- Runs pretrained model on each subset to generate offline rewards
- Uses same reward function as training for consistency
- Stores results in validation_data/0.jsonl format
Merging (estimate_offline_values/merge_offline_values.py:44-134):
- Aggregates scores from all subsets by prompt
- Subsamples to max 8 scores per prompt using random selection
- Outputs offline_values.json mapping prompts to score lists
- Custom Dataset and Reward Function (spo_retool.py)
CustomRLHFDataset (lines 57-126):
- Processes multiple math datasets: AIME 2024/2025, DAPO-Math-17k, BeyondAIME
- Appends answer format requirement: \boxed{answer} to all prompts
- Maps datasets to consistent schema with data_source, reward_model, agent_name
compute_score (lines 128-151):
- Validates reasoning format: exactly one tag, no code blocks after thinking
- Extracts boxed answer using math_dapo.compute_score with strict verification
- Returns binary reward (0 or 1) and predicted answer
- Tool-Augmented Agent Loop (agent_loop/spo_tool_agent_loop.py)
State Machine Architecture (lines 156-168):
- PENDING: Prepares initial prompt with tool instructions
- GENERATING: LLM generates reasoning/code
- PROCESSING_TOOLS: Executes Python code in sandbox
- TERMINATED: Completes trajectory
Multi-Turn Code Execution (lines 253-278):
- Parses code blocks from LLM output using
...markers - Executes in isolated sandbox with timeout/security controls
- Returns output wrapped in
... tags - Updates response_mask: 1 for LLM tokens, 0 for tool outputs
Key Features:
- Stateless sandbox: each execution is independent
- Truncation handling: limits tool outputs to max_tool_response_length (configurable)
- Error recovery: gracefully handles execution failures
- Training Configuration (train.sh)
GRPO Configuration: train_batch_size=128 # 8 responses × 16 prompts ppo_mini_batch_size=16 gen_batch_size=128 n_resp_per_prompt=8
SPO Configuration: train_batch_size=1024 # 1 response × 1024 prompts (8× larger) ppo_mini_batch_size=128 # 8× larger mini-batches gen_batch_size=14000 # Large batch for efficient generation n_resp_per_prompt=1 # Single response per prompt offline_values=path/to/offline_values.json
- Additional Modifications
Agent Loop Manager (spo_agent_loop.py:663-813):
- Manages async rollout workers for parallel trajectory generation
- Implements LLM server load balancing with sticky sessions
- Computes performance metrics: generation time, tool call latency
Configuration System (config/spo_trainer.yaml):
- Extends base PPO config with SPO-specific parameters
- Configurable rho decay strategies: constant or KL-based
- Integrated with Hydra for flexible experiment management
Checklist Before Submitting
[!IMPORTANT] Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
- [x] Read the Contribute Guide.
- [x] Apply pre-commit checks:
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always - [x] Add / Update the documentation.
- [ ] Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in the
ci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)
@wuxibin89 @PeterSH6 @vermouth1992 Could you please review this PR? 🙏