open-instruct icon indicating copy to clipboard operation
open-instruct copied to clipboard

GRPO implementation update

Open vwxyzjn opened this issue 10 months ago • 19 comments

Let's use this issue to share the latest GRPO development updates. CC @gauravpandeyamu thanks for your fix.

The command below (without @gauravpandeyamu's fix) yields the charts below. Overall the training score and sequence length goes up, but the downstream eval in MATH seems to suffer. I am gonna try out @gauravpandeyamu's fix on KL regularization to see if it helps.

for beta in 0.03 0.01 0.0; do
for nspp in 16 32; do
for m in half-m ; do
local_rollout_batch_size=8
if [ $m == "half-m" ]; then
    local_mini_batch_size=$(($local_rollout_batch_size * $nspp / 2))
else
    local_mini_batch_size=$(($local_rollout_batch_size * $nspp))
fi
exp_name="0128_grpo_math_zs_${beta}_${nspp}_${m}_${RANDOM}"
full_bsz=$(($local_rollout_batch_size * nspp * (8 + 8 + 8 + 7) * 2))
echo $exp_name:
echo --- local_mini_batch_size=$local_mini_batch_size
echo --- full_bsz=$full_bsz
echo --- num_gradient_updates=$(($local_rollout_batch_size * $nspp / $local_mini_batch_size))
exp_name="0128_grpo_math_zs_${beta}_${nspp}_${RANDOM}"
python mason.py \
    --cluster ai2/jupiter-cirrascale-2 \
    --workspace ai2/tulu-3-dev \
    --priority high \
    --preemptible \
    --num_nodes 1 \
    --max_retries 1 \
    --budget ai2/oe-adapt \
    --gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& uv run python open_instruct/grpo_vllm_thread_ray_gtrl.py \
    --exp_name $exp_name \
    --beta $beta \
    --local_mini_batch_size $local_mini_batch_size \
    --number_samples_per_prompt $nspp \
    --output_dir /weka/oe-adapt-default/costah/models/$exp_name \
    --local_rollout_batch_size $local_rollout_batch_size \
    --dataset_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 1.0}" \
    --dataset_train_splits train \
    --dataset_eval_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 32}" \
    --dataset_eval_splits train \
    --max_token_length 2048 \
    --max_prompt_token_length 2048 \
    --response_length 2048 \
    --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \
    --non_stop_penalty \
    --stop_token eos \
    --temperature 1.0 \
    --ground_truths_key ground_truth \
    --chat_template tulu \
    --sft_messages_key messages \
    --learning_rate 5e-7 \
    --total_episodes 1000000 \
    --penalty_reward_value 0.0 \
    --deepspeed_stage 2 \
    --per_device_train_batch_size 2 \
    --local_rollout_forward_batch_size 2 \
    --actor_num_gpus_per_node 6 \
    --num_epochs 1 \
    --vllm_tensor_parallel_size 2 \
    --lr_scheduler_type constant \
    --apply_verifiable_reward true \
    --seed 1 \
    --num_evals 1000 \
    --save_freq 40 \
    --reward_model_multiplier 0.0 \
    --no_try_launch_beaker_eval_jobs \
    --try_launch_beaker_eval_jobs_on_weka \
    --gradient_checkpointing \
    --with_tracking
done
done
done
Image Image

vwxyzjn avatar Jan 29 '25 21:01 vwxyzjn

The KL constraint didn't work well. 0129_grpo_math_kl_fix_zs_0.03_16_half-m_30414__1__1738185952 which includes the KL fix seems to have even higher KL than the previously incorrect run.

Image

vwxyzjn avatar Jan 30 '25 13:01 vwxyzjn

That's a good catch. While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence. I suspect that you are using the kl1 estimator.

The gradient of the kl divergence between the policy $\pi_\theta$ and $\pi_{ref}$, the reference policy, is given by $$\nabla_\theta \sum_{y} \pi_\theta (y|x) \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}$$

If the samples $y_1, ..., y_n$ come from $\pi_t(y|x)$, the gradient can be computed as $$\nabla_\theta \frac{1}{n} \sum_{i=1}^n \frac{\pi_\theta (y|x)}{\pi_t(y|x)} \log \frac{\pi_\theta(y|x)}{\pi_{ref}(y|x)}$$

