[fsdp] feat: integrate PrefixGrouper for GRPO training acceleration
What does this PR do?
Integrate PrefixGrouper into verl's FSDP worker to accelerate GRPO training by reducing redundant prefix computations.
In GRPO training, each prompt is copied G times (rollout.n), leading to redundant self-attention computation on shared prefixes. PrefixGrouper decomposes this into prefix self-attention + suffix concat-attention, significantly reducing computation and memory usage.
Key changes:
- Add
use_prefix_grouperconfig option inActorConfig - Implement PG forward path in
DataParallelPPOActor._forward_micro_batch - Add utility functions in
verl/trainer/ppo/prefix_grouper_utils.py - Add example scripts and documentation in
examples/prefix_grouper_examples/
Test
Benchmark Results (Qwen3-4B, 4×H800, rollout.n=4):
| Context Length | Metric | PG | No PG | Speedup |
|---|---|---|---|---|
| 4K | old_log_prob |
1.31s | 1.70s | 1.30x |
update_actor |
4.80s | 6.07s | 1.26x | |
step |
17.08s | 19.40s | 1.14x | |
| 8K | old_log_prob |
1.69s | 2.63s | 1.56x |
update_actor |
5.98s | 10.18s | 1.70x | |
step |
19.48s | 24.71s | 1.27x |
As context length increases, the speedup becomes more pronounced.
API and Usage Example
# Enable PrefixGrouper in training config
actor_rollout_ref.actor.use_prefix_grouper=True
trainer.balance_batch=False # Required: PG is incompatible with balance_batch
actor_rollout_ref.model.use_remove_padding=False # Required: PG is incompatible with remove_padding
# Run example script
bash examples/prefix_grouper_examples/run_qwen3_pg.sh
Design & Code Changes
High-level Design:
PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When rollout.n > 1, multiple responses share the same prompt, but standard attention computes the prefix n times. PrefixGrouper decomposes this into:
- Prefix self-attention: Compute once per unique prompt
- Suffix concat-attention: Each response attends to the shared prefix output
Design & Code Changes
High-level Design:
PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When rollout.n > 1, multiple responses share the same prompt, but standard attention computes the prefix n times. PrefixGrouper decomposes this into:
- Prefix self-attention: Compute once per unique prompt
- Suffix concat-attention: Each response attends to the shared prefix output
Code Changes:
| File | Change |
|---|---|
verl/workers/config/actor.py |
Add use_prefix_grouper: bool = False config option |
verl/trainer/config/actor/actor.yaml |
Add use_prefix_grouper: false default config |
verl/workers/actor/dp_actor.py |
(1) Add self.use_prefix_grouper and self.use_dynamic_bsz attributes in __init__; (2) Add PG forward path in _forward_micro_batch with lazy import and incompatibility checks; (3) Select extra keys (prompts, response_mask, uid) for PG in compute_log_prob; (4) Select extra keys (prompts, uid) for PG in update_policy |
verl/trainer/ppo/prefix_grouper_utils.py |
New file with: build_position_ids_for_prefix_grouper() for position encoding, build_pg_from_micro_batch() to construct PrefixGrouper from micro batch, pg_forward() to execute PG-optimized forward pass |
verl/workers/fsdp_workers.py |
Sync use_prefix_grouper config from actor to ref policy in init_model to ensure both use the same forward path |
verl/trainer/ppo/ray_trainer.py |
Add ValueError check for use_prefix_grouper + balance_batch incompatibility at initialization |
examples/prefix_grouper_examples/ |
New directory with: README.md documentation, run_qwen3_prefix_grouper.sh example script, qwen3/modeling_qwen3.py modified model supporting PrefixGrouper |
Limitations
- FSDP worker only: Megatron worker is not supported yet
- Incompatible configurations:
trainer.balance_batch=True(reorders data, breaks uid grouping)use_dynamic_bsz=Trueuse_remove_padding=True(Flash Attention V2 variable length)use_fused_kernels=Trueuse_ulysses_sp=True(Ulysses sequence parallelism)
- Model modification required: The model must accept
prefix_grouperargument in itsforwardmethod
Checklist Before Submitting
- [x] Read the Contribute Guide.
- [x] Apply pre-commit checks:
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always - [x] Add / Update the documentation. (Added examples/prefix_grouper_examples/README.md)
- [ ] Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: PrefixGrouper requires modified model files and specific hardware setup, tested manually with benchmark results above.
- [ ] Once your PR is ready for CI, send a message in the
ci-requestchannel.
/gemini review
Please format code: https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting
Please format code: https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting
Fixed, thanks!