TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[C/PyTorch] Add THD support for cuDNN attention

Open cyanguwa opened this issue 9 months ago • 7 comments

Description

This PR adds THD support for fused attention (F16_arbitrary_seqlen backend). This feature allows users to run attention for two more cases:

case 1: no padding between sequences
    8 sequences are packed as abcddddeeeffgh
    total sequence length t = 14
    cumulative sequence lengths: cu_seqlens = [0, 1, 2, 3, 7, 10, 12, 13, 14]
    offset tensor that helps cuDNN step into qkv tensors: offset_tensor = [0, 1, 2, 3, 7, 10, 12, 13]

case 2: with padding between sequences
    8 sequences are packed as abc0ddddeee0ffgh
    total sequence length t = 16
    cumulative sequence lengths: cu_seqlens = [0, 1, 2, 3, 7, 10, 12, 13, 14]
    offset tensor that helps cuDNN step into qkv tensors: offset_tensor = [0, 1, 2, 4, 8, 12, 14, 15]

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)

Changes

  • Added THD support to cuDNN attention on the C level and in PyTorch modules. Jax and Paddle are to be worked on.
  • Updated cudnn-frontend to v1.4.0.
  • Added tests for max_seqlen_q=1 and head_dim=256 inference cases for thd, bshd, and sbhd formats.

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

cyanguwa avatar May 02 '24 23:05 cyanguwa

This PR changes the fused attention API. I'm wondering if we can also change cu_seqlens arguments to actual_seqlens, so that we can avoid the kernel to convert actual_seqlens to cu_seqlens.

Actually, cu_seqlens as API arguments makes JAX do two additional kernels, which is mask -> actual_seqlens -> sharding (DP) -> kernel to convert actual_seqlens to cu_seqlens -> call nvte_fmha -> kernel to convert cu_seqlens to actual_seqlens.

zlsh80826 avatar May 03 '24 03:05 zlsh80826

I had a commit to fix the JAX build issue with the new API.

zlsh80826 avatar May 03 '24 10:05 zlsh80826

This PR changes the fused attention API. I'm wondering if we can also change cu_seqlens arguments to actual_seqlens, so that we can avoid the kernel to convert actual_seqlens to cu_seqlens.

Actually, cu_seqlens as API arguments makes JAX do two additional kernels, which is mask -> actual_seqlens -> sharding (DP) -> kernel to convert actual_seqlens to cu_seqlens -> call nvte_fmha -> kernel to convert cu_seqlens to actual_seqlens.

@zlsh80826 If mask -> cu_seqlens instead of mask -> actual_seqlens, would there still be two extra kernels? @ptrendx suggested that we make the cu_seqlens/actual_seqlens related changes in a different PR in order to keep this PR focused on THD. Any burning needs to make the cu_seqlens/actual_seqlens changes?

cyanguwa avatar May 15 '24 21:05 cyanguwa

/te-ci pytorch

cyanguwa avatar May 16 '24 22:05 cyanguwa

@zlsh80826 If mask -> cu_seqlens instead of mask -> actual_seqlens, would there still be two extra kernels? @ptrendx suggested that we make the cu_seqlens/actual_seqlens related changes in a different PR in order to keep this PR focused on THD. Any burning needs to make the cu_seqlens/actual_seqlens changes?

If changes to mask->actual_seqlens then it will not have two extra kernels. I did profile the performance impact and the extra kenrels are less then 0.1% in the e2e training, so it is not hurry for now.

zlsh80826 avatar May 17 '24 01:05 zlsh80826

/te-ci pytorch

cyanguwa avatar May 17 '24 20:05 cyanguwa

/te-ci pytorch

cyanguwa avatar May 17 '24 20:05 cyanguwa

/te-ci pytorch

cyanguwa avatar May 20 '24 20:05 cyanguwa

Hi @cyanguwa, the changes for common looks good for me. I will fix the JAX build this week.

zlsh80826 avatar May 23 '24 11:05 zlsh80826

/te-ci pytorch

cyanguwa avatar May 23 '24 22:05 cyanguwa

/te-ci jax

cyanguwa avatar May 24 '24 22:05 cyanguwa

/te-ci paddle

cyanguwa avatar May 24 '24 22:05 cyanguwa

/te-ci jax

cyanguwa avatar May 28 '24 20:05 cyanguwa

@cyanguwa I have a few questions when exposing the THD format on JAX. Could you take a look on cudnn slack channel before merging it? Thanks.

zlsh80826 avatar May 29 '24 10:05 zlsh80826