FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

How to fine tune vicuna-7b with A40

Open yqh984638220 opened this issue 1 year ago • 8 comments

How to fine tune vicuna-7b with A40

yqh984638220 avatar May 17 '23 12:05 yqh984638220

FlashAttention backward for head dim > 64 requires A100 or H100
GPUs as the implementation needs a large amount of shared memory.

This might be related.. got this error while using train_mem.py with A40s..

gabinguo avatar May 21 '23 23:05 gabinguo

same problem. running train_mem.py using following args: torchrun --nproc_per_node=2 --master_port=20001 /data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train_mem.py
--model_name_or_path /data/ljn/Vicuna-13B/model/FastChat/data/hf
--data_path /data/ljn/Vicuna-13B/model/FastChat/playground/data/leecode_new.json
--bf16 True
--output_dir /data/ljn/Vicuna-13B/model/FastChat/output
--num_train_epochs 3
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--gradient_accumulation_steps 16
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 1200
--save_total_limit 10
--learning_rate 2e-5
--weight_decay 0.
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--fsdp "full_shard auto_wrap"
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer'
--tf32 True
--model_max_length 2048
--gradient_checkpointing True
--lazy_preprocess True

Error: Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████| 3/3 [03:19<00:00, 66.41s/it] wandb: (1) Create a W&B account wandb: (2) Use an existing W&B account wandb: (3) Don't visualize my results wandb: Enter your choice: 3 wandb: You chose "Don't visualize my results" wandb: Tracking run with wandb version 0.15.3 wandb: W&B syncing is set to offline in this directory. wandb: Run wandb online or set WANDB_MODE=online to enable cloud syncing. 0%| | 0/537 [00:00<?, ?it/s]Traceback (most recent call last): File "/data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train_mem.py", line 13, in train() File "/data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train.py", line 263, in train trainer.train() File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/transformers/trainer.py", line 1662, in train return inner_training_loop( File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/transformers/trainer.py", line 1927, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/transformers/trainer.py", line 2717, in training_step loss.backward() File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward torch.autograd.backward( File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/autograd/init.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply return user_fn(self, *args) File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 157, in backward torch.autograd.backward(outputs_with_grad, args_with_grad) File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/autograd/init.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply return user_fn(self, *args) File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 75, in backward _flash_attn_backward( File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 42, in _flash_attn_backward _, _, _, softmax_d = flash_attn_cuda.bwd( RuntimeError: FlashAttention backward for head dim > 64 requires A100 or H100 GPUs as the implementation needs a large amount of shared memory.

I used 2 * A40 on my machine, is there any suggestions to solve this?

jhu10 avatar May 29 '23 09:05 jhu10

maybe you can remove flash_attn() , image

Ted8000 avatar May 30 '23 02:05 Ted8000

I am facing this similar issues, anyone figured out how to solve this?

prateeky2806 avatar Jun 03 '23 08:06 prateeky2806

我似乎也遇到了同样的问题 image

Hzzhang-nlp avatar Jun 08 '23 06:06 Hzzhang-nlp

How to solve it image

Hzzhang-nlp avatar Jun 08 '23 06:06 Hzzhang-nlp

@Hzzhang-nlp if cuda is 12.x, can install pytorch 12.1 from nightly and install flash-attention from source

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

git clone https://github.com/HazyResearch/flash-attention.git
python setup.py install

aduan avatar Jun 14 '23 02:06 aduan

It works for me. But still OOM for a single A6000 finetuning.

Len-Li avatar Jun 15 '23 15:06 Len-Li

We can't do much about this, I fear. If the gpu has not enough memory, it has not enough memory...

surak avatar Oct 23 '23 09:10 surak