verl icon indicating copy to clipboard operation
verl copied to clipboard

[Bug]FSDP2 failed to load large model state_dict

Open GHGmc2 opened this issue 6 months ago • 6 comments

I found that FSDP2 failed to load large(32B or 72B) model state_dict. And it works after I changed the "fsdp2" part in the cmd below to "fsdp":

actor_rollout_ref.actor.strategy=fsdp \
actor_rollout_ref.ref.strategy=fsdp \

Config

  • GPU: 64 H20-96GB cards
  • model: Qwen2.5-VL-32B
  • dataset: geo3k

Command w/ FSDP2

MODEL_PATH=/path/to/Qwen2.5-VL-32B
DATA_PATH=/path/to/geo3k

python3 -m verl.trainer.main_ppo \
      algorithm.adv_estimator=grpo \
      data.train_files=$DATA_PATH/train.parquet \
      data.val_files=$DATA_PATH/test.parquet \
      data.train_batch_size=512 \
      data.max_prompt_length=1024 \
      data.max_response_length=2048 \
      data.filter_overlong_prompts=True \
      data.truncation='error' \
      data.image_key=images \
      actor_rollout_ref.model.path=$MODEL_PATH \
      actor_rollout_ref.model.use_remove_padding=True \
      actor_rollout_ref.model.enable_gradient_checkpointing=True \
      actor_rollout_ref.actor.strategy=fsdp2 \
      actor_rollout_ref.actor.optim.lr=1e-6 \
      actor_rollout_ref.actor.ppo_mini_batch_size=128 \
      actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
      actor_rollout_ref.actor.use_kl_loss=True \
      actor_rollout_ref.actor.kl_loss_coef=0.01 \
      actor_rollout_ref.actor.kl_loss_type=low_var_kl \
      actor_rollout_ref.actor.entropy_coeff=0 \
      actor_rollout_ref.rollout.name=vllm \
      actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \
      actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
      actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
      actor_rollout_ref.rollout.enable_chunked_prefill=True \
      actor_rollout_ref.rollout.enforce_eager=False \
      actor_rollout_ref.rollout.free_cache_engine=False \
      actor_rollout_ref.rollout.n=5\
      actor_rollout_ref.ref.strategy=fsdp2 \
      actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \
      actor_rollout_ref.ref.fsdp_config.param_offload=True \
      trainer.critic_warmup=0 \
      trainer.logger=['console','tensorboard'] \
      trainer.project_name='verl_grpo_example_geo3k' \
      trainer.experiment_name='qwen2_5_vl_32b_function_rm' \
      trainer.n_gpus_per_node=8\
      trainer.nnodes=8 \
      trainer.save_freq=-1 \
      trainer.test_freq=5 \
      trainer.total_epochs=5$@

