llama-recipes icon indicating copy to clipboard operation
llama-recipes copied to clipboard

OOM when saving model, seems the clause "FSDP(model,...)" will cause the OOM.

Open cyrishe opened this issue 1 year ago • 8 comments

System Info

A100-40G *8 system ubuntu-20.04 peft 0.4.0 torch 2.2.0.dev20231012+cu118

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

🐛 Describe the bug

After training 34B-codellama2 with LoRA(the training goes on well), OOM when save_pretrained . I tried different PEFT versions, from v 0.3.0 to source code, do not work. It is said PEFT-0.2.0 will work, but llama-recipes can't work with PEFT-0.2.0

I modified code , added save_pretrained just after model loaded. It shows that OOM will occur after FSDP operation.

    mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
    my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
    ### add save model before FSDP
    print("before FSDP save model")
    model.save_pretrained("./test_pretrained1")
    model = FSDP(
        model,
        auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
        cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
        mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
        sharding_strategy=fsdp_config.sharding_strategy,
        device_id=torch.cuda.current_device(),
        limit_all_gathers=True,
        #add use_orig_params=True for adapter#
        #use_orig_params=True,
        sync_module_states=train_config.low_cpu_fsdp,
        param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
        if train_config.low_cpu_fsdp and rank != 0 else None,
    )
    ### add save model after FSDP
    print("after FSDP save model")
    model.save_pretrained("./test_pretrained2")

Error logs

before FSDP save model trainable params: 19,660,800 || all params: 33,763,631,104 || trainable%: 0.05823070373989121 before FSDP save model trainable params: 19,660,800 || all params: 33,763,631,104 || trainable%: 0.05823070373989121 before FSDP save model trainable params: 19,660,800 || all params: 33,763,631,104 || trainable%: 0.05823070373989121 before FSDP save model ^MLoading checkpoint shards: 0%| | 0/7 [00:00<?, ?it/s]^MLoading checkpoint shards: 14%|█▍ | 1/7 [00:07<00:45, 7.64s/it]^MLoading checkpoint shards: 29%|██▊ | 2/7 [00:14<00:37, 7.48s/it]^MLoading checkpoint shards: 43%|████▎ | 3/7 [00:22<00:29, 7.50s/it]^MLoading checkpoint shards: 57%|█████▋ | 4/7 [00:29<00:22, 7.47s/it]^MLoading checkpoint shards: 71%|███████▏ | 5/7 [00:37<00:15, 7.64s/it]^MLoading checkpoint shards: 86%|████████▌ | 6/7 [00:45<00:07, 7.52s/it]^MLoading checkpoint shards: 100%|██████████| 7/7 [00:52<00:00, 7.42s/it]^MLoading checkpoint shards: 100%|██████████| 7/7 [00:52<00:00, 7.49s/it] --> Model /mnt/shhg01/cyris/model_hub/codellama34B_chat/

--> /mnt/shhg01/cyris/model_hub/codellama34B_chat/ has 33743.970304 Million params

paras distr: 435 0 LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type='CAUSAL_LM', inference_mode=False, r=16, target_modules=['q_proj', 'v_proj'], lora_alpha=32, lora_dropout=0.05, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None) trainable params: 19,660,800 || all params: 33,763,631,104 || trainable%: 0.05823070373989121 bFloat16 enabled for mixed precision - using bfSixteen policy before FSDP save model after FSDP save model after FSDP save model Traceback (most recent call last): File "/mnt/shhg01/cyris/univista-llm-train/ft/lora/finetuning.py", line 8, in Traceback (most recent call last): File "/mnt/shhg01/cyris/univista-llm-train/ft/lora/finetuning.py", line 8, in fire.Fire(main) File "/mnt/shhg01/env/llama-recipes/lib/python3.10/site-packages/fire/core.py", line 141, in Fire fire.Fire(main) File "/mnt/shhg01/env/llama-recipes/lib/python3.10/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/mnt/shhg01/env/llama-recipes/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/mnt/shhg01/env/llama-recipes/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/mnt/shhg01/env/llama-recipes/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace component, remaining_args = _CallAndUpdateTrace( File "/mnt/shhg01/env/llama-recipes/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/mnt/shhg01/cyris/univista-llm-train/ft/lora/llama_recipes/finetuning.py", line 204, in main component = fn(*varargs, **kwargs) File "/mnt/shhg01/cyris/univista-llm-train/ft/lora/llama_recipes/finetuning.py", line 204, in main model.save_pretrained("./test_pretrained2")

... ... torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 6 has a total capacty of 39.59 GiB of which 1.12 MiB is free. Process 3377172 has 39.58 GiB memory in use. Of the allocated memory 38.18 GiB is allocated by PyTorch, and 40.47 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF Exception raised from malloc at ../c10/cuda/CUDACachingAllocator.cpp:1114 (most recent call first):

Expected behavior

I used llama-recipes on finetuning 7B or 13B model, all went on very well. But OOM when finetuning 34B, and not in training but in saving. Hope we can fix this problem.

cyrishe avatar Oct 23 '23 08:10 cyrishe

@cyrishe it seems you are using FSDP only not with PEFT from the code you shared above. Also seems you are trying to save an FSDP wrapped model with mode.save_pretrain(). FSDP checkpoint saving is different and you can find it here, https://github.com/facebookresearch/llama-recipes/blob/main/src/llama_recipes/utils/train_utils.py#L155-L174

HamidShojanazeri avatar Oct 29 '23 19:10 HamidShojanazeri

@HamidShojanazeri Thanks for your information. I think I am using FSDP and PEFT on the training task. The training arguments are as follows: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes 1 --nproc_per_node 8 finetuning.py
--enable_fsdp --pure_bf16 --low_cpu_fsdp --use_peft
--peft_method lora
--batch_size_training 1 --gradient_accumulation_steps 4 --num_workers_dataloader 4 --num_epochs 2 --model_name ${local_path_of_the_weights} --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --output_dir ${output}

I successfully trained and saved a LoRA adapter for a 7B model, using the same codes, but OOM occurred when I did the same on a 34B model. Later this week I will train the 7B model again, and record the GPU status when saving the adapter, to see if there will be a boost in mem usage.

cyrishe avatar Oct 30 '23 01:10 cyrishe

Hey @cyrishe ,

Did you manage to solve this issue? I am encountering the same problem while trying to finetune 70B with FSDP and PEFT on 5 A100 x 80GB. I successfully managed to train the 7B model before with the same parameters and scripts.

Mugheeera avatar Mar 18 '24 10:03 Mugheeera

@Mugheeera this sounds to be a transformers bug, trying need to do more investigation but to unblock you (similar to this issue).

I could repro this issue with transformers version of 4.38.1 which is from pip install, installing from src could resolve the issue transformers 4.39.0.dev0

git clone https://github.com/huggingface/transformers.git cd transformers/

pip install -e .

can you pls give it a try.

HamidShojanazeri avatar Mar 18 '24 15:03 HamidShojanazeri

