FastChat
FastChat copied to clipboard
How to fine tune vicuna-7b with A40
How to fine tune vicuna-7b with A40
FlashAttention backward for head dim > 64 requires A100 or H100
GPUs as the implementation needs a large amount of shared memory.
This might be related.. got this error while using train_mem.py with A40s..
same problem.
running train_mem.py using following args:
torchrun --nproc_per_node=2 --master_port=20001 /data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train_mem.py
--model_name_or_path /data/ljn/Vicuna-13B/model/FastChat/data/hf
--data_path /data/ljn/Vicuna-13B/model/FastChat/playground/data/leecode_new.json
--bf16 True
--output_dir /data/ljn/Vicuna-13B/model/FastChat/output
--num_train_epochs 3
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--gradient_accumulation_steps 16
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 1200
--save_total_limit 10
--learning_rate 2e-5
--weight_decay 0.
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--fsdp "full_shard auto_wrap"
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer'
--tf32 True
--model_max_length 2048
--gradient_checkpointing True
--lazy_preprocess True
Error:
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████| 3/3 [03:19<00:00, 66.41s/it]
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: Tracking run with wandb version 0.15.3
wandb: W&B syncing is set to offline
in this directory.
wandb: Run wandb online
or set WANDB_MODE=online to enable cloud syncing.
0%| | 0/537 [00:00<?, ?it/s]Traceback (most recent call last):
File "/data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train_mem.py", line 13, in
I used 2 * A40 on my machine, is there any suggestions to solve this?
maybe you can remove flash_attn() ,
I am facing this similar issues, anyone figured out how to solve this?
我似乎也遇到了同样的问题
How to solve it
@Hzzhang-nlp if cuda is 12.x, can install pytorch 12.1 from nightly and install flash-attention from source
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
git clone https://github.com/HazyResearch/flash-attention.git
python setup.py install
It works for me. But still OOM for a single A6000 finetuning.
We can't do much about this, I fear. If the gpu has not enough memory, it has not enough memory...