GRPO implementation update
Let's use this issue to share the latest GRPO development updates. CC @gauravpandeyamu thanks for your fix.
The command below (without @gauravpandeyamu's fix) yields the charts below. Overall the training score and sequence length goes up, but the downstream eval in MATH seems to suffer. I am gonna try out @gauravpandeyamu's fix on KL regularization to see if it helps.
for beta in 0.03 0.01 0.0; do
for nspp in 16 32; do
for m in half-m ; do
local_rollout_batch_size=8
if [ $m == "half-m" ]; then
local_mini_batch_size=$(($local_rollout_batch_size * $nspp / 2))
else
local_mini_batch_size=$(($local_rollout_batch_size * $nspp))
fi
exp_name="0128_grpo_math_zs_${beta}_${nspp}_${m}_${RANDOM}"
full_bsz=$(($local_rollout_batch_size * nspp * (8 + 8 + 8 + 7) * 2))
echo $exp_name:
echo --- local_mini_batch_size=$local_mini_batch_size
echo --- full_bsz=$full_bsz
echo --- num_gradient_updates=$(($local_rollout_batch_size * $nspp / $local_mini_batch_size))
exp_name="0128_grpo_math_zs_${beta}_${nspp}_${RANDOM}"
python mason.py \
--cluster ai2/jupiter-cirrascale-2 \
--workspace ai2/tulu-3-dev \
--priority high \
--preemptible \
--num_nodes 1 \
--max_retries 1 \
--budget ai2/oe-adapt \
--gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& uv run python open_instruct/grpo_vllm_thread_ray_gtrl.py \
--exp_name $exp_name \
--beta $beta \
--local_mini_batch_size $local_mini_batch_size \
--number_samples_per_prompt $nspp \
--output_dir /weka/oe-adapt-default/costah/models/$exp_name \
--local_rollout_batch_size $local_rollout_batch_size \
--dataset_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 1.0}" \
--dataset_train_splits train \
--dataset_eval_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 32}" \
--dataset_eval_splits train \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
--model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \
--non_stop_penalty \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--sft_messages_key messages \
--learning_rate 5e-7 \
--total_episodes 1000000 \
--penalty_reward_value 0.0 \
--deepspeed_stage 2 \
--per_device_train_batch_size 2 \
--local_rollout_forward_batch_size 2 \
--actor_num_gpus_per_node 6 \
--num_epochs 1 \
--vllm_tensor_parallel_size 2 \
--lr_scheduler_type constant \
--apply_verifiable_reward true \
--seed 1 \
--num_evals 1000 \
--save_freq 40 \
--reward_model_multiplier 0.0 \
--no_try_launch_beaker_eval_jobs \
--try_launch_beaker_eval_jobs_on_weka \
--gradient_checkpointing \
--with_tracking
done
done
done
The KL constraint didn't work well. 0129_grpo_math_kl_fix_zs_0.03_16_half-m_30414__1__1738185952 which includes the KL fix seems to have even higher KL than the previously incorrect run.
That's a good catch. While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence. I suspect that you are using the kl1 estimator.
The gradient of the kl divergence between the policy $\pi_\theta$ and $\pi_{ref}$, the reference policy, is given by $$\nabla_\theta \sum_{y} \pi_\theta (y|x) \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}$$
If the samples $y_1, ..., y_n$ come from $\pi_t(y|x)$, the gradient can be computed as $$\nabla_\theta \frac{1}{n} \sum_{i=1}^n \frac{\pi_\theta (y|x)}{\pi_t(y|x)} \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}$$
This is what kl1 should have been. kl2 and kl3 are still fine.
BTW, the GRPO paper recommends using kl3 estimator (Equation 4 in https://arxiv.org/pdf/2402.03300)
Yeah, so I thought about the kl3. The issue is that kl3 blows up: see
https://github.com/huggingface/trl/pull/423#issuecomment-1584590437
but this only happens when using KL as part of the reward. I am not sure what happens when we put it in the loss directly...
Interesting. I agree that it might work if added in the loss directly.
I have modified my local fork with the following changes:
# kl loss should be computed without torch.no_grad()
ref_logprobs_diff = new_logprobs - ref_logprobs[micro_batch_inds]
kl1 = ratio * ref_logprobs_diff
kl2 = (ref_logprobs_diff) ** 2 / 2
kl3 = (-ref_logprobs_diff).exp() - 1 + ref_logprobs_diff
if args.kl_estimator == "kl1":
kl = kl1
elif args.kl_estimator == "kl2":
kl = kl2
elif args.kl_estimator == "kl3":
kl = kl3
kl_loss = masked_mean(kl, ~padding_mask[micro_batch_inds])
pg_loss = pg_loss + args.beta * kl_loss
Basically, it gives a kl1 estimator whose gradient is unbiased. I haven't created a new PR since it has conflicting changes with your PR.
I can help test it out. I modified it to the following and launched a scan. Will report back with results.
kl1 = new_logprobs - mb_reflogprobs
kl2 = (kl1) ** 2 / 2
kl3 = (-kl1).exp() - 1 + kl1
kl4 = ratio * kl1
if args.kl_estimator == "kl1":
kl = kl1
elif args.kl_estimator == "kl2":
kl = kl2
elif args.kl_estimator == "kl3":
kl = kl3
elif args.kl_estimator == "kl4":
kl = kl4
if epoch_idx == 0:
kl_stats[micro_batch_inds] = kl.sum(1).float()
# grpo change: directly subtract KL in loss (add)
pg_loss = masked_mean(pg_loss_max + (args.beta * kl), ~padding_mask[micro_batch_inds])
for beta in 0.03; do
for nspp in 16; do
for m in half-m ; do
for kl_estimator in k1 k2 k3 k4; do
local_rollout_batch_size=8
if [ $m == "half-m" ]; then
local_mini_batch_size=$(($local_rollout_batch_size * $nspp / 2))
else
local_mini_batch_size=$(($local_rollout_batch_size * $nspp))
fi
exp_name="0130_kl_scan_grpo_math_zs_${kl_estimator}_${beta}_${nspp}_${m}_${RANDOM}"
full_bsz=$(($local_rollout_batch_size * nspp * (8 + 8 + 8 + 7) * 2))
echo $exp_name:
echo --- local_mini_batch_size=$local_mini_batch_size
echo --- full_bsz=$full_bsz
echo --- num_gradient_updates=$(($local_rollout_batch_size * $nspp / $local_mini_batch_size))
python mason.py \
--cluster ai2/jupiter-cirrascale-2 \
--workspace ai2/tulu-3-dev \
--priority high \
--preemptible \
--num_nodes 1 \
--max_retries 1 \
--budget ai2/oe-adapt \
--gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& uv run python open_instruct/grpo_vllm_thread_ray_gtrl.py \
--exp_name $exp_name \
--beta $beta \
--local_mini_batch_size $local_mini_batch_size \
--number_samples_per_prompt $nspp \
--output_dir /weka/oe-adapt-default/costah/models/$exp_name \
--local_rollout_batch_size $local_rollout_batch_size \
--dataset_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 1.0}" \
--dataset_train_splits train \
--dataset_eval_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 32}" \
--dataset_eval_splits train \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
--model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \
--non_stop_penalty \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template tulu \
--sft_messages_key messages \
--learning_rate 5e-7 \
--total_episodes 1000000 \
--penalty_reward_value 0.0 \
--deepspeed_stage 2 \
--per_device_train_batch_size 2 \
--local_rollout_forward_batch_size 2 \
--actor_num_gpus_per_node 6 \
--num_epochs 1 \
--vllm_tensor_parallel_size 2 \
--lr_scheduler_type constant \
--apply_verifiable_reward true \
--seed 1 \
--num_evals 1000 \
--save_freq 40 \
--reward_model_multiplier 0.0 \
--no_try_launch_beaker_eval_jobs \
--try_launch_beaker_eval_jobs_on_weka \
--gradient_checkpointing \
--with_tracking
done
done
done
done
Great. Thanks
Eh this is a bit awkward -- none of the KL estimator seem to control KL well.
Since RLVR looks only at the final answer completely ignoring the the text generated, the model can be motivated to generate text that deviates from the reference policy as long as it reaches the correct final answer.
Perhaps, a stronger KL term (higher beta value) is desirable in RLVR, thus ensuring that the text generated is sensible.
FYI on KL3 estimator. https://x.com/vwxyzjn/status/1885329398821187633
Ok so when I launched the experiments I accidentally left out the --kl_estimator 🤡, so all exps were run using kl1.
Now when using the kl3 it looks much more reasonable.
Meanwhile
- kl1 looks wrong (larger beta induces larger KL...???)
-
kl2 also looks reasonable
-
kl4 seems ok
@gauravpandeyamu why did you multiply the ref_logprobs_diff by ratio? I don't get it.
Ahh, yes. Now, the graphs of kl1, kl2, kl3 and kl4 make perfect sense.
As for why kl1 estimator is a bad estimator, and how multiplying by the ratio fixes it (mainly the bias is fixed), here is ChatGPT's response.
https://chatgpt.com/share/679dfb75-cdec-800f-9078-f838d3925f9e
While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence.
If I am not mistaken, I believe that the gradient of KL3 is also not an unbiased estimator of the gradient of the KL divergence.
$$\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta}\left(\frac{\pi_{\mathrm{ref}}}{\pi_{\theta}} - \log \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}} - 1\right)\right] = -\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta}\log \pi_{\theta} \times \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}}\right] \neq -\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta} \log \pi_{\theta} \times \log \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}}\right] = \nabla_{\theta}\mathrm{KL}(\pi_{\theta}\|\pi_{\mathrm{ref}})$$
Given this, I don't fully understand how KL3 can work for optimization (maybe the lower variance alone is good enough, despite this potential bias). It should be a good estimator to monitor the KL divergence (notably because it remains non-negative, is unbiased and has low variance), but using it for optimization is a bit mysterious to me.
I'm sorry to highjack this issue, but I haven't seen a lot of discussions on the use of KL3 in GRPO.
While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence.
If I am not mistaken, I believe that the gradient of KL3 is also not an unbiased estimator of the gradient of the KL divergence.
E π θ [ ∇ θ ( π ref π θ − log π ref π θ − 1 ) ] = − E π θ [ ∇ θ log π θ × π ref π θ ] ≠ − E π θ [ ∇ θ log π θ × log π ref π θ ] = ∇ θ KL ( π θ ∥ π ref )
Given this, I don't fully understand how KL3 can work for optimization (maybe the lower variance alone is good enough, despite this potential bias). It should be a good estimator to monitor the KL divergence (notably because it remains non-negative, is unbiased and has low variance), but using it for optimization is a bit mysterious to me.
I'm sorry to highjack this issue, but I haven't seen a lot of discussions on the use of KL3 in GRPO.
You are right. It is biased but with a low variance. Intuitively, another reason why it works is that each term of the kl3 estimator (-log p(x)/q(x) + p(x)/q(x) -1 ) is lower bounded by 0 (same as kl2) with the equality achieved only when p(x)=q(x). So, if you try to estimate kl with just one sample and optimize the kl3 estimator for that sample, you will end up making the two probs equal for that sample.
On the other hand, if you try to minimize the kl1 estimator for a single sample (-log p(x)/q(x)), you will just end up maximizing q(x). This is why optimizing kl1 leads to blow-up in KL if beta is higher.
Also worth noting that $E_{\pi_t} \left[\frac{\pi_\theta}{\pi_{ref}} -\log \frac{\pi_\theta}{\pi_{ref}} - 1\right]$ is a valid divergence and the kl3 estimator and its gradient are unbiased estimators of this divergence and its gradient respectively.
That makes sense. In that case, I would find KL2 a more "principled" choice for optimization, since its gradient wrt. $\theta$ is an unbiased estimator of the gradient of the KL divergence (provided that it still has low variance, which seems to be validated by @vwxyzjn's experiments). But KL2 would still be a biased estimator of the KL divergence (to report metrics), so maybe optimizing KL2 (kl = kl2) and monitoring KL3 is the way to go?
But your comment on it being a valid divergence also makes a lot of sense, and maybe what the success of KL3 shows is that we should go beyond just the KL divergence.
I agree with KL2 being a more principled choice for optimization. There are works that explore f-divergences in the PPO objective https://arxiv.org/pdf/2309.16240
Also worth noting that E π t [ π θ π r e f − log π θ π r e f − 1 ] is a valid divergence and the kl3 estimator and its gradient are unbiased estimators of this divergence and its gradient respectively.
@gauravpandeyamu , Reading this thread to understand the KL-divergence issue, you mention that "E π t [ π θ π r e f − log π θ π r e f − 1 ]" is a valid divergence, can you please explain why? The expectation is taken w.r.t. \pi_t, and if I understand correctly, it is an unbiased estimator of KL(pi_theta || pi_ref) only if the expectation is w.r.t. pi_ref, given that (pi_theta/pi_ref -1) is a control variate (zero expectation). If we are taking an expectation w.r.t. pi_t, then don't we need to apply importance sampling to adjust?
Hi @vwxyzjn , you mentioned that
Overall the training score and sequence length goes up, but the downstream eval in MATH seems to suffer.
I wonder whether applying KL3 (instead of KL1) solves this issue?
I am closing this issue for being out of date.