Help understanding GPU memory spikes during GRPO training.
Hello, I'm currently experimenting with the performance of various configurations of GRPO to run on H200 gpus (2 nodes, 8 H200s per node). Regardless of my configuration I'm seeing large, intermittent memory spikes across some of the GPUs in a node. I was wondering if someone could provide some insight on what could be causing these large spikes.
I've attached an image of the GPU Memory and my config:
#!/bin/bash
BASE_MODEL="Qwen/Qwen2.5-3B-Instruct"
ROLLOUT_TP_SIZE=1
EXPERIMENT_NAME="countdown-qwen2.5-3b-16k"
WANDB_PROJECT="RLCountdown"
DATA_DIR="countdown/dataset"
source scripts/base_script.sh "$@"
python3 -m trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=256 \
data.max_response_length=16384 \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=35000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['wandb'] \
trainer.experiment_name=$FULL_EXPERIMENT_NAME \
+trainer.val_before_train=True \
+trainer.raw_chat_template=False \
trainer.n_gpus_per_node=$N_GPUS \
trainer.nnodes=2 \
trainer.save_freq=-1 \
trainer.test_freq=20 \
trainer.project_name=$WANDB_PROJECT \
actor_rollout_ref.rollout.disable_log_stats=False \
trainer.total_epochs=15 2>&1 | tee verl_demo.log
The spike is caused by cross entropy and entropy computation in backward.
The spike is caused by cross entropy and entropy computation in backward.
Thank you for the reply. Do you have any suggestions on what parameters I have set that causes the cross entropy and entropy memory usage during the backward pass to spike so much?
Lowering the rollout gpu memory utilization, increasing the rollout TP, and lowering the ppo_max_token_len_per_gpu all still show these spikes and eventually result in a CUDA oom error. Or is this memory graph expected?
I think you should lower the ppo_max_token_len_per_gpu. The KVCache and rollout weight are offloaded during training, so they won't consume GPU memory at the backward stage.
I see, thanks for your reply. I halved the ppo_max_token_len_per_gpu (to 17k) and the spikes reduced down to 91% of GPU memory. Given vllm rollout memory isn't a constraint I'm guessing the recommendation would be to increase actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu instead of just setting it to the same value as actor_rollout_ref.actor.ppo_max_token_len_per_gpu? Also if I'd like to scale up ppo_max_token_len_per_gpu for longer response length (or run with the same value on a larger model) what is your suggestion, when I tried this same config except with 4 nodes I still saw memory spikes and a CUDA OOM.
I encountered the same problem. The GPU memory usage shows spikes and CUDA OOM, any solutions?
I encountered the same problems. Adjusting GPU utilization, i.e. actor_rollout_ref.rollout.gpu_memory_utilization=0.8, solved my case.
I encountered the same problem. I tried decrease gpu actor_rollout_ref.rollout.gpu_memory_utilization and decrease actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu , but they were not working..
I have encountered the same problem, and not all actor updates will cause this. When memory spikes occured, the prompt and response length may not be very high compared to that when it didn't occur. Do you have any solutions?