TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch/Jax] Fix attention mask definition, and sliding window for decoder

Open cyanguwa opened this issue 1 year ago • 7 comments

Description

This PR helps resolve issues #614 and #629.

Moving forward, we'd like to define attention mask consistently in PyTorch, Jax and Paddle as True being masking out the corresponding position and False being allowing that position to participate in attention. A typical causal mask would look like:

0, 1, 1, 1
0, 0, 1, 1
0, 0, 0, 1
0, 0, 0, 0

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [ ] New feature (non-breaking change which adds functionality)
  • [x] Breaking change (fix or feature that would cause existing functionality to not work as expected)

Changes

  • Unify the implementation and unit tests regarding attention mask in TE PyTorch, Jax and Paddle.
  • Differentiate sliding widow initialization for decoder and encoder
  • Save aux_ctx_tensors and similar tensors using save_for_backward call instead of to ctx

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

cyanguwa avatar Apr 26 '24 22:04 cyanguwa

@timmoon10 thanks for the table :smiley_cat: I was doing something like that myself (OCD? not if someone else is also doing it!).

After looking through some of the other frameworks' implementations (PyTorch/Jax/Paddle), and our own implementation and unit tests, I'm leaning towards True = masking out; False = inclusion in attention actually.

I don't think Keras is a popular framework now and that we should follow its trend. Also, PyTorch's scaled_dot_product_attention is a beta feature but its MultiHeadAttention is stable. cuDNN's attention doesn't really use a mask in my opinion, not the True/False kind anyway. For causal, it's using a is_causal=True/False flag, and for padding, a cu_seqlens tensor (which we converted the mask to in TE).

Our fused softmax implementation uses True = masking out; False = inclusion in attention, which also makes all the softmax calls in TE PyTorch, Jax and Paddle all subscribe to that definition. They are not the fastest attention paths, so maybe we don't have to care about them as much. But, I think the original term when it came out, for causal, is upper triangle masking, which kind of made it upper rather than lower, and that 1 means masking things out. This is probably also why apex and Tri Dao attention's softmax both adopted that definition.

Any thoughts? :smiley: thanks!

cyanguwa avatar Apr 30 '24 23:04 cyanguwa

/te-ci

cyanguwa avatar May 02 '24 22:05 cyanguwa

Hi @cyanguwa, thanks for making things clear. TE-JAX awares mask_bit = True to mask out, and it is documented, so I think JAX's unit tests don't need to be changed in this PR.

zlsh80826 avatar May 03 '24 03:05 zlsh80826

@cyanguwa Thank you so much for looking into this and clarifying the mask definition ❤️

Marks101 avatar May 07 '24 06:05 Marks101

/te-ci pytorch

cyanguwa avatar May 14 '24 02:05 cyanguwa

/te-ci jax

cyanguwa avatar May 14 '24 22:05 cyanguwa

/te-ci pytorch

cyanguwa avatar May 15 '24 17:05 cyanguwa

/te-ci jax

cyanguwa avatar May 15 '24 17:05 cyanguwa

/te-ci pytorch

cyanguwa avatar May 16 '24 19:05 cyanguwa