verl icon indicating copy to clipboard operation
verl copied to clipboard

Help understanding GPU memory spikes during GRPO training.

Open ragingpandas opened this issue 9 months ago • 4 comments

Hello, I'm currently experimenting with the performance of various configurations of GRPO to run on H200 gpus (2 nodes, 8 H200s per node). Regardless of my configuration I'm seeing large, intermittent memory spikes across some of the GPUs in a node. I was wondering if someone could provide some insight on what could be causing these large spikes.

I've attached an image of the GPU Memory and my config:

Image

Image

#!/bin/bash

BASE_MODEL="Qwen/Qwen2.5-3B-Instruct"
ROLLOUT_TP_SIZE=1
EXPERIMENT_NAME="countdown-qwen2.5-3b-16k"
WANDB_PROJECT="RLCountdown"
DATA_DIR="countdown/dataset"

source scripts/base_script.sh "$@"

python3 -m trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=256 \
data.max_response_length=16384 \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=35000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['wandb'] \
trainer.experiment_name=$FULL_EXPERIMENT_NAME \
+trainer.val_before_train=True \
+trainer.raw_chat_template=False \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=2 \
trainer.save_freq=-1 \
trainer.test_freq=20 \
trainer.project_name=$WANDB_PROJECT \
actor_rollout_ref.rollout.disable_log_stats=False \
trainer.total_epochs=15 2>&1 | tee verl_demo.log

ragingpandas avatar Mar 03 '25 23:03 ragingpandas

The spike is caused by cross entropy and entropy computation in backward.

vermouth1992 avatar Mar 04 '25 01:03 vermouth1992

The spike is caused by cross entropy and entropy computation in backward.

Thank you for the reply. Do you have any suggestions on what parameters I have set that causes the cross entropy and entropy memory usage during the backward pass to spike so much?

Lowering the rollout gpu memory utilization, increasing the rollout TP, and lowering the ppo_max_token_len_per_gpu all still show these spikes and eventually result in a CUDA oom error. Or is this memory graph expected?

ragingpandas avatar Mar 04 '25 02:03 ragingpandas

I think you should lower the ppo_max_token_len_per_gpu. The KVCache and rollout weight are offloaded during training, so they won't consume GPU memory at the backward stage.

PeterSH6 avatar Mar 04 '25 17:03 PeterSH6

I see, thanks for your reply. I halved the ppo_max_token_len_per_gpu (to 17k) and the spikes reduced down to 91% of GPU memory. Given vllm rollout memory isn't a constraint I'm guessing the recommendation would be to increase actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu instead of just setting it to the same value as actor_rollout_ref.actor.ppo_max_token_len_per_gpu? Also if I'd like to scale up ppo_max_token_len_per_gpu for longer response length (or run with the same value on a larger model) what is your suggestion, when I tried this same config except with 4 nodes I still saw memory spikes and a CUDA OOM.

ragingpandas avatar Mar 04 '25 22:03 ragingpandas

I encountered the same problem. The GPU memory usage shows spikes and CUDA OOM, any solutions?

kaamosi avatar Apr 13 '25 04:04 kaamosi

I encountered the same problems. Adjusting GPU utilization, i.e. actor_rollout_ref.rollout.gpu_memory_utilization=0.8, solved my case.

HorHang avatar May 08 '25 04:05 HorHang

I encountered the same problem. I tried decrease gpu actor_rollout_ref.rollout.gpu_memory_utilization and decrease actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu , but they were not working..

zizi0123 avatar Jul 18 '25 17:07 zizi0123

I have encountered the same problem, and not all actor updates will cause this. When memory spikes occured, the prompt and response length may not be very high compared to that when it didn't occur. Do you have any solutions?

Purewhiter avatar Oct 19 '25 14:10 Purewhiter