TransformerEngine
TransformerEngine copied to clipboard
[PyTorch/Jax] Fix attention mask definition, and sliding window for decoder
Description
This PR helps resolve issues #614 and #629.
Moving forward, we'd like to define attention mask consistently in PyTorch, Jax and Paddle as True being masking out the corresponding position and False being allowing that position to participate in attention. A typical causal mask would look like:
0, 1, 1, 1
0, 0, 1, 1
0, 0, 0, 1
0, 0, 0, 0
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Changes
- Unify the implementation and unit tests regarding attention mask in TE PyTorch, Jax and Paddle.
- Differentiate sliding widow initialization for decoder and encoder
- Save
aux_ctx_tensorsand similar tensors usingsave_for_backwardcall instead of toctx
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
For reference, most frameworks treat attention mask Trues as inclusion:
| Implementation | True in attention mask implies inclusion |
|---|---|
torch.nn.MultiheadAttention |
No |
torch.nn.functional.scaled_dot_product_attention |
Yes |
flax.linen.MultiHeadAttention |
Yes |
paddle.nn.MultiHeadAttention |
Yes |
keras.layers.Attention |
Yes |
tf.keras.layers.Attention |
Yes |
| cuDNN fused attention | Yes |
apex.contrib.multihead_attn.SelfMultiheadAttn |
No |
flash_attn.fused_softmax.ScaledMaskedSoftmax |
No |
@timmoon10 thanks for the table :smiley_cat: I was doing something like that myself (OCD? not if someone else is also doing it!).
After looking through some of the other frameworks' implementations (PyTorch/Jax/Paddle), and our own implementation and unit tests, I'm leaning towards True = masking out; False = inclusion in attention actually.
I don't think Keras is a popular framework now and that we should follow its trend. Also, PyTorch's scaled_dot_product_attention is a beta feature but its MultiHeadAttention is stable. cuDNN's attention doesn't really use a mask in my opinion, not the True/False kind anyway. For causal, it's using a is_causal=True/False flag, and for padding, a cu_seqlens tensor (which we converted the mask to in TE).
Our fused softmax implementation uses True = masking out; False = inclusion in attention, which also makes all the softmax calls in TE PyTorch, Jax and Paddle all subscribe to that definition. They are not the fastest attention paths, so maybe we don't have to care about them as much. But, I think the original term when it came out, for causal, is upper triangle masking, which kind of made it upper rather than lower, and that 1 means masking things out. This is probably also why apex and Tri Dao attention's softmax both adopted that definition.
Any thoughts? :smiley: thanks!
/te-ci
Hi @cyanguwa, thanks for making things clear. TE-JAX awares mask_bit = True to mask out, and it is documented, so I think JAX's unit tests don't need to be changed in this PR.
@cyanguwa Thank you so much for looking into this and clarifying the mask definition ❤️
/te-ci pytorch
/te-ci jax
/te-ci pytorch
/te-ci jax
/te-ci pytorch