This is what kl1 should have been. kl2 and kl3 are still fine.

BTW, the GRPO paper recommends using kl3 estimator (Equation 4 in https://arxiv.org/pdf/2402.03300)

gauravpandeyamu avatar Jan 30 '25 15:01 gauravpandeyamu

Yeah, so I thought about the kl3. The issue is that kl3 blows up: see

https://github.com/huggingface/trl/pull/423#issuecomment-1584590437

but this only happens when using KL as part of the reward. I am not sure what happens when we put it in the loss directly...

vwxyzjn avatar Jan 30 '25 16:01 vwxyzjn

Interesting. I agree that it might work if added in the loss directly.

I have modified my local fork with the following changes:

                    # kl loss should be computed without torch.no_grad()
                    ref_logprobs_diff = new_logprobs - ref_logprobs[micro_batch_inds]
                    kl1 = ratio * ref_logprobs_diff
                    kl2 = (ref_logprobs_diff) ** 2 / 2
                    kl3 = (-ref_logprobs_diff).exp() - 1 + ref_logprobs_diff
                    if args.kl_estimator == "kl1":
                        kl = kl1
                    elif args.kl_estimator == "kl2":
                        kl = kl2
                    elif args.kl_estimator == "kl3":
                        kl = kl3

                    kl_loss = masked_mean(kl, ~padding_mask[micro_batch_inds])
                    pg_loss = pg_loss + args.beta * kl_loss

Basically, it gives a kl1 estimator whose gradient is unbiased. I haven't created a new PR since it has conflicting changes with your PR.

gauravpandeyamu avatar Jan 30 '25 16:01 gauravpandeyamu

I can help test it out. I modified it to the following and launched a scan. Will report back with results.

                        kl1 = new_logprobs - mb_reflogprobs
                        kl2 = (kl1) ** 2 / 2
                        kl3 = (-kl1).exp() - 1 + kl1
                        kl4 = ratio * kl1
                        if args.kl_estimator == "kl1":
                            kl = kl1
                        elif args.kl_estimator == "kl2":
                            kl = kl2
                        elif args.kl_estimator == "kl3":
                            kl = kl3
                        elif args.kl_estimator == "kl4":
                            kl = kl4

                        if epoch_idx == 0:
                            kl_stats[micro_batch_inds] = kl.sum(1).float()
                        
                        # grpo change: directly subtract KL in loss (add)
                        pg_loss = masked_mean(pg_loss_max + (args.beta * kl), ~padding_mask[micro_batch_inds])
for beta in 0.03; do
for nspp in 16; do
for m in half-m ; do
for kl_estimator in k1 k2 k3 k4; do
local_rollout_batch_size=8
if [ $m == "half-m" ]; then
    local_mini_batch_size=$(($local_rollout_batch_size * $nspp / 2))
else
    local_mini_batch_size=$(($local_rollout_batch_size * $nspp))
fi
exp_name="0130_kl_scan_grpo_math_zs_${kl_estimator}_${beta}_${nspp}_${m}_${RANDOM}"
full_bsz=$(($local_rollout_batch_size * nspp * (8 + 8 + 8 + 7) * 2))
echo $exp_name:
echo --- local_mini_batch_size=$local_mini_batch_size
echo --- full_bsz=$full_bsz
echo --- num_gradient_updates=$(($local_rollout_batch_size * $nspp / $local_mini_batch_size))
python mason.py \
    --cluster ai2/jupiter-cirrascale-2 \
    --workspace ai2/tulu-3-dev \
    --priority high \
    --preemptible \
    --num_nodes 1 \
    --max_retries 1 \
    --budget ai2/oe-adapt \
    --gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& uv run python open_instruct/grpo_vllm_thread_ray_gtrl.py \
    --exp_name $exp_name \
    --beta $beta \
    --local_mini_batch_size $local_mini_batch_size \
    --number_samples_per_prompt $nspp \
    --output_dir /weka/oe-adapt-default/costah/models/$exp_name \
    --local_rollout_batch_size $local_rollout_batch_size \
    --dataset_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 1.0}" \
    --dataset_train_splits train \
    --dataset_eval_mixer "{\"ai2-adapt-dev/math_ground_truth_zs\": 32}" \
    --dataset_eval_splits train \
    --max_token_length 2048 \
    --max_prompt_token_length 2048 \
    --response_length 2048 \
    --model_name_or_path allenai/Llama-3.1-Tulu-3-8B-DPO \
    --non_stop_penalty \
    --stop_token eos \
    --temperature 1.0 \
    --ground_truths_key ground_truth \
    --chat_template tulu \
    --sft_messages_key messages \
    --learning_rate 5e-7 \
    --total_episodes 1000000 \
    --penalty_reward_value 0.0 \
    --deepspeed_stage 2 \
    --per_device_train_batch_size 2 \
    --local_rollout_forward_batch_size 2 \
    --actor_num_gpus_per_node 6 \
    --num_epochs 1 \
    --vllm_tensor_parallel_size 2 \
    --lr_scheduler_type constant \
    --apply_verifiable_reward true \
    --seed 1 \
    --num_evals 1000 \
    --save_freq 40 \
    --reward_model_multiplier 0.0 \
    --no_try_launch_beaker_eval_jobs \
    --try_launch_beaker_eval_jobs_on_weka \
    --gradient_checkpointing \
    --with_tracking
done
done
done
done

vwxyzjn avatar Jan 30 '25 17:01 vwxyzjn

Great. Thanks

gauravpandeyamu avatar Jan 30 '25 19:01 gauravpandeyamu

Eh this is a bit awkward -- none of the KL estimator seem to control KL well.

Image

vwxyzjn avatar Jan 30 '25 19:01 vwxyzjn

Since RLVR looks only at the final answer completely ignoring the the text generated, the model can be motivated to generate text that deviates from the reference policy as long as it reaches the correct final answer.

Perhaps, a stronger KL term (higher beta value) is desirable in RLVR, thus ensuring that the text generated is sensible.

gauravpandeyamu avatar Jan 31 '25 06:01 gauravpandeyamu

FYI on KL3 estimator. https://x.com/vwxyzjn/status/1885329398821187633

vwxyzjn avatar Jan 31 '25 14:01 vwxyzjn

Ok so when I launched the experiments I accidentally left out the --kl_estimator 🤡, so all exps were run using kl1.

Now when using the kl3 it looks much more reasonable.

Image

vwxyzjn avatar Jan 31 '25 16:01 vwxyzjn

Meanwhile

  1. kl1 looks wrong (larger beta induces larger KL...???)
Image
  1. kl2 also looks reasonable Image

  2. kl4 seems ok

Image

@gauravpandeyamu why did you multiply the ref_logprobs_diff by ratio? I don't get it.

vwxyzjn avatar Jan 31 '25 16:01 vwxyzjn

Ahh, yes. Now, the graphs of kl1, kl2, kl3 and kl4 make perfect sense.

As for why kl1 estimator is a bad estimator, and how multiplying by the ratio fixes it (mainly the bias is fixed), here is ChatGPT's response.

https://chatgpt.com/share/679dfb75-cdec-800f-9078-f838d3925f9e

gauravpandeyamu avatar Feb 01 '25 10:02 gauravpandeyamu

While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence.

If I am not mistaken, I believe that the gradient of KL3 is also not an unbiased estimator of the gradient of the KL divergence.

$$\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta}\left(\frac{\pi_{\mathrm{ref}}}{\pi_{\theta}} - \log \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}} - 1\right)\right] = -\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta}\log \pi_{\theta} \times \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}}\right] \neq -\mathbb{E}_{\pi_{\theta}}\left[\nabla_{\theta} \log \pi_{\theta} \times \log \frac{\pi_{\mathrm{ref}}}{\pi_{\theta}}\right] = \nabla_{\theta}\mathrm{KL}(\pi_{\theta}\|\pi_{\mathrm{ref}})$$

