FastChat
FastChat copied to clipboard
CUDA OOM When Using Flash Attention
Hello,
Thank you for sharing your awesome work!
I'm trying to train Vicuna on my own dataset. I walked through the installation process from source. I had to install pytorch with cuda 11.7.0 support instead of 11.6. My server only supports cuda 11.2.2/11.4.4/11.5.2/11.7.0/11.8.0 but not 11.6.
When I try to train the 13B model with flash attention, I get CUDA OOM error even when the per_device_train_batch_size is set to 1. I think there might be a memory leak. I also tried building flash attention from source and still got the same error.
I know this is probably a flash attention problem, but do you have any insights? Any guidance will be very much appreciated.
Best regards, Hani
I just tried training the 7B model and it works fine with both flash attention and without. It would be great if I can train models>7B with flash attention. I think using flash attention will enable me to train the 13B model only. Beyond that I'll need to look into memory efficient tuning methods such as LoRA
How many GPUs do you use?
@zhisbug, thank you for your reply.
I was using 4A100-80GB GPUs but then realized the repository uses 8.
I have a questions regarding the batch size. The run code uses 8GPUs with per_device_train_batch_size set to 4 and gradient_accumulation_steps set to 1. Correct me if I'm wrong, but the total batch size would be 8x4x1=32, but the repo says it uses a global batch size of 128.
I checked the Alpaca repo and they set the gradient_accumulation_steps to 8 for their 7B model so the global batch size would be 4x4x8 =128. Am I missing something?
Your input will be very much appreciated!
@Michaelvll might be the right person to clarify the training config
@HaniItani is your problem solved?
Yes, I figured it out. Thank you very much.
@HaniItani what did you do to fix it? i'm also running into OOM training 13B models with flash attention, wondering if it's the same problem. thanks!
Hi @alwayshalffull , Flash Attention was not the problem, I had the wrong parameters. The 13B model requires 8 80GB-A100 GPUs to finetune with a per device train batch size of 4 as reflected in scripts/train_vicuna_13b.sh. If you're getting OOM when saving the model, please check this issue. Hope this helps.