verl icon indicating copy to clipboard operation
verl copied to clipboard

PPO Training Hangs at Step 0 when use_remove_padding

Open maksimstw opened this issue 9 months ago • 4 comments

When training the Qwen 2.5 7B Math model using the example script below, the training process consistently hangs at step 0, with GPU utilization dropping to 0%. This issue occurs when using 8 A100 GPUs on a single node. However, if use_remove_padding is set to False, the training proceeds without any problems. What might be the issue?

set -x

gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
math_train_path=$HOME/data/math/train.parquet
math_test_path=$HOME/data/math/test.parquet

train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"

python3 -m verl.trainer.main_ppo \
    data.train_files="$train_files" \
    data.val_files="$test_files" \
    data.train_batch_size=1024 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
    actor_rollout_ref.model.enable_gradient_checkpointing=False \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    critic.optim.lr=1e-5 \
    critic.model.use_remove_padding=True \
    critic.model.path=Qwen/Qwen2.5-7B-Instruct \
    critic.model.enable_gradient_checkpointing=False \
    critic.ppo_micro_batch_size_per_gpu=8 \
    critic.model.fsdp_config.param_offload=False \
    critic.model.fsdp_config.optimizer_offload=False \
    algorithm.kl_ctrl.kl_coef=0.0001 \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_example' \
    trainer.experiment_name='Qwen2.5-7B-Instruct_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=4 \
    trainer.save_freq=-1 \
    trainer.test_freq=10 \
    trainer.total_epochs=15 $@

maksimstw avatar Feb 26 '25 03:02 maksimstw

Could you check with py-spy dump --pid xxx, or run with breakpoint and ray debug (see faq page) to see where the program hangs at?

eric-haibin-lin avatar Feb 27 '25 21:02 eric-haibin-lin

After some debugging, I found that enabling use_remove_padding for critic does not hang the training. Enabling use_remove_padding for actor does. It hangs at this line.

Image

maksimstw avatar Mar 02 '25 00:03 maksimstw

After even more debugging, I found that if we modify self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) to self.compute_entropy_from_logits = verl_F.entropy_from_logits

the programs can run with no issues. I also tried setting dynamic=False but it still hangs. How much speedup do we get from torch.compile?

maksimstw avatar Mar 02 '25 01:03 maksimstw

After even more debugging, I found that if we modify self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) to self.compute_entropy_from_logits = verl_F.entropy_from_logits

the programs can run with no issues. I also tried setting dynamic=False but it still hangs. How much speedup do we get from torch.compile?

Hi! I think the primary purpose of using torch.compile here is to reduce intermediate memory usage. While it indeed provides speedup, the impact on the overal program is minor (as it is not called too often). If memory overflow isn't an issue, you can safely remove it. It would definitely be preferable to use an alternative operator rather than relying on torch.compile.

Yifei-Zuo avatar Mar 04 '25 02:03 Yifei-Zuo

We can now disable torch compile by setting a flag in the config file. #554

maksimstw avatar Apr 09 '25 05:04 maksimstw