Given this, I don't fully understand how KL3 can work for optimization (maybe the lower variance alone is good enough, despite this potential bias). It should be a good estimator to monitor the KL divergence (notably because it remains non-negative, is unbiased and has low variance), but using it for optimization is a bit mysterious to me.

I'm sorry to highjack this issue, but I haven't seen a lot of discussions on the use of KL3 in GRPO.

tristandeleu avatar Feb 04 '25 16:02 tristandeleu

While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence.

If I am not mistaken, I believe that the gradient of KL3 is also not an unbiased estimator of the gradient of the KL divergence.

E π θ [ ∇ θ ( π ref π θ − log ⁡ π ref π θ − 1 ) ] = − E π θ [ ∇ θ log ⁡ π θ × π ref π θ ] ≠ − E π θ [ ∇ θ log ⁡ π θ × log ⁡ π ref π θ ] = ∇ θ KL ( π θ ∥ π ref )

Given this, I don't fully understand how KL3 can work for optimization (maybe the lower variance alone is good enough, despite this potential bias). It should be a good estimator to monitor the KL divergence (notably because it remains non-negative, is unbiased and has low variance), but using it for optimization is a bit mysterious to me.

I'm sorry to highjack this issue, but I haven't seen a lot of discussions on the use of KL3 in GRPO.