Hi @HamidShojanazeri , I also got the same OOM error when using 8xH100s. I'm using the transformer '4.41.0.dev0' and build the llama-recipes from the source. I observed this error will happen when training with 70B model (in my case, the llama3-70b) and it does not happen for llama 3 8B. I'm wondering the reason to use use model.save_pretrained when using both PEFT and FSDP instead of using save_model_and_optimizer_sharded as did when not using PEFT? (https://github.com/meta-llama/llama-recipes/blob/c1f8de216740712eb0ed21f092c9fb9f47fd74ed/src/llama_recipes/utils/train_utils.py#L228-L259)

Would it possible the reason of OOM is because the model needs to gather weights across ranks before model.save_pretrained?

yueyugua avatar May 13 '24 14:05 yueyugua

@yueyugua the reason for that is if you are trying to FT with Lora then you would need to save only LORA checkpoints which is doable with model.save_pretrained(train_config.output_dir) otherwise, save_model_checkpoint will try to save the FSDP checkpoints where it runs into OOM issue.

HamidShojanazeri avatar May 15 '24 20:05 HamidShojanazeri

Hi @HamidShojanazeri , thank you for the reply. I was using lora to finetune the 70B llama3 model and ran into this OOM issue, that's why I thought this might be the root cause.

If lora take too much memory, will qlora be the right way to do it? However I found qlora is not supported from this ticket https://github.com/meta-llama/llama-recipes/issues/240, and I'm wondering is there plan to support it?

yueyugua avatar May 17 '24 20:05 yueyugua

Hi @yueyugua and @HamidShojanazeri , I'm running into the same issue. I'm finetuning llama3-70b with FSDP and Peft Lora on 12xH100 and run into OOM (torch.OutOfMemoryError: CUDA out of memory). The error occurs after the first epoch when trying to save the model state_dict = model.state_dict()

My command: srun torchrun --nnodes=3 --nproc_per_node=4 --rdzv_backend=c10d --rdzv_endpoint=$RDZV_HOST:$RDZV_PORT finetuning.py --model_name "Meta-Llama-3-70B" --dataset custom_dataset --custom_dataset.file my_data.py --low_cpu_fsdp --fsdp_config.pure_bf16 --batch_size_training 1 --enable_fsdp True --fsdp_config.checkpoint_type StateDictType.FULL_STATE_DICT --dist_checkpoint_root_folder --use_peft True --output_dir "model" --save_model True --num_epochs 1 --peft_method lora --use_fp16 --gradient_accumulation_steps 8 --fsdp_config.fsdp_cpu_offload True

Details about the error: [rank4]: state_dict = model.state_dict() [rank4]: ^^^^^^^^^^^^^^^^^^ [rank4]: File "/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1915, in state_dict [rank4]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) [rank4]: File "/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1915, in state_dict [rank4]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) [rank4]: File "/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1915, in state_dict [rank4]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) [rank4]: [Previous line repeated 2 more times] [rank4]: File "/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1911, in state_dict [rank4]: hook(self, prefix, keep_vars) [rank4]: File "/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context [rank4]: return func(*args, ^^^^^^^^^^^^^^^^^^^^^ [rank4]: File "/lib/python3.11/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 786, in _pre_state_dict_hook [rank4]: _pre_state_dict_hook_fn[fsdp_state._state_dict_type]( [rank4]: File "/lib/python3.11/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 307, in _full_pre_state_dict_hook [rank4]: _common_unshard_pre_state_dict_hook( [rank4]: File "/lib/python3.11/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 174, in _common_unshard_pre_state_dict_hook [rank4]: _enter_unshard_params_ctx( [rank4]: File "/lib/python3.11/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 138, in _enter_unshard_params_ctx [rank4]: fsdp_state._unshard_params_ctx[module].__enter__() [rank4]: File "/apps/miniconda3/23.5.2/lib/python3.11/contextlib.py", line 137, in __enter__ [rank4]: return next(self.gen) [rank4]: ^^^^^^^^^^^^^^ [rank4]: File "/lib/python3.11/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 196, in _unshard_fsdp_state_params [rank4]: _unshard(state, handle, computation_stream, computation_stream) [rank4]: File "/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 299, in _unshard [rank4]: handle.unshard() [rank4]: File "/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 1309, in unshard [rank4]: unsharded_flat_param = self._alloc_padded_unsharded_flat_param() [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank4]: File "/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 1336, in _alloc_padded_unsharded_flat_param [rank4]: _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined] [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank4]: File "/lib/python3.11/site-packages/torch/distributed/utils.py", line 168, in _alloc_storage [rank4]: tensor._typed_storage()._resize_(size.numel()) [rank4]: File "/lib/python3.11/site-packages/torch/storage.py", line 989, in _resize_ [rank4]: self._untyped_storage.resize_(size * self._element_size()) [rank4]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.60 GiB. GPU 0 has a total capacity of 93.23 GiB of which 190.56 MiB is free. Including non-PyTorch memory, this process has 92.95 GiB memory in use. Of the allocated memory 90.26 GiB is allocated by PyTorch, and 1.60 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Installed modules:

Package                  Version
------------------------ ------------------------
absl-py                  2.1.0
accelerate               0.30.1
aiohttp                  3.9.5
aiosignal                1.3.1
asttokens                2.4.1
attrs                    23.2.0
bert-score               0.3.13
certifi                  2024.2.2
charset-normalizer       3.3.2
click                    8.1.7
comm                     0.2.2
contourpy                1.2.1
cycler                   0.12.1
datasets                 2.19.1
debugpy                  1.8.1
decorator                5.1.1
dill                     0.3.8
executing                2.0.1
filelock                 3.13.1
fire                     0.6.0
fonttools                4.53.0
frozenlist               1.4.1
fsspec                   2024.2.0
huggingface-hub          0.23.2
idna                     3.7
ipykernel                6.29.4
ipython                  8.25.0
jedi                     0.19.1
Jinja2                   3.1.3
joblib                   1.4.2
jupyter_client           8.6.2
jupyter_core             5.7.2
kiwisolver               1.4.5
MarkupSafe               2.1.5
matplotlib               3.9.0
matplotlib-inline        0.1.7
mpmath                   1.2.1
multidict                6.0.5
multiprocess             0.70.16
nest-asyncio             1.6.0
networkx                 3.2.1
nltk                     3.8.1
numpy                    1.26.4
nvidia-cublas-cu12       12.4.2.65
nvidia-cuda-cupti-cu12   12.4.99
nvidia-cuda-nvrtc-cu12   12.4.99
nvidia-cuda-runtime-cu12 12.4.99
nvidia-cudnn-cu12        8.9.7.29
nvidia-cufft-cu12        11.2.0.44
nvidia-curand-cu12       10.3.5.119
nvidia-cusolver-cu12     11.6.0.99
nvidia-cusparse-cu12     12.3.0.142
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.99
nvidia-nvtx-cu12         12.4.99
packaging                24.0
pandas                   2.2.2
parso                    0.8.4
peft                     0.11.0
pexpect                  4.9.0
Pillow                   9.3.0
pip                      24.0
platformdirs             4.2.2
prompt_toolkit           3.0.45
psutil                   5.9.8
ptyprocess               0.7.0
pure-eval                0.2.2
pyarrow                  16.1.0
pyarrow-hotfix           0.6
Pygments                 2.18.0
pyparsing                3.1.2
python-dateutil          2.9.0.post0
pytorch-triton           3.0.0+45fff310c8
pytz                     2024.1
PyYAML                   6.0.1
pyzmq                    26.0.3
regex                    2024.5.15
requests                 2.32.3
rouge_score              0.1.2
safetensors              0.4.3
setuptools               65.5.0
six                      1.16.0
stack-data               0.6.3
sympy                    1.12
termcolor                2.4.0
tokenizers               0.15.2
torch                    2.4.0.dev20240515+cu124
torchaudio               2.2.0.dev20240515+cu124
torchvision              0.19.0.dev20240515+cu124
tornado                  6.4
tqdm                     4.66.4
traitlets                5.14.3
transformers             4.38.2
typing_extensions        4.8.0
tzdata                   2024.1
urllib3                  2.2.1
wcwidth                  0.2.13
xxhash                   3.4.1
yarl                     1.9.4

Do you have any suggestions?

RoelTim avatar Jun 10 '24 23:06 RoelTim