TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] FlashAttention: causal masking enforced in cross attention due to sliding window attention

Open Marks101 opened this issue 1 year ago • 1 comments

Hi transformer-engine team,

we noticed that in a decoder layer with self_attn_mask_type="causal", the following line https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py#L594 sets window_size = (-1, 0). This window_size is passed to both self attention as well as to cross attention. In the cross attention this encforces a causal masking although padding is intended. This bug was probably introduced in https://github.com/NVIDIA/TransformerEngine/pull/551.

I am unsure what a poper fix might be. Does a windowed attention really make sense for cross attention with padding? Considering that the FlashAttention backend re-packs the sequences with PackTensors, this might have unintended side effects.

Cheers Markus

Marks101 avatar Jan 25 '24 09:01 Marks101

@cyanguwa please take a look at this issue.

ptrendx avatar Jan 25 '24 21:01 ptrendx