transformers icon indicating copy to clipboard operation
transformers copied to clipboard

[performance] ensure `causal_mask` is created directly on device

Open jeffra opened this issue 2 years ago • 3 comments

What does this PR do?

@tjruwase and @tohtana discovered that causal_mask is currently being created on CPU then moved to GPU during the forward pass of OPT (and we think other models). This appears to be causing a significant performance degradation on multi-gpu environments due to parallel host to device copies going on. It's not 100% clear to us why this is so bad but here is what we observe before and after this patch:

Before this patch w. OPT-125m on x8 A100s: image

After the patch: image

These numbers were gathered from a modified version of https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py but turning on wall_clock_breakdown: true in our deepspeed config.

One major complication we see in accepting this PR is that the two functions being modified are copied across lots of different models and the make fix-copies script doesn't seem to address all of them correctly across both _make_causal_mask and _prepare_decoder_attention_mask

Who can review?

Tagging @sgugger and @stas00 to help triage to the right people

jeffra avatar Mar 25 '23 00:03 jeffra

The documentation is not available anymore as the PR was closed or merged.

cc @thomasw21 @NouamaneTazi since both of you are experts on this kind of things - to see if you have any general opinion and/or if you would like to review this PR too.

ydshieh avatar Mar 25 '23 04:03 ydshieh

@jeffra Would it possible for you (and/or @tjruwase and @tohtana) to provide your script that finds/measures/profiles the running time for this issue 🙏 . It would be super helpful for us to dive into internally too.

ydshieh avatar Mar 25 '23 09:03 ydshieh

LGTM, thanks a lot for the fix! Note that the same modification needs to be applied to BART (since OPT copies from BART) in order for all quality checks to pass.

FYI (@sgugger) : @stas00 mentioned on Slack

I tried to support Jeff to tell him to how make copies but he found that many copies are either not tagged properly or the copied functions were completely renamed and thus it's very difficult to make an automatedtransformers-wide fix

and in this PR description, the author(s)

One major complication we see in accepting this PR is that the two functions being modified are copied across lots of different models and the make fix-copies script doesn't seem to address all of them correctly across both _make_causal_mask and _prepare_decoder_attention_mask

It's likely that they expect us to help on this part. I can help (I was waiting for the approval for the fix in OPT which is done now.)

ydshieh avatar Mar 27 '23 13:03 ydshieh

I think just copying the same fix to BART and then applying make fix-copies is simple enough for this PR. Dealing with functions that are not copies or are named differently can indeed be done in followup PRs.

sgugger avatar Mar 27 '23 14:03 sgugger

Ok, i've updated the BART implementation and attempted to get make fix-copies to work for me but I think I might be doing something wrong. Some of the original issues I saw are now fixed on other models (e.g., https://github.com/huggingface/transformers/pull/22382 adds a # Copied from tag for llama). However, I am still seeing issues i think coming from the fix-up scripts getting confused with the function signature change of _make_causal_mask. Also, I added the # Copied from tag into opt for _make_causal_mask which was part of my previous issue i think.

Can someone try make fix-copies on their side with this? You should be able to push to my branch.

For example, here's the diff of src/transformers/models/xglm/modeling_xglm.py after applying make fix-copies in this branch, it does not add device as an argument to _make_causal_mask:

diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py
index 8a1955793..59851bd85 100755
--- a/src/transformers/models/xglm/modeling_xglm.py
+++ b/src/transformers/models/xglm/modeling_xglm.py
@@ -119,13 +119,13 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
     Make causal mask used for bi-directional self-attention.
     """
     bsz, tgt_len = input_ids_shape
-    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
-    mask_cond = torch.arange(mask.size(-1))
+    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+    mask_cond = torch.arange(mask.size(-1), device=device)
     mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
     mask = mask.to(dtype)

     if past_key_values_length > 0:
-        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
+        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
     return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

It modifies all of these models, so ideally don't want to edit these manually :)

        modified:   src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
        modified:   src/transformers/models/biogpt/modeling_biogpt.py
        modified:   src/transformers/models/blenderbot/modeling_blenderbot.py
        modified:   src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
        modified:   src/transformers/models/informer/modeling_informer.py
        modified:   src/transformers/models/llama/modeling_llama.py
        modified:   src/transformers/models/m2m_100/modeling_m2m_100.py
        modified:   src/transformers/models/marian/modeling_marian.py
        modified:   src/transformers/models/mbart/modeling_mbart.py
        modified:   src/transformers/models/mvp/modeling_mvp.py
        modified:   src/transformers/models/nllb_moe/modeling_nllb_moe.py
        modified:   src/transformers/models/pegasus/modeling_pegasus.py
        modified:   src/transformers/models/pegasus_x/modeling_pegasus_x.py
        modified:   src/transformers/models/plbart/modeling_plbart.py
        modified:   src/transformers/models/speech_to_text/modeling_speech_to_text.py
        modified:   src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
        modified:   src/transformers/models/speecht5/modeling_speecht5.py
        modified:   src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
        modified:   src/transformers/models/trocr/modeling_trocr.py
        modified:   src/transformers/models/whisper/modeling_whisper.py
        modified:   src/transformers/models/xglm/modeling_xglm.py

jeffra avatar Mar 27 '23 19:03 jeffra

Ah yes, make fix-copies does not change the signature of the function so that is indeed something to edit manually. If it's too much work I can try to push this to your branch tomorrow.

sgugger avatar Mar 27 '23 19:03 sgugger

Ah yes, make fix-copies does not change the signature of the function so that is indeed something to edit manually. If it's too much work I can try to push this to your branch tomorrow.

Sounds good, I might have some time this afternoon for this. Otherwise feel free to do it :) Just wasn't sure if this was an expected issue with the copy scripts or not.

jeffra avatar Mar 27 '23 19:03 jeffra

Okay all the models should be fixed now, make fixup is clear on my local tests.

jeffra avatar Mar 28 '23 01:03 jeffra