TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] FlashAttention: causal masking enforced in cross attention due to sliding window attention
Hi transformer-engine team,
we noticed that in a decoder layer with self_attn_mask_type="causal"
, the following line https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py#L594 sets window_size = (-1, 0). This window_size is passed to both self attention as well as to cross attention. In the cross attention this encforces a causal masking although padding is intended. This bug was probably introduced in https://github.com/NVIDIA/TransformerEngine/pull/551.
I am unsure what a poper fix might be. Does a windowed attention really make sense for cross attention with padding? Considering that the FlashAttention backend re-packs the sequences with PackTensors
, this might have unintended side effects.
Cheers Markus
@cyanguwa please take a look at this issue.