TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Definition of attention mask
Hi team,
I am wondering about the definition of the attention mask in transformer-engine. I did not find an explanation in the docs. Does True mean that the position takes part in attention or that it is masked out?
These two code positions suggest that True means masked out: https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/test_numerics.py#L71 https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/utils.py#L34
But for flash attention with padding, the cumulated sequence lengths are computed based on a sum over the mask. This suggests that it is vice versa: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L180
Thanks for clarifying this!
@ptrendx @ksivaman It would be really great if you could give us some feedback on this. We are currently noticing that due to this UnfusedDotProductAttention
(uses attention_mask_func
) and FlashAttention
behave differently if a mask is provided.
I think True
in a mask should mean 'masked out', but maybe we are not doing that consistently in our code. I can take a look at this and change the sum to the opposite. Will change anywhere else that calls get_cu_seqlens_and_indices()
or get_cu_seqlens()
as well. Thanks.
Hi @cyanguwa. Thanks for looking into this 😃
Just one additional note: it is actually not entirely clear that True
means masked out. This seems to be the convention from Megatron-LM, but for example torch.nn.functional.scaled_dot_product_attention defines it the other way round.
But we would be totally fine either way.