Error w/ FSDP2

   File "/code/verl/single_controller/ray/base.py", line 466, in func
     return getattr(self.worker_dict[key], name)(*args, **kwargs)
   File "/code/verl/single_controller/base/decorator.py", line 501, in inner
     return func(*args, **kwargs)
   File "/code/verl/workers/fsdp_workers.py", line 487, in init_model
     self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
   File "/code/verl/workers/fsdp_workers.py", line 301, in _build_model_optimizer
     fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
   File "/code/verl/utils/fsdp_utils.py", line 411, in fsdp2_load_full_state_dict
     set_model_state_dict(model, full_state, options=options)
   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict.py", line 1207, in set_model_state_dict
     return _load_model_state_dict(model, model_state_dict, info)
   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
     return func(*args, **kwargs)
   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict.py", line 569, in _load_model_state_dict
     _broadcast_state_dict(
   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_state_dict_utils.py", line 616, in _broadcast_state_dict
     _broadcast_tensors(ret, local_state_dict, keys, device, pg)
   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_state_dict_utils.py", line 514, in _broadcast_tensors
     full_tensor = torch.empty(
 torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 540.00 MiB. GPU 0 has a total capacity of 95.22 GiB of which 385.56 MiB is free. Process 2432532 has 94.84 GiB memory in use. Of the allocated memory 92.26 GiB is allocated by PyTorch, and 1.39 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Attach the GPU memory profiling:

  • FSDP2 Image

  • FSDP1 Image

GHGmc2 avatar May 14 '25 08:05 GHGmc2

When I use fsdp2, I found that the same script loading checkpoint with fsdp works, but with fsdp2 will lead to the OOM error. Maybe it's the same cause of this issue, I will track this issue and try to fix this.

0x404 avatar May 19 '25 16:05 0x404

I encountered similar issues when launching and building the optimizer.

SparkJiao avatar May 20 '25 02:05 SparkJiao

same issue

puppet101 avatar May 20 '25 07:05 puppet101

Located at:

https://github.com/volcengine/verl/blob/15b1b15f9963e178b68368c9b3996c60637a5156/verl/utils/fsdp_utils.py#L392-L420

this function consumes a large amount of memory after loading the state dict. Taking Qwen2.5-coder-instruct-7B as an example using four GPUs, FSDP1 uses 6.59G on rank0 after completing critic loading, while FSDP2 uses ~32G on rank0 after using this function. Not sure why currently.

Additionally, FSDP2's offload doesn't seem to change GPU memory usage: https://github.com/volcengine/verl/blob/15b1b15f9963e178b68368c9b3996c60637a5156/verl/utils/fsdp_utils.py#L150-L155

I made a temporary modification to the current code (https://github.com/0x404/verl/commit/88cf2daa2d8bed9439e6bc2a314f73c2aee77791), and after the change, FSDP2's loading process uses the same amount of GPU memory as FSDP1.

However, this change is temporary because it requires each rank to load the hf model to CPU, which consumes a large amount of CPU memory. I will keep tracking and find out why fsdp2_load_full_state_dict consumes so much GPU memory. If anyone has new findings or solution, please let me know.

0x404 avatar May 20 '25 11:05 0x404

Hi all, I made a fix in https://github.com/volcengine/verl/pull/1667 which should solve the oom issue, you may want to give it a try.

0x404 avatar May 24 '25 08:05 0x404

Located at:

verl/verl/utils/fsdp_utils.py

Lines 392 to 420 in 15b1b15

def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): """ Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place.

 Args: 
     model (`torch.nn.Module`): The model to load the state dict into 
     full_state (`dict`): The full state dict to load, can only be on rank 0 
 """ 
 from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict 

 # To broadcast, it needs to be instantiated in the GPU. 
 if dist.get_rank() == 0: 
     model = model.to(device=torch.cuda.current_device(), non_blocking=True) 
 else: 
     model = model.to_empty(device=torch.cuda.current_device()) 

 cpu_offload = cpu_offload is not None 
 options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) 
 set_model_state_dict(model, full_state, options=options) 

 # rotary_emb is not in state_dict, so we need to broadcast it manually 
 for name, buf in model.named_buffers(): 
     dist.broadcast(buf, src=0) 

 if cpu_offload: 
     model.to("cpu", non_blocking=True) 
     for buf in model.buffers(): 
         buf.data = buf.data.to(torch.cuda.current_device()) 

this function consumes a large amount of memory after loading the state dict. Taking Qwen2.5-coder-instruct-7B as an example using four GPUs, FSDP1 uses 6.59G on rank0 after completing critic loading, while FSDP2 uses ~32G on rank0 after using this function. Not sure why currently.

Additionally, FSDP2's offload doesn't seem to change GPU memory usage:

verl/verl/utils/fsdp_utils.py

Lines 150 to 155 in 15b1b15

@torch.no_grad() def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): for param in model.parameters(): param.data = param.data.to(torch.device("cpu"), non_blocking=True) if empty_cache: torch.cuda.empty_cache() I made a temporary modification to the current code (0x404@88cf2da), and after the change, FSDP2's loading process uses the same amount of GPU memory as FSDP1.

However, this change is temporary because it requires each rank to load the hf model to CPU, which consumes a large amount of CPU memory. I will keep tracking and find out why fsdp2_load_full_state_dict consumes so much GPU memory. If anyone has new findings or solution, please let me know.

same issure,but oom at set_model_state_dict. your code does not workes in my scenarios(32B,npu 64GB memory). sadness

PrometheusComing avatar May 26 '25 08:05 PrometheusComing

same issue

lpc-eol avatar Jun 29 '25 07:06 lpc-eol

@wconstab @mori360 @weifengpy could you help advice?

eric-haibin-lin avatar Jul 15 '25 00:07 eric-haibin-lin

@eric-haibin-lin I will take a look. thanks for mentioning us

weifengpy avatar Jul 15 '25 06:07 weifengpy

Image

we have a bug in torch==2.6.0 in torch.distributed.checkpoint.state_dict.set_model_state_dict (see memory usage in the right figure) https://github.com/pytorch/pytorch/pull/134025

  • gpu memory got accumulated during broadcasting
  • and cpu offloading does not work

switching to torch=2.7.0 resolved the problem (see memory usage in the left figure)

@eric-haibin-lin would it be possible to upgrade pytorch to 2.7.0? if it has to be 2.6.0, we can land a fixed set_model_state_dict in verl. cc @mori360 @fegin

weifengpy avatar Jul 16 '25 00:07 weifengpy

Fixed in https://github.com/volcengine/verl/pull/2606

eric-haibin-lin avatar Jul 23 '25 02:07 eric-haibin-lin