TransformerEngine
TransformerEngine copied to clipboard
[C/PyTorch] Add THD support for cuDNN attention
Description
This PR adds THD support for fused attention (F16_arbitrary_seqlen
backend). This feature allows users to run attention for two more cases:
case 1: no padding between sequences
8 sequences are packed as abcddddeeeffgh
total sequence length t = 14
cumulative sequence lengths: cu_seqlens = [0, 1, 2, 3, 7, 10, 12, 13, 14]
offset tensor that helps cuDNN step into qkv tensors: offset_tensor = [0, 1, 2, 3, 7, 10, 12, 13]
case 2: with padding between sequences
8 sequences are packed as abc0ddddeee0ffgh
total sequence length t = 16
cumulative sequence lengths: cu_seqlens = [0, 1, 2, 3, 7, 10, 12, 13, 14]
offset tensor that helps cuDNN step into qkv tensors: offset_tensor = [0, 1, 2, 4, 8, 12, 14, 15]
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Changes
- Added THD support to cuDNN attention on the C level and in PyTorch modules. Jax and Paddle are to be worked on.
- Updated cudnn-frontend to v1.4.0.
- Added tests for
max_seqlen_q=1
andhead_dim=256
inference cases forthd
,bshd
, andsbhd
formats.
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
This PR changes the fused attention API. I'm wondering if we can also change cu_seqlens
arguments to actual_seqlens
, so that we can avoid the kernel to convert actual_seqlens
to cu_seqlens
.
Actually, cu_seqlens
as API arguments makes JAX do two additional kernels, which is mask -> actual_seqlens -> sharding (DP) -> kernel to convert actual_seqlens to cu_seqlens -> call nvte_fmha -> kernel to convert cu_seqlens to actual_seqlens
.
I had a commit to fix the JAX build issue with the new API.
This PR changes the fused attention API. I'm wondering if we can also change
cu_seqlens
arguments toactual_seqlens
, so that we can avoid the kernel to convertactual_seqlens
tocu_seqlens
.Actually,
cu_seqlens
as API arguments makes JAX do two additional kernels, which ismask -> actual_seqlens -> sharding (DP) -> kernel to convert actual_seqlens to cu_seqlens -> call nvte_fmha -> kernel to convert cu_seqlens to actual_seqlens
.
@zlsh80826 If mask -> cu_seqlens
instead of mask -> actual_seqlens
, would there still be two extra kernels? @ptrendx suggested that we make the cu_seqlens/actual_seqlens
related changes in a different PR in order to keep this PR focused on THD. Any burning needs to make the cu_seqlens/actual_seqlens
changes?
/te-ci pytorch
@zlsh80826 If
mask -> cu_seqlens
instead ofmask -> actual_seqlens
, would there still be two extra kernels? @ptrendx suggested that we make thecu_seqlens/actual_seqlens
related changes in a different PR in order to keep this PR focused on THD. Any burning needs to make thecu_seqlens/actual_seqlens
changes?
If changes to mask->actual_seqlens
then it will not have two extra kernels. I did profile the performance impact and the extra kenrels are less then 0.1% in the e2e training, so it is not hurry for now.
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
Hi @cyanguwa, the changes for common looks good for me. I will fix the JAX build this week.
/te-ci pytorch
/te-ci jax
/te-ci paddle
/te-ci jax
@cyanguwa I have a few questions when exposing the THD format on JAX. Could you take a look on cudnn slack channel before merging it? Thanks.