stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

[Feature Request] Implement GRPO

Open floribau opened this issue 5 months ago • 2 comments

🚀 Feature

Group Relative Policy Optimization (GRPO) is a reinforcement learning algorithm introduced in https://arxiv.org/pdf/2402.03300, which has gained a lot of attention following its use in fine-tuning DeepSeek-R1. GRPO is an extension of PPO with a similar clipped objective. The main difference is that GRPO eliminates the need for training a value model, reducing both memory and computational burden. Instead, for each training iteration, GRPO samples a group of trajectories sharing the same initial state, and then uses the group average reward as the baseline, replacing the value estimate baseline.

DeepSeekMath introduced two variants of GRPO:

  • Outcome Supervision, which assumes a single scalar reward to the final output of a trajectory. The advantage of each trajectory is the trajectory reward normalized by the mean trajectory reward. Note: While this variant can be directly applied to classic Gym environments, it doesn't make use of the per-step rewards returned after each action, thus leading to sparse rewards and slower learning.
  • Process Supervision, which uses per-step advantages. Each per-step reward is first normalized by the mean per-step reward across the whole group. Then, for each step, the advantage is calculated as the return-to-go of these normalized per-step rewards. Note: The per-step normalization poses the following issue: In environments which assign a fixed reward for each step (e.g., CartPole's +1 for each step), the mean reward across all steps equals each single per-step reward. This leads to all normalized rewards being 0 and thus, also all advantages being 0 as well.

To overcome the issue of DeepSeek's process supervision, at https://www.kaggle.com/discussions/general/573162 another process supervision variant has been proposed. This variant first computes the return-to-go for each time step, using the per-step rewards. Then, for each step, the return-to-go is normalized by the average return-to-go across the whole group to get the per-step advantage.

Probably not all of the variants can and should be added to SB3-contrib. Which of the variants to contribute is up for discussion for me.

Motivation

I implemented the GRPO algorithm out of curiosity, since it wasn't already existing in SB3. While originally developed for LLM fine-tuning, I hypothesize that the algorithm isn't limited to this type of task and should work on classical reinforcement learning tasks as well. So, I wanted to apply the different variants of GRPO on Gym environments and compare the performance to other RL baselines.

I know that there already is a similar open issue #273, but the algorithm suggested there doesn't strictly implement GRPO, but rather a hybrid approach closer to a multi-sample PPO approach. Therefore, I implemented a proof of concept version myself and opened this new issue.

Pitch

A first implementation is in my forked repository on the 'contrib-grpo' branch: https://github.com/floribau/stable-baselines3-contrib-grpo/tree/contrib-grpo. The repo contains the GRPO class in https://github.com/floribau/stable-baselines3-contrib-grpo/blob/contrib-grpo/sb3_contrib/grpo/grpo.py, with the three different variants described above being realized by initializing the GRPO class with different GroupBuffer classes (https://github.com/floribau/stable-baselines3-contrib-grpo/blob/contrib-grpo/sb3_contrib/common/buffers.py).

However, this should be understood only as a proof of concept and is by no means ready for a PR (no tests, not enough evaluation on benchmarks). Before further working on it to make it ready for a PR, I want to discuss the general algorithm and its suitability for SB3 here. Since my code structure isn't optimal either, I also want to discuss suggested changes to my structure. After the discussion, I'll go ahead and implement a version following the SB3 standards more closely.

Alternatives

To my knowledge, SB3 doesn't include any multi-sample algorithm, constructing the baseline from a group average. I'm happy to discuss alternatives in implementation details.

Additional context

As mentioned above, I know of the existence of issue #273, and want to address some of the questions raised there:

  • My proposed algorithm follows DeepSeek's implementation more strictly than the hybrid approach and only adds another possible variant on top of implementing their variants.
  • I acknowledge that GRPO was originally developed for LLM fine-tuning. My implementation elegantly uses seeds to enable sampling of multiple trajectories, thus working for all Gym envs. Since the group trajectories should all start from the same initial state, I simply reset the environments to a seed fixed for a group rollout collection. Thus, there's no need to cumbersomely deepcopy the env states.
  • Since GRPO was originally developed for LLM fine-tuning, there aren't any performance baselines on classical RL tasks in the paper. However, I ran the different variants in small experiments on CartPole and compared them to PPO. I plan to compare them on more complex benchmarks and larger experiments. An initial small comparison (with only 4 runs on CartPole) showed promising results, with the alternative process supervision variant achieving performance similar to PPO, thus motivating further work. Image

Checklist

  • [x] I have checked that there is no similar issue in the repo
  • [x] If I'm requesting a new feature, I have proposed alternatives

floribau avatar Jul 11 '25 14:07 floribau

Hello, thanks for the proposal.

I acknowledge that GRPO was originally developed for LLM fine-tuning. My implementation elegantly uses seeds to enable sampling of multiple trajectories,

Similar to what I wrote in https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/273, GRPO seems to be specific to LLM training (which SB3 doesn't really support).

Using

            env.seed(random_seed)
            self._last_obs = env.reset()  # always reset envs to the same state in the same group

is indeed a hack that might work in some environments (the one that can be reset to an exact state) and seems to be valid only when using a single env (while not dones: is harder to interpret when you have multiple env at the same time and each episode doesn't terminate at the same timestep). For other env, this cannot be used, for instance when using a real robot or simulation with no access to a seed/reset state method.

here aren't any performance baselines on classical RL tasks in the paper. However, I ran the different variants in small experiments on CartPole and compared them to PPO. I plan to compare them on more complex benchmarks and larger experiments.

That's the main thing missing indeed, proper benchmarking and evaluation.

EDIT: if you want, we could point to your implementation in SB3 project page in the documentation

araffin avatar Aug 01 '25 15:08 araffin

Hi @araffin,

thanks for your review.

is indeed a hack that might work in some environments (the one that can be reset to an exact state) and seems to be valid only when using a single env (while not dones: is harder to interpret when you have multiple env at the same time and each episode doesn't terminate at the same timestep).

You are correct that my implementation currently only supports single env due to the mentioned reason (while not dones: being difficult to implement correctly for multiple env).

Similar to what I wrote in #273, GRPO seems to be specific to LLM training (which SB3 doesn't really support).

It is true that GRPO was initially developed for LLM fine-tuning, yet my first comparison on CartPole showed promising potential to work on Gym envs as well. The only constraint is that the env needs to be able to be reset to an exact state via a seed, as you pointed out. However, this seems to be the case for most Gym envs in the Farama project.

EDIT: I do understand however that the proposed solution doesn't work for other envs like real robots.

That's the main thing missing indeed, proper benchmarking and evaluation.

If you are still open to include the implementation into the SB3-contrib project given the mentioned constraints (single env and working only for envs that allow exact resets), I will run more experiments and present proper evaluation.

EDIT: if you want, we could point to your implementation in SB3 project page in the documentation

In any case, yes, I'd be happy if you point to my implementation in the SB3 project documentation :)

floribau avatar Dec 02 '25 16:12 floribau