[BUG] Qwen3 MoE with FSDP2 meets `torch.utils.checkpoint.CheckpointError` when `offload_policy=True`
Hi, I encountered a problem while trying to run the following command: 我在尝试运行下面的命令时遇到了问题:
export MODEL_PATH=/root/Qwen3-30B-A3B
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=256 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.actor.strategy=fsdp2 \
actor_rollout_ref.actor.fsdp_config.offload_policy=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.strategy=fsdp2 \
critic.model.fsdp_config.offload_policy=True \
critic.optim.lr=1e-5 \
critic.model.path=$MODEL_PATH \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=console \
trainer.val_before_train=False \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=2 \
trainer.test_freq=100 \
trainer.total_epochs=15 2>&1 | tee verl_demo.log
The error messages are as follows: 错误信息如下所示:
(TaskRunner pid=2992930) Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::WorkerDict.critic_update_critic() (pid=2993413, ip=10.254.251.252, actor_id=f7b99069e6cf20b344806be602000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7ee4533c6120>)
(TaskRunner pid=2992930) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=2992930) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=2992930) File "/mnt/data/chenyushuo.cys/verl/verl/single_controller/ray/base.py", line 705, in func
(TaskRunner pid=2992930) return getattr(self.worker_dict[key], name)(*args, **kwargs)
(TaskRunner pid=2992930) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=2992930) File "/mnt/data/chenyushuo.cys/verl/verl/single_controller/base/decorator.py", line 514, in inner
(TaskRunner pid=2992930) return func(*args, **kwargs)
(TaskRunner pid=2992930) ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=2992930) File "/mnt/data/chenyushuo.cys/verl/verl/workers/fsdp_workers.py", line 1244, in update_critic
(TaskRunner pid=2992930) metrics = self.critic.update_critic(data=data)
(TaskRunner pid=2992930) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=2992930) File "/mnt/data/chenyushuo.cys/verl/verl/utils/profiler/performance.py", line 89, in f
(TaskRunner pid=2992930) return self.log(decorated_function, *args, **kwargs)
(TaskRunner pid=2992930) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=2992930) File "/mnt/data/chenyushuo.cys/verl/verl/utils/profiler/performance.py", line 102, in log
(TaskRunner pid=2992930) output = func(*args, **kwargs)
(TaskRunner pid=2992930) ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=2992930) File "/mnt/data/chenyushuo.cys/verl/verl/workers/critic/dp_critic.py", line 240, in update_critic
(TaskRunner pid=2992930) loss.backward()
(TaskRunner pid=2992930) File "/root/miniforge3/envs/megatron/lib/python3.12/site-packages/torch/_tensor.py", line 648, in backward
(TaskRunner pid=2992930) torch.autograd.backward(
(TaskRunner pid=2992930) File "/root/miniforge3/envs/megatron/lib/python3.12/site-packages/torch/autograd/__init__.py", line 353, in backward
(TaskRunner pid=2992930) _engine_run_backward(
(TaskRunner pid=2992930) File "/root/miniforge3/envs/megatron/lib/python3.12/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
(TaskRunner pid=2992930) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
(TaskRunner pid=2992930) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=2992930) File "/root/miniforge3/envs/megatron/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 1128, in unpack_hook
(TaskRunner pid=2992930) frame.check_recomputed_tensors_match(gid)
(TaskRunner pid=2992930) File "/root/miniforge3/envs/megatron/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 864, in check_recomputed_tensors_match
(TaskRunner pid=2992930) raise CheckpointError(
(TaskRunner pid=2992930) torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
(TaskRunner pid=2992930) Number of tensors saved during forward: 1750
(TaskRunner pid=2992930) Number of tensors saved during recomputation: 310
(WorkerDict pid=2993220) kwargs: {'n': 1, 'logprobs': 0, 'max_tokens': 256, 'detokenize': False, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 7x across cluster]
The package versions are as follows: 版本信息如下所示:
$ python -V
Python 3.12.11
$ pip list | grep -E 'verl|torch|transformers|flash'
flash_attn 2.8.1
torch 2.7.1
torchaudio 2.7.1
torchdata 0.11.0
torchvision 0.22.1
transformer_engine_torch 2.6.0.post1
transformers 4.53.3
verl 0.5.0
+1. ran into same issue with FSDP2
使用qwen3-30B-moe with FSDP2进行sft时,也遇到相同的问题;
+1, the same issue with FSDP2 + activation checkpoint.
+1, sft meets same issue
any solution on this? other than disabling checkpointing (which causes oom)?
same issue, any update?
I think there is a chance that setting use_reentrant=True in the fsdp worker could fix it, but I personally have just seen really bad performance with fsdp2 for large MOE models. Using Megatron is probably your best bet for now.
Also experiencing similar issues
I think there is a chance that setting use_reentrant=True in the fsdp worker could fix it, but I personally have just seen really bad performance with fsdp2 for large MOE models. Using Megatron is probably your best bet for now.
Hi @koceja ! What do you mean by "bad performance"? Are you referring to precision, speed, or something else?