transformers
transformers copied to clipboard
[performance] ensure `causal_mask` is created directly on device
What does this PR do?
@tjruwase and @tohtana discovered that causal_mask is currently being created on CPU then moved to GPU during the forward pass of OPT (and we think other models). This appears to be causing a significant performance degradation on multi-gpu environments due to parallel host to device copies going on. It's not 100% clear to us why this is so bad but here is what we observe before and after this patch:
Before this patch w. OPT-125m on x8 A100s:

After the patch:

These numbers were gathered from a modified version of https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py but turning on wall_clock_breakdown: true in our deepspeed config.
One major complication we see in accepting this PR is that the two functions being modified are copied across lots of different models and the make fix-copies script doesn't seem to address all of them correctly across both _make_causal_mask and _prepare_decoder_attention_mask
Who can review?
Tagging @sgugger and @stas00 to help triage to the right people
The documentation is not available anymore as the PR was closed or merged.
cc @thomasw21 @NouamaneTazi since both of you are experts on this kind of things - to see if you have any general opinion and/or if you would like to review this PR too.
@jeffra Would it possible for you (and/or @tjruwase and @tohtana) to provide your script that finds/measures/profiles the running time for this issue 🙏 . It would be super helpful for us to dive into internally too.
LGTM, thanks a lot for the fix! Note that the same modification needs to be applied to BART (since OPT copies from BART) in order for all quality checks to pass.
FYI (@sgugger) : @stas00 mentioned on Slack
I tried to support Jeff to tell him to how make copies but he found that many copies are either not tagged properly or the copied functions were completely renamed and thus it's very difficult to make an automatedtransformers-wide fix
and in this PR description, the author(s)
One major complication we see in accepting this PR is that the two functions being modified are copied across lots of different models and the make fix-copies script doesn't seem to address all of them correctly across both _make_causal_mask and _prepare_decoder_attention_mask
It's likely that they expect us to help on this part. I can help (I was waiting for the approval for the fix in OPT which is done now.)
I think just copying the same fix to BART and then applying make fix-copies is simple enough for this PR. Dealing with functions that are not copies or are named differently can indeed be done in followup PRs.
Ok, i've updated the BART implementation and attempted to get make fix-copies to work for me but I think I might be doing something wrong. Some of the original issues I saw are now fixed on other models (e.g., https://github.com/huggingface/transformers/pull/22382 adds a # Copied from tag for llama). However, I am still seeing issues i think coming from the fix-up scripts getting confused with the function signature change of _make_causal_mask. Also, I added the # Copied from tag into opt for _make_causal_mask which was part of my previous issue i think.
Can someone try make fix-copies on their side with this? You should be able to push to my branch.
For example, here's the diff of src/transformers/models/xglm/modeling_xglm.py after applying make fix-copies in this branch, it does not add device as an argument to _make_causal_mask:
diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py
index 8a1955793..59851bd85 100755
--- a/src/transformers/models/xglm/modeling_xglm.py
+++ b/src/transformers/models/xglm/modeling_xglm.py
@@ -119,13 +119,13 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
- mask_cond = torch.arange(mask.size(-1))
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
It modifies all of these models, so ideally don't want to edit these manually :)
modified: src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
modified: src/transformers/models/biogpt/modeling_biogpt.py
modified: src/transformers/models/blenderbot/modeling_blenderbot.py
modified: src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
modified: src/transformers/models/informer/modeling_informer.py
modified: src/transformers/models/llama/modeling_llama.py
modified: src/transformers/models/m2m_100/modeling_m2m_100.py
modified: src/transformers/models/marian/modeling_marian.py
modified: src/transformers/models/mbart/modeling_mbart.py
modified: src/transformers/models/mvp/modeling_mvp.py
modified: src/transformers/models/nllb_moe/modeling_nllb_moe.py
modified: src/transformers/models/pegasus/modeling_pegasus.py
modified: src/transformers/models/pegasus_x/modeling_pegasus_x.py
modified: src/transformers/models/plbart/modeling_plbart.py
modified: src/transformers/models/speech_to_text/modeling_speech_to_text.py
modified: src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
modified: src/transformers/models/speecht5/modeling_speecht5.py
modified: src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
modified: src/transformers/models/trocr/modeling_trocr.py
modified: src/transformers/models/whisper/modeling_whisper.py
modified: src/transformers/models/xglm/modeling_xglm.py
Ah yes, make fix-copies does not change the signature of the function so that is indeed something to edit manually. If it's too much work I can try to push this to your branch tomorrow.
Ah yes,
make fix-copiesdoes not change the signature of the function so that is indeed something to edit manually. If it's too much work I can try to push this to your branch tomorrow.
Sounds good, I might have some time this afternoon for this. Otherwise feel free to do it :) Just wasn't sure if this was an expected issue with the copy scripts or not.
Okay all the models should be fixed now, make fixup is clear on my local tests.