DeepSpeed
DeepSpeed copied to clipboard
[BUG] OOM when train 70B models using deepspeed 0.16.4
We found that using OpenRLHF + DeepSpeed 0.15.0, SFT + Adam Offload can train a 70B model with 8 A100 70G + ZeRO3, whereas DeepSpeed 0.16.4 results in OOM. You can try the script https://github.com/OpenRLHF/OpenRLHF/blob/main/examples/scripts/train_sft_llama.sh and use the 70B model + Adam Offload to reproduce the issue. It looks like this is a serious bug that deepspeed 16.4 can't train 70b models.
Reproduced the issue. In OpenRLHF 0.5.7, after upgrading deepspeed from 0.15.0 to 0.16.0, my ppo demo encounters OOM
File "/root/miniconda3/lib/python3.10/site-packages/openrlhf/utils/deepspeed/deepspeed.py", line 129, in backward
model.backward(loss)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2048, in backward
self.optimizer.backward(loss, retain_graph=retain_graph)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 2263, in backward
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
scaled_loss.backward(retain_graph=retain_graph)
File "/root/miniconda3/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1177, in reduce_partition_and_remove_grads
self.reduce_ready_partitions_and_remove_grads(param)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1532, in reduce_ready_partitions_and_remove_grads
self.reduce_independent_p_g_buckets_and_remove_grads(param)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1269, in reduce_independent_p_g_buckets_and_remove_grads
self.__reduce_and_partition_ipg_grads()
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1319, in __reduce_and_partition_ipg_grads
grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py", line 1388, in __avg_scatter_grads
grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, self.dp_process_group)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/runtime/comm/coalesced_collectives.py", line 122, in reduce_scatter_coalesced
tensor_partition_flat_buffer = instrument_w_nvtx(torch.cat)(tensor_partitions_lst_with_padding)
File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 932.00 MiB. GPU 0 has a total capacity of 79.33 GiB of which 524.44 MiB is free. Process 99710 has 62.04 GiB memory in use. Process 100513 has 16.65 GiB memory in use. Of the allocated memory 59.29 GiB is allocated by PyTorch, and 1.12 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
This commit https://github.com/deepspeedai/DeepSpeed/commit/cd20a3bbc7713908d7fb5fd7af4a91d52f126370 introduced the issue
Open reasoner zero identified the same issue. https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero/issues/13
@loadams
I found that this issue occurs because the ds_grads_remaining cannot correctly count for the forward pass when using gradient checkpointing.
Below is the log for enable gradient checkpointing.:
Thanks @hijkzzz and @thomZ1 for narrowing down the problematic commit.
Tagging @tjruwase and @wenbinc-Bin for the original PR to take a look as well.
@loadams I'd like to work on this.
Any update on this? Lots of people are waiting on this to be resolved, so they can upgrade to use the new AutoTP for additional optimization in their training
Any update on this?
I encountered the same issue during the 14B-MoE training. It appears that versions of DeepSpeed 0.16.0 and above require more GPU memory.
@loadams Any update on this?
@hijkzzz - I haven't had time to work on this more unfortunately. @delock - @wenbinc-Bin's PR seems to maybe be the culprit, but could you help take a look too?
Sorry. I don't have time to look at this issue right now. I will look at it when I have time. If any of you have time, you can look at this issue first.
Would it be best to revert the breaking change as it's more likely that DPO is a smaller use case than everyone affected by this.
I think part of the problem is that we call the backwards hook in the forward to clear the ds_grads_remaining:
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))
@hijkzzz - I haven't had time to work on this more unfortunately. @delock - @wenbinc-Bin's PR seems to maybe be the culprit, but could you help take a look too?
Let me talk with Wenbin to understand the original PR. @loadams
I looked at the place where __n_available_params is set to zero. The loop will release params and will decrease this variable accordingly. So to speak if the loop didn't release all params, set this variable to zero will make prefetch fetch more params than planned. It would help if we know the value of this variable before line 454 in case OOM happens.
If the logic in https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/zero/partitioned_param_coordinator.py#L481 and https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/zero/partitioned_param_coordinator.py#L520 is comprehensive, then I don't see the needs to reset __n_avaiable_params in the code. I saw this reset was in reset_step then moved to release_and_reset_all. Is there a reason to add this reset in the first place?
One behavior I noticed is that module.ds_grads_remaining doesn't decrement when using gradient checkpointing and use_reentrant=False. When using use_reentrant=True, most of the decoder layers will have the correct module.ds_grads_remaining value which will properly trigger the self.post_sub_module_backward_function, with the exception of the nn.CrossEntropyLoss module which isn't typically wrapped with checkpointing, so it's module.ds_grads_remaining value continues to climb during each step.
@loadams I have talked to @wenbinc-Bin offline, he agreed to rollback his PR first to address this issue. Then he would submit a more comprehensive fix on OOM issue. I'll create a rollback PR as a temporary fix for this issue.