During pre-training, using FA2 consumes more memory than using SDPA
As described in the title
When performing pre-training, using FA2 will consume more GPU memory than using SDPA.
I am using the trainer from transformers for training, and the simplified code is roughly as follows:
model = AutoModelForMaskedLM.from_pretrained(
MODEL_PATH,
torch_dtype = torch.bfloat16 if is_torch_bf16_gpu_available() == True else torch.float16,
attn_implementation = "flash_attention_2" # spda
).to("cuda" if torch.cuda.is_available() else "cpu")
training_args = TrainingArguments(
bf16 = True,
optim = "paged_adamw_8bit",
warmup_ratio = 0.1,
weight_decay = 5e-5,
learning_rate = 5e-5,
num_train_epochs = 1,
per_device_eval_batch_size = 16,
per_device_train_batch_size = 8,
gradient_checkpointing = False,
)
trainer = Trainer(
args = training_args,
model = model,
data_collator = DataCollatorForLanguageModeling(
tokenizer = tokenizer,
mlm = True,
mlm_probability = 0.30,
pad_to_multiple_of = 8,
),
eval_dataset = eval_dataset,
train_dataset = train_dataset,
processing_class = tokenizer,
)
When all other parameters are kept consistent and only the attn_implementation is changed, the GPU memory usage rates are 48% and 88%, respectively.
When using FA2, the GPU memory usage is significantly higher than with SDPA and also much higher than with other traditional Bert-Like models, and there is no improvement in speed.
The same phenomenon has been observed on both Windows 11 24H2 and Ubuntu@WSL2.
ENVS: PyTorch 2.5.1 Python 3.12.8 flash_attn v2.7.2.post1
So is this a problem with my workflow or a bug?
In general I think the FA2 support on Windows is not well tested(https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). We only ever used Linux machines for the pre-training part.
In general I think the FA2 support on Windows is not well tested(https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). We only ever used Linux machines for the pre-training part.
Yes, it's possible. But as described in the main text, I also got the same results when running the training script in WSL2 Env. Since FA2 works correctly during training of other models, I suspect this might be a bug specific to the HF Trainer or a configuration error. I wonder if anyone else can verify this.
IIRC, compilation does not work properly on Windows/WSL, but this should not cause such a gap and should affect both path.
Could you try verifying that using attn_implementation = "flash_attention_2" indeed use the FA2 path?
As raised by @staghado, we mainly use Linux machines (and WSL is not totally equivalent to Linux machines), so it is a bit hard for us to debug.
Since FA2 works correctly during training of other models,
Could you give more information about which setup you are referring to? I believe Jina is the only encoder with FA. Also, could you give the full boilerplate, as if I recall correctly, @tomaarsen uses WSL sometimes, so maybe he can try running the boilerplate and see if he experiences the same behavior.
Could you try verifying that using attn_implementation = "flash_attention_2" indeed use the FA2 path?
Yes, the data in the main text represents the results after specifying attn_implementation = "flash_attention_2".
Could you give more information about which setup you are referring to? I believe Jina is the only encoder with FA.
I may not have expressed it clearly: when training other types of models (such as Qwen-2.5) in the same dependency environment, FA2 is indeed working normally, and a significant reduction in memory usage can be observed. However, I have not tried using FA2 in the training of other Bert-Like models.
Also, could you give the full boilerplate, as if I recall correctly, @tomaarsen uses WSL sometimes, so maybe he can try running the boilerplate and see if he experiences the same behavior.
This is the most simplified script after removing the data logic, which can reproduce the issue mentioned earlier. Place the plain text file containing the corpus, sample.txt, and the model folder modern_bert in the same directory, and then execute python sample.py.
.
├── modern_bert
│ ├── config.json
│ ├── model.safetensors
│ ├── special_tokens_map.json
│ ├── tokenizer.json
│ └── tokenizer_config.json
├── sample.py
└── sample.txt
Switch between FA2 and SDPA by modifying the constant ATTN_IMPLEMENTATION at the beginning of the script. The script contains some Chinese comments and logs, but I think they should not have any actual impact :)
IIRC, compilation does not work properly on Windows/WSL, but this should not cause such a gap and should affect both path. Could you try verifying that using
attn_implementation = "flash_attention_2"indeed use the FA2 path?As raised by @staghado, we mainly use Linux machines (and WSL is not totally equivalent to Linux machines), so it is a bit hard for us to debug.
Since FA2 works correctly during training of other models,
Could you give more information about which setup you are referring to? I believe Jina is the only encoder with FA. Also, could you give the full boilerplate, as if I recall correctly, @tomaarsen uses WSL sometimes, so maybe he can try running the boilerplate and see if he experiences the same behavior.
There is a new finding: Using the same weight file and system environment, when training for the downstream task (NER), FA2 can observe a significant increase in speed and reduction in memory usage. Compared with SDPA, the speed increases by +100% and the memory usage decreases by -50%. I think this is the normal performance of FA2 taking effect. I guess that some of the differential settings or steps in the two tasks of MLM and TokenClassification might be the reason for this difference. @NohTow
Still unable to identify the root cause of the issue, but a "silly" solution has been found. I observed that when FA2 is enabled and causes abnormal GPU memory usage, it doesn't immediately max out the memory in the first STEP. Instead, in the following few STEPS, there are several abnormal spikes in memory usage. I used the following code to manually clear the GPU memory, and it worked: after clearing the memory a few times at the start of training, the subsequent memory usage stabilized.
def clear_memory(self, threshold: float) -> None:
result = os.popen("nvidia-smi --query-gpu=memory.total,memory.reserved,memory.used --format csv,noheader,nounits").readlines()
result = result[0].strip().split(", ")
total = int(result[0])
used = int(result[1]) + int(result[2])
if used / total > threshold:
torch.cuda.empty_cache()
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.clear_memory(0.92)
Overall, I believe the abnormal GPU memory usage under the FA2 path does exist, but with the above method, the training process can now proceed relatively normally. I can continue to provide information to assist you in fully resolving the issue. Alternatively, if you believe this issue can be considered resolved with this approach and no further follow-up is needed, you can close this ISSUE.
Hello, Thanks for the thorough investigation! Clearing memory should not be required when training, so I think there is indeed a need for a cleaner fix. IIRC, I saw somewhere that you mentioned having ok results with MLM when using a Linux box? Or MLM with FA is being bad no matter the setup? Either it is an issue with MLM on WSL/Windows or an issue with MLM on HF period. The latter should be a bit more urgent to fix (and easy to debug as our setup should be able to reproduce).
Hello, Thanks for the thorough investigation! Clearing memory should not be required when training, so I think there is indeed a need for a cleaner fix. IIRC, I saw somewhere that you mentioned having ok results with MLM when using a Linux box? Or MLM with FA is being bad no matter the setup? Either it is an issue with MLM on WSL/Windows or an issue with MLM on HF period. The latter should be a bit more urgent to fix (and easy to debug as our setup should be able to reproduce).
Yes, I previously shared some test results, but because they couldn't be reliably reproduced, I later deleted that response.
Here's a summary of the information that can be stably reproduced:
- When debugging the training script on my personal computer (4070 Ti Super), whether on Windows, WSL, or native Ubuntu, the aforementioned VRAM anomaly issue occurs during MLM tasks using FA2 and HF Trainer.
- This issue is unrelated to the version of FA2, as it persists from version 2.6.3 to the latest 2.7.3.
- Across various configuration combinations, FA2's VRAM usage is consistently higher (+10% to +50%) than SDPA, while the speed is similar or slightly better, which does not align with the expected significant VRAM savings and speed improvements.
- If
torch_compile = Trueis explicitly set inTrainingArguments, the VRAM anomaly does not occur, and a significant speed improvement (20%-30%) can be observed. NOTE: This only works ifgradient_checkpointing = Trueis also set, otherwise, the error mentioned in this ISSUE will occur.
training_args = TrainingArguments(
...
torch_compile = True,
gradient_checkpointing = True,
...
)
Overall, I am inclined to believe that this issue is less likely related to the system and more likely due to some oversight in the specific implementation of hf trf.
Since ModernBert's default vocabulary is relatively small, the VRAM anomaly might not be very noticeable. However, it becomes more pronounced with my expanded East-Asian language vocabulary, which I am also uploading for reference. modern_bert_cjk.zip
Hello everyone, is there anything to update? I eventually completed the training using SDPA modern_bert_multilingual but it would be even better if the memory issue on FA2 could be resolved.