You are right. It is biased but with a low variance. Intuitively, another reason why it works is that each term of the kl3 estimator (-log p(x)/q(x) + p(x)/q(x) -1 ) is lower bounded by 0 (same as kl2) with the equality achieved only when p(x)=q(x). So, if you try to estimate kl with just one sample and optimize the kl3 estimator for that sample, you will end up making the two probs equal for that sample.

On the other hand, if you try to minimize the kl1 estimator for a single sample (-log p(x)/q(x)), you will just end up maximizing q(x). This is why optimizing kl1 leads to blow-up in KL if beta is higher.

gauravpandeyamu avatar Feb 04 '25 16:02 gauravpandeyamu

Also worth noting that $E_{\pi_t} \left[\frac{\pi_\theta}{\pi_{ref}} -\log \frac{\pi_\theta}{\pi_{ref}} - 1\right]$ is a valid divergence and the kl3 estimator and its gradient are unbiased estimators of this divergence and its gradient respectively.

gauravpandeyamu avatar Feb 04 '25 17:02 gauravpandeyamu

That makes sense. In that case, I would find KL2 a more "principled" choice for optimization, since its gradient wrt. $\theta$ is an unbiased estimator of the gradient of the KL divergence (provided that it still has low variance, which seems to be validated by @vwxyzjn's experiments). But KL2 would still be a biased estimator of the KL divergence (to report metrics), so maybe optimizing KL2 (kl = kl2) and monitoring KL3 is the way to go?

But your comment on it being a valid divergence also makes a lot of sense, and maybe what the success of KL3 shows is that we should go beyond just the KL divergence.

tristandeleu avatar Feb 04 '25 17:02 tristandeleu

I agree with KL2 being a more principled choice for optimization. There are works that explore f-divergences in the PPO objective https://arxiv.org/pdf/2309.16240

gauravpandeyamu avatar Feb 04 '25 17:02 gauravpandeyamu

Also worth noting that E π t [ π θ π r e f − log ⁡ π θ π r e f − 1 ] is a valid divergence and the kl3 estimator and its gradient are unbiased estimators of this divergence and its gradient respectively.

@gauravpandeyamu , Reading this thread to understand the KL-divergence issue, you mention that "E π t [ π θ π r e f − log ⁡ π θ π r e f − 1 ]" is a valid divergence, can you please explain why? The expectation is taken w.r.t. \pi_t, and if I understand correctly, it is an unbiased estimator of KL(pi_theta || pi_ref) only if the expectation is w.r.t. pi_ref, given that (pi_theta/pi_ref -1) is a control variate (zero expectation). If we are taking an expectation w.r.t. pi_t, then don't we need to apply importance sampling to adjust?

shashankg7 avatar Feb 13 '25 20:02 shashankg7

Hi @vwxyzjn , you mentioned that

Overall the training score and sequence length goes up, but the downstream eval in MATH seems to suffer.

I wonder whether applying KL3 (instead of KL1) solves this issue?

lkevinzc avatar Feb 15 '25 03:02 lkevinzc

I am closing this issue for being out of date.

finbarrtimbers avatar Oct 29 '25 19:10 finbarrtimbers