FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

CUDA OOM When Using Flash Attention

Open HaniItani opened this issue 2 years ago • 4 comments
trafficstars

Hello,

Thank you for sharing your awesome work!

I'm trying to train Vicuna on my own dataset. I walked through the installation process from source. I had to install pytorch with cuda 11.7.0 support instead of 11.6. My server only supports cuda 11.2.2/11.4.4/11.5.2/11.7.0/11.8.0 but not 11.6.

When I try to train the 13B model with flash attention, I get CUDA OOM error even when the per_device_train_batch_size is set to 1. I think there might be a memory leak. I also tried building flash attention from source and still got the same error.

I know this is probably a flash attention problem, but do you have any insights? Any guidance will be very much appreciated.

Best regards, Hani

HaniItani avatar Apr 04 '23 01:04 HaniItani

I just tried training the 7B model and it works fine with both flash attention and without. It would be great if I can train models>7B with flash attention. I think using flash attention will enable me to train the 13B model only. Beyond that I'll need to look into memory efficient tuning methods such as LoRA

HaniItani avatar Apr 04 '23 01:04 HaniItani

How many GPUs do you use?

zhisbug avatar Apr 08 '23 02:04 zhisbug

@zhisbug, thank you for your reply.

I was using 4A100-80GB GPUs but then realized the repository uses 8.

I have a questions regarding the batch size. The run code uses 8GPUs with per_device_train_batch_size set to 4 and gradient_accumulation_steps set to 1. Correct me if I'm wrong, but the total batch size would be 8x4x1=32, but the repo says it uses a global batch size of 128.

I checked the Alpaca repo and they set the gradient_accumulation_steps to 8 for their 7B model so the global batch size would be 4x4x8 =128. Am I missing something?

Your input will be very much appreciated!

HaniItani avatar Apr 09 '23 16:04 HaniItani

@Michaelvll might be the right person to clarify the training config

zhisbug avatar Apr 20 '23 23:04 zhisbug

@HaniItani is your problem solved?

zhisbug avatar Jul 05 '23 19:07 zhisbug

Yes, I figured it out. Thank you very much.

HaniItani avatar Jul 05 '23 19:07 HaniItani

@HaniItani what did you do to fix it? i'm also running into OOM training 13B models with flash attention, wondering if it's the same problem. thanks!

alwayshalffull avatar Jul 05 '23 23:07 alwayshalffull

Hi @alwayshalffull , Flash Attention was not the problem, I had the wrong parameters. The 13B model requires 8 80GB-A100 GPUs to finetune with a per device train batch size of 4 as reflected in scripts/train_vicuna_13b.sh. If you're getting OOM when saving the model, please check this issue. Hope this helps.

HaniItani avatar Jul 07 '23 14:07 HaniItani