verl
verl copied to clipboard
[Bug]FSDP2 failed to load large model state_dict
I found that FSDP2 failed to load large(32B or 72B) model state_dict. And it works after I changed the "fsdp2" part in the cmd below to "fsdp":
actor_rollout_ref.actor.strategy=fsdp \
actor_rollout_ref.ref.strategy=fsdp \
Config
- GPU: 64 H20-96GB cards
- model: Qwen2.5-VL-32B
- dataset: geo3k
Command w/ FSDP2
MODEL_PATH=/path/to/Qwen2.5-VL-32B
DATA_PATH=/path/to/geo3k
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$DATA_PATH/train.parquet \
data.val_files=$DATA_PATH/test.parquet \
data.train_batch_size=512 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.image_key=images \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.strategy=fsdp2 \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=5\
actor_rollout_ref.ref.strategy=fsdp2 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
trainer.critic_warmup=0 \
trainer.logger=['console','tensorboard'] \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen2_5_vl_32b_function_rm' \
trainer.n_gpus_per_node=8\
trainer.nnodes=8 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=5$@
Error w/ FSDP2
File "/code/verl/single_controller/ray/base.py", line 466, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
File "/code/verl/single_controller/base/decorator.py", line 501, in inner
return func(*args, **kwargs)
File "/code/verl/workers/fsdp_workers.py", line 487, in init_model
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
File "/code/verl/workers/fsdp_workers.py", line 301, in _build_model_optimizer
fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
File "/code/verl/utils/fsdp_utils.py", line 411, in fsdp2_load_full_state_dict
set_model_state_dict(model, full_state, options=options)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict.py", line 1207, in set_model_state_dict
return _load_model_state_dict(model, model_state_dict, info)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict.py", line 569, in _load_model_state_dict
_broadcast_state_dict(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_state_dict_utils.py", line 616, in _broadcast_state_dict
_broadcast_tensors(ret, local_state_dict, keys, device, pg)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_state_dict_utils.py", line 514, in _broadcast_tensors
full_tensor = torch.empty(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 540.00 MiB. GPU 0 has a total capacity of 95.22 GiB of which 385.56 MiB is free. Process 2432532 has 94.84 GiB memory in use. Of the allocated memory 92.26 GiB is allocated by PyTorch, and 1.39 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Attach the GPU memory profiling:
-
FSDP2
-
FSDP1
When I use fsdp2, I found that the same script loading checkpoint with fsdp works, but with fsdp2 will lead to the OOM error. Maybe it's the same cause of this issue, I will track this issue and try to fix this.
I encountered similar issues when launching and building the optimizer.
same issue
Located at:
https://github.com/volcengine/verl/blob/15b1b15f9963e178b68368c9b3996c60637a5156/verl/utils/fsdp_utils.py#L392-L420
this function consumes a large amount of memory after loading the state dict. Taking Qwen2.5-coder-instruct-7B as an example using four GPUs, FSDP1 uses 6.59G on rank0 after completing critic loading, while FSDP2 uses ~32G on rank0 after using this function. Not sure why currently.
Additionally, FSDP2's offload doesn't seem to change GPU memory usage: https://github.com/volcengine/verl/blob/15b1b15f9963e178b68368c9b3996c60637a5156/verl/utils/fsdp_utils.py#L150-L155
I made a temporary modification to the current code (https://github.com/0x404/verl/commit/88cf2daa2d8bed9439e6bc2a314f73c2aee77791), and after the change, FSDP2's loading process uses the same amount of GPU memory as FSDP1.
However, this change is temporary because it requires each rank to load the hf model to CPU, which consumes a large amount of CPU memory. I will keep tracking and find out why fsdp2_load_full_state_dict consumes so much GPU memory. If anyone has new findings or solution, please let me know.
Hi all, I made a fix in https://github.com/volcengine/verl/pull/1667 which should solve the oom issue, you may want to give it a try.
Located at:
Lines 392 to 420 in 15b1b15
def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): """ Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place.
Args: model (`torch.nn.Module`): The model to load the state dict into full_state (`dict`): The full state dict to load, can only be on rank 0 """ from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict # To broadcast, it needs to be instantiated in the GPU. if dist.get_rank() == 0: model = model.to(device=torch.cuda.current_device(), non_blocking=True) else: model = model.to_empty(device=torch.cuda.current_device()) cpu_offload = cpu_offload is not None options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) set_model_state_dict(model, full_state, options=options) # rotary_emb is not in state_dict, so we need to broadcast it manually for name, buf in model.named_buffers(): dist.broadcast(buf, src=0) if cpu_offload: model.to("cpu", non_blocking=True) for buf in model.buffers(): buf.data = buf.data.to(torch.cuda.current_device())this function consumes a large amount of memory after loading the state dict. Taking Qwen2.5-coder-instruct-7B as an example using four GPUs, FSDP1 uses 6.59G on rank0 after completing critic loading, while FSDP2 uses ~32G on rank0 after using this function. Not sure why currently.
Additionally, FSDP2's offload doesn't seem to change GPU memory usage:
Lines 150 to 155 in 15b1b15
@torch.no_grad() def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): for param in model.parameters(): param.data = param.data.to(torch.device("cpu"), non_blocking=True) if empty_cache: torch.cuda.empty_cache() I made a temporary modification to the current code (0x404@88cf2da), and after the change, FSDP2's loading process uses the same amount of GPU memory as FSDP1.
However, this change is temporary because it requires each rank to load the hf model to CPU, which consumes a large amount of CPU memory. I will keep tracking and find out why
fsdp2_load_full_state_dictconsumes so much GPU memory. If anyone has new findings or solution, please let me know.
same issure,but oom at set_model_state_dict. your code does not workes in my scenarios(32B,npu 64GB memory). sadness
same issue
@wconstab @mori360 @weifengpy could you help advice?
@eric-haibin-lin I will take a look. thanks for mentioning us
we have a bug in torch==2.6.0 in torch.distributed.checkpoint.state_dict.set_model_state_dict (see memory usage in the right figure) https://github.com/pytorch/pytorch/pull/134025
- gpu memory got accumulated during broadcasting
- and cpu offloading does not work
switching to torch=2.7.0 resolved the problem (see memory usage in the left figure)
@eric-haibin-lin would it be possible to upgrade pytorch to 2.7.0? if it has to be 2.6.0, we can land a fixed set_model_state_dict in verl. cc @mori360 @fegin
Fixed in https://github.com/volcengine/verl/pull/2606