verl icon indicating copy to clipboard operation
verl copied to clipboard

[fsdp] feat: integrate PrefixGrouper for GRPO training acceleration

Open kevssim opened this issue 1 month ago • 1 comments

What does this PR do?

Integrate PrefixGrouper into verl's FSDP worker to accelerate GRPO training by reducing redundant prefix computations.

In GRPO training, each prompt is copied G times (rollout.n), leading to redundant self-attention computation on shared prefixes. PrefixGrouper decomposes this into prefix self-attention + suffix concat-attention, significantly reducing computation and memory usage.

Key changes:

  • Add use_prefix_grouper config option in ActorConfig
  • Implement PG forward path in DataParallelPPOActor._forward_micro_batch
  • Add utility functions in verl/trainer/ppo/prefix_grouper_utils.py
  • Add example scripts and documentation in examples/prefix_grouper_examples/

Test

Benchmark Results (Qwen3-4B, 4×H800, rollout.n=4):

Context Length Metric PG No PG Speedup
4K old_log_prob 1.31s 1.70s 1.30x
update_actor 4.80s 6.07s 1.26x
step 17.08s 19.40s 1.14x
8K old_log_prob 1.69s 2.63s 1.56x
update_actor 5.98s 10.18s 1.70x
step 19.48s 24.71s 1.27x
timing_comparison_combined

As context length increases, the speedup becomes more pronounced.

API and Usage Example

# Enable PrefixGrouper in training config
actor_rollout_ref.actor.use_prefix_grouper=True
trainer.balance_batch=False  # Required: PG is incompatible with balance_batch
actor_rollout_ref.model.use_remove_padding=False  # Required: PG is incompatible with remove_padding
# Run example script
bash examples/prefix_grouper_examples/run_qwen3_pg.sh

Design & Code Changes

High-level Design:

PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When rollout.n > 1, multiple responses share the same prompt, but standard attention computes the prefix n times. PrefixGrouper decomposes this into:

  1. Prefix self-attention: Compute once per unique prompt
  2. Suffix concat-attention: Each response attends to the shared prefix output

Design & Code Changes

High-level Design:

PrefixGrouper optimizes GRPO training by avoiding redundant computation on shared prefixes. When rollout.n > 1, multiple responses share the same prompt, but standard attention computes the prefix n times. PrefixGrouper decomposes this into:

  1. Prefix self-attention: Compute once per unique prompt
  2. Suffix concat-attention: Each response attends to the shared prefix output

Code Changes:

File Change
verl/workers/config/actor.py Add use_prefix_grouper: bool = False config option
verl/trainer/config/actor/actor.yaml Add use_prefix_grouper: false default config
verl/workers/actor/dp_actor.py (1) Add self.use_prefix_grouper and self.use_dynamic_bsz attributes in __init__; (2) Add PG forward path in _forward_micro_batch with lazy import and incompatibility checks; (3) Select extra keys (prompts, response_mask, uid) for PG in compute_log_prob; (4) Select extra keys (prompts, uid) for PG in update_policy
verl/trainer/ppo/prefix_grouper_utils.py New file with: build_position_ids_for_prefix_grouper() for position encoding, build_pg_from_micro_batch() to construct PrefixGrouper from micro batch, pg_forward() to execute PG-optimized forward pass
verl/workers/fsdp_workers.py Sync use_prefix_grouper config from actor to ref policy in init_model to ensure both use the same forward path
verl/trainer/ppo/ray_trainer.py Add ValueError check for use_prefix_grouper + balance_batch incompatibility at initialization
examples/prefix_grouper_examples/ New directory with: README.md documentation, run_qwen3_prefix_grouper.sh example script, qwen3/modeling_qwen3.py modified model supporting PrefixGrouper

Limitations

  • FSDP worker only: Megatron worker is not supported yet
  • Incompatible configurations:
    • trainer.balance_batch=True (reorders data, breaks uid grouping)
    • use_dynamic_bsz=True
    • use_remove_padding=True (Flash Attention V2 variable length)
    • use_fused_kernels=True
    • use_ulysses_sp=True (Ulysses sequence parallelism)
  • Model modification required: The model must accept prefix_grouper argument in its forward method

Checklist Before Submitting

  • [x] Read the Contribute Guide.
  • [x] Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • [x] Add / Update the documentation. (Added examples/prefix_grouper_examples/README.md)
  • [ ] Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: PrefixGrouper requires modified model files and specific hardware setup, tested manually with benchmark results above.
  • [ ] Once your PR is ready for CI, send a message in the ci-request channel.

kevssim avatar Dec 01 '25 07:12 kevssim

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Dec 01 '25 07:12 CLAassistant

/gemini review

wuxibin89 avatar Dec 17 '25 08:12 wuxibin89

Please format code: https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting

wuxibin89 avatar Dec 17 '25 08:12 wuxibin89

Please format code: https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting

Fixed, thanks!

kevssim avatar Dec 17 '25 09:12 kevssim