FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

An OOM error occurred while saving in the last step of training 13B model on A100(40G) * 8.

Open zhangzhengde0225 opened this issue 2 years ago • 1 comments
trafficstars

The CUDA out of memory error occurred while saving in the last step of training 13B model on A100(40G) * 8, can someone help me to solve this?

Training Details

  • Hardware devices: GPU A100(40G) * 4 in one node.
  • Technologies:
    • FDSP: full_shard + auto_wrap + offload
    • Flash-Attension
    • BF16 + TF32
    • Batch Size per GPU: 2
  • Datasets:dummy.json, 910 Q&A pairs.

Description

When the offload technology is not used, the OOM error will be reported at the beginning of training. According to @merrymercy 's reply on #346 , after adding --fsdp "full_shard auto_wrap offload" to the script, the training is work well.

The training script is as follows:

torchrun --nnodes=1 --nproc_per_node=8 --master_port=1993 \
    fastchat/train/train_mem.py \
    --model_name_or_path /data/vicuna/vicuna-13b \
    --data_path /data/dummy.json \
    --bf16 True \
    --output_dir ./outputs \
    --num_train_epochs 1 \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 16 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1200 \
    --save_total_limit 8 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap offload" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess True

However, the following error was reported at the end of the training: OutOfMemoryError: CUDA out of memory. Tried to allocate 1.18 GiB (GPU 2; 39.45 GiB total capacity; 36.87 GiB already allocated; 82.25 MiB free; 38.00 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The traceback is as follows:


╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /ihepbatch/cc/zdzhang/VSProjects/hepai-scientist/HaiST/scripts/../repos/FastChat/fastchat/train/ │
│ train_mem.py:14 in <module>                                                                      │
│                                                                                                  │
│   11 from fastchat.train.train import train                                                      │
│   12                                                                                             │
│   13 if __name__ == "__main__":                                                                  │
│ ❱ 14 │   train()                                                                                 │
│   15                                                                                             │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/VSProjects/hepai-scientist/HaiST/repos/FastChat/fastchat/train/train.py:28 │
│ 2 in train                                                                                       │
│                                                                                                  │
│   279 │   else:                                                                                  │
│   280 │   │   trainer.train()                                                                    │
│   281 │   trainer.save_state()                                                                   │
│ ❱ 282 │   safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)   │
│   283                                                                                            │
│   284                                                                                            │
│   285 if __name__ == "__main__":                                                                 │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/VSProjects/hepai-scientist/HaiST/repos/FastChat/fastchat/train/train.py:69 │
│ in safe_save_model_for_hf_trainer                                                                │
│                                                                                                  │
│    66                                                                                            │
│    67 def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):        │
│    68 │   """Collects the state dict and dump to disk."""                                        │
│ ❱  69 │   state_dict = trainer.model.state_dict()                                                │
│    70 │   if trainer.args.should_save:                                                           │
│    71 │   │   cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}           │
│    72 │   │   del state_dict                                                                     │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/nn/modules/module. │
│ py:1818 in state_dict                                                                            │
│                                                                                                  │
│   1815 │   │   self._save_to_state_dict(destination, prefix, keep_vars)                          │
│   1816 │   │   for name, module in self._modules.items():                                        │
│   1817 │   │   │   if module is not None:                                                        │
│ ❱ 1818 │   │   │   │   module.state_dict(destination=destination, prefix=prefix + name + '.', k  │
│   1819 │   │   for hook in self._state_dict_hooks.values():                                      │
│   1820 │   │   │   hook_result = hook(self, destination, prefix, local_metadata)                 │
│   1821 │   │   │   if hook_result is not None:                                                   │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/nn/modules/module. │
│ py:1818 in state_dict                                                                            │
│                                                                                                  │
│   1815 │   │   self._save_to_state_dict(destination, prefix, keep_vars)                          │
│   1816 │   │   for name, module in self._modules.items():                                        │
│   1817 │   │   │   if module is not None:                                                        │
│ ❱ 1818 │   │   │   │   module.state_dict(destination=destination, prefix=prefix + name + '.', k  │
│   1819 │   │   for hook in self._state_dict_hooks.values():                                      │
│   1820 │   │   │   hook_result = hook(self, destination, prefix, local_metadata)                 │
│   1821 │   │   │   if hook_result is not None:                                                   │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/nn/modules/module. │
│ py:1818 in state_dict                                                                            │
│                                                                                                  │
│   1815 │   │   self._save_to_state_dict(destination, prefix, keep_vars)                          │
│   1816 │   │   for name, module in self._modules.items():                                        │
│   1817 │   │   │   if module is not None:                                                        │
│ ❱ 1818 │   │   │   │   module.state_dict(destination=destination, prefix=prefix + name + '.', k  │
│   1819 │   │   for hook in self._state_dict_hooks.values():                                      │
│   1820 │   │   │   hook_result = hook(self, destination, prefix, local_metadata)                 │
│   1821 │   │   │   if hook_result is not None:                                                   │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/nn/modules/module. │
│ py:1818 in state_dict                                                                            │
│                                                                                                  │
│   1815 │   │   self._save_to_state_dict(destination, prefix, keep_vars)                          │
│   1816 │   │   for name, module in self._modules.items():                                        │
│   1817 │   │   │   if module is not None:                                                        │
│ ❱ 1818 │   │   │   │   module.state_dict(destination=destination, prefix=prefix + name + '.', k  │
│   1819 │   │   for hook in self._state_dict_hooks.values():                                      │
│   1820 │   │   │   hook_result = hook(self, destination, prefix, local_metadata)                 │
│   1821 │   │   │   if hook_result is not None:                                                   │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/nn/modules/module. │
│ py:1815 in state_dict                                                                            │
│                                                                                                  │
│   1812 │   │   if hasattr(destination, "_metadata"):                                             │
│   1813 │   │   │   destination._metadata[prefix[:-1]] = local_metadata                           │
│   1814 │   │                                                                                     │
│ ❱ 1815 │   │   self._save_to_state_dict(destination, prefix, keep_vars)                          │
│   1816 │   │   for name, module in self._modules.items():                                        │
│   1817 │   │   │   if module is not None:                                                        │
│   1818 │   │   │   │   module.state_dict(destination=destination, prefix=prefix + name + '.', k  │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/nn/modules/module. │
│ py:1722 in _save_to_state_dict                                                                   │
│                                                                                                  │
│   1719 │   │   │   │   module                                                                    │
│   1720 │   │   """                                                                               │
│   1721 │   │   for hook in self._state_dict_pre_hooks.values():                                  │
│ ❱ 1722 │   │   │   hook(self, prefix, keep_vars)                                                 │
│   1723 │   │                                                                                     │
│   1724 │   │   for name, param in self._parameters.items():                                      │
│   1725 │   │   │   if param is not None:                                                         │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/utils/_contextlib. │
│ py:115 in decorate_context                                                                       │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/_ │
│ state_dict_utils.py:669 in _pre_state_dict_hook                                                  │
│                                                                                                  │
│   666 │   │   StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,                        │
│   667 │   │   StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,                    │
│   668 │   }                                                                                      │
│ ❱ 669 │   _pre_state_dict_hook_fn[fsdp_state._state_dict_type](                                  │
│   670 │   │   fsdp_state,                                                                        │
│   671 │   │   module,                                                                            │
│   672 │   │   *args,                                                                             │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/_ │
│ state_dict_utils.py:271 in _full_pre_state_dict_hook                                             │
│                                                                                                  │
│   268 │   in ``nn.Module``.                                                                      │
│   269 │   """                                                                                    │
│   270 │   _common_pre_state_dict_hook(module, fsdp_state)                                        │
│ ❱ 271 │   _common_unshard_pre_state_dict_hook(                                                   │
│   272 │   │   module,                                                                            │
│   273 │   │   fsdp_state,                                                                        │
│   274 │   │   offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,                       │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/_ │
│ state_dict_utils.py:143 in _common_unshard_pre_state_dict_hook                                   │
│                                                                                                  │
│   140 │   Performs the pre-state_dict tasks shared by all state_dict types that require          │
│   141 │   ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this ho   │
│   142 │   """                                                                                    │
│ ❱ 143 │   _enter_unshard_params_ctx(                                                             │
│   144 │   │   module,                                                                            │
│   145 │   │   fsdp_state,                                                                        │
│   146 │   │   writeback=False,                                                                   │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/_ │
│ state_dict_utils.py:109 in _enter_unshard_params_ctx                                             │
│                                                                                                  │
│   106 │   │   offload_to_cpu=offload_to_cpu,                                                     │
│   107 │   │   with_grads=with_grads,                                                             │
│   108 │   )                                                                                      │
│ ❱ 109 │   fsdp_state._unshard_params_ctx[module].__enter__()                                     │
│   110                                                                                            │
│   111                                                                                            │
│   112 @no_type_check                                                                             │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/contextlib.py:135 in __enter__         │
│                                                                                                  │
│   132 │   │   # they are only needed for recreation, which is not possible anymore               │
│   133 │   │   del self.args, self.kwds, self.func                                                │
│   134 │   │   try:                                                                               │
│ ❱ 135 │   │   │   return next(self.gen)                                                          │
│   136 │   │   except StopIteration:                                                              │
│   137 │   │   │   raise RuntimeError("generator didn't yield") from None                         │
│   138                                                                                            │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/_ │
│ unshard_param_utils.py:198 in _unshard_fsdp_state_params                                         │
│                                                                                                  │
│   195 │   # No need to call `wait_stream()` since we unshard in the computation                  │
│   196 │   # stream directly                                                                      │
│   197 │   computation_stream = torch.cuda.current_stream()                                       │
│ ❱ 198 │   _unshard(state, handles, computation_stream, computation_stream)                       │
│   199 │   if with_grads:                                                                         │
│   200 │   │   _unshard_grads(handles)                                                            │
│   201                                                                                            │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/_ │
│ runtime_utils.py:329 in _unshard                                                                 │
│                                                                                                  │
│    326 │   │   │   event.synchronize()                                                           │
│    327 │   with torch.cuda.stream(unshard_stream):                                               │
│    328 │   │   for handle in handles:                                                            │
│ ❱  329 │   │   │   handle.unshard()                                                              │
│    330 │   │   │   handle.post_unshard()                                                         │
│    331                                                                                           │
│    332                                                                                           │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/f │
│ lat_param.py:918 in unshard                                                                      │
│                                                                                                  │
│    915 │   │   │   )                                                                             │
│    916 │   │   │   self._use_unsharded_flat_param(unsharded_flat_param)                          │
│    917 │   │   │   return                                                                        │
│ ❱  918 │   │   unsharded_flat_param = self._alloc_padded_unsharded_flat_param()                  │
│    919 │   │   padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)   │
│    920 │   │   self._use_unsharded_flat_param(padded_unsharded_flat_param)                       │
│    921                                                                                           │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/f │
│ lat_param.py:944 in _alloc_padded_unsharded_flat_param                                           │
│                                                                                                  │
│    941 │   │   flat_param = self.flat_param                                                      │
│    942 │   │   unsharded_flat_param = self._get_padded_unsharded_flat_param()                    │
│    943 │   │   self._check_storage_freed(unsharded_flat_param)                                   │
│ ❱  944 │   │   _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size)  # type:  │
│    945 │   │   return unsharded_flat_param                                                       │
│    946 │                                                                                         │
│    947 │   def _get_padded_unsharded_flat_param(self) -> torch.Tensor:                           │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/utils/_contextlib. │
│ py:115 in decorate_context                                                                       │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/distributed/fsdp/_ │
│ utils.py:79 in _alloc_storage                                                                    │
│                                                                                                  │
│    76 │   │   │   tensor_storage_size == 0,                                                      │
│    77 │   │   │   f"Tensor storage should have been resized to be 0 but got {tensor_storage_si   │
│    78 │   │   )                                                                                  │
│ ❱  79 │   │   tensor._typed_storage()._resize_(size.numel())                                     │
│    80 │   return not already_allocated                                                           │
│    81                                                                                            │
│    82                                                                                            │
│                                                                                                  │
│ /ihepbatch/cc/zdzhang/anaconda3/envs/haist/lib/python3.10/site-packages/torch/storage.py:764 in  │
│ _resize_                                                                                         │
│                                                                                                  │
│    761 │                                                                                         │
│    762 │   # For internal use only, to avoid deprecation warning                                 │
│    763 │   def _resize_(self, size):                                                             │
│ ❱  764 │   │   self._untyped_storage.resize_(size * self._element_size())                        │
│    765 │                                                                                         │
│    766 │   @classmethod                                                                          │
│    767 │   def _free_weak_ref(cls, *args, **kwargs):                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯

Error Analysis:

According to the traceback, the code state_dict = trainer.model.state_dict() retieve the dictionary containing the whole state of the model. The returned object is a shallow copy.

I guess the possiable reason is: Although the shallow copy only contains references to the same memory locations as the original object, saving memory, the GPU is already too full to allocate even less space, resulting in an error.

Is there any expert guidance on how to handle this issue?Thank you.

zhangzhengde0225 avatar May 02 '23 12:05 zhangzhengde0225

@zhangzhengde0225 I workaround the issue by using 8*A100(80G). I got similar results like yours, the training process went smooth but the error happened during the model weight persistent. Check this issue for more details https://github.com/lm-sys/FastChat/issues/256

Jeffwan avatar May 02 '23 21:05 Jeffwan

see also https://github.com/pytorch/pytorch/issues/98823

merrymercy avatar May 05 '23 15:05 merrymercy