[Pytorch] CP + THD + chunked attention support.
Description
This PR introduces support for CP + THD + chunked attention
Fixes # (issue)
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refactoring
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
/te-ci pytorch L1
/te-ci pytorch L1
Hi @pggPL - Would you please say what is left to get this merged? To make it by TE2.10, it needs to be merged by 11/15.
Greptile Overview
Greptile Summary
This PR implements chunked attention support for Llama 4 in TransformerEngine-PyTorch. The implementation adds a chunk_size parameter that partitions sequences into fixed-size chunks to reduce memory usage during attention computation.
Key changes:
- Added CUDA kernel implementations for chunking sequences in THD format for context parallelism (
thd_chunkify_kernel,thd_chunkify_p2p_kernel, and diagonal/above/below kernels) - Modified attention forward/backward passes to support chunked computation for non-CP (bshd/sbhd/thd) and CP cases (thd with p2p only)
- Extended C++ and Python APIs to expose chunking utilities
- Added comprehensive test coverage for both CP and non-CP scenarios
Supported configurations:
- Non-CP: All three qkv_formats (bshd, sbhd, thd), FlashAttention/FusedAttention/UnfusedDotProductAttention backends
- CP: thd format only with p2p communication type, FlashAttention/FusedAttention backends
- Mask types:
no_mask,causal,causal_bottom_right(though code assertion differs from PR description for thd+chunked) - Bias types:
no_biasonly with chunked attention
Issues identified:
- Previous review thread already identified the
causal_bottom_rightmask inconsistency between PR description and implementation assertion - The implementation is complex with significant CUDA kernel code added (526 lines), requiring thorough testing
Confidence Score: 3/5
- This PR adds significant complexity with extensive CUDA kernel implementations and changes to critical attention paths; thorough testing and verification are needed before merge
- Score reflects: (1) large implementation with 1318 insertions including complex CUDA kernels that are difficult to verify without runtime testing, (2) previous review thread identified issues that should be addressed, (3) good test coverage added but the feature touches critical performance paths, (4) implementation appears sound architecturally but has intricate chunking logic for CP P2P cases
- Pay close attention to
transformer_engine/common/fused_attn/context_parallel.cu(complex CUDA kernels) andtransformer_engine/pytorch/attention/dot_product_attention/context_parallel.py(intricate CP logic with chunking)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/context_parallel.cu | 4/5 | Added 526 lines of CUDA kernel implementations for chunked attention support in CP: thd_chunkify_kernel, thd_chunkify_p2p_kernel, thd_seq_tweak_below_diag_kernel, and thd_seq_tweak_above_diag_kernel with corresponding host wrapper functions |
| transformer_engine/pytorch/csrc/extensions/attention.cpp | 4/5 | Added PyTorch C++ extension wrappers for chunked attention functions and modified tensor allocation to support negative infinity initialization for softmax buffers |
| transformer_engine/pytorch/attention/dot_product_attention/utils.py | 4/5 | Added 146 lines including chunk_size to AttentionParams, chunked attention detection logic in get_attention_backend, and utility functions thd_chunkify, thd_chunkify_p2p, thd_chunkify_p2p_below_diagonal, and thd_chunkify_p2p_above_diagonal |
| transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py | 3/5 | Added chunked attention parameter and implementation logic with tensor reshaping for bshd/sbhd formats and sequence chunkification for thd format, including validation assertions |
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 3/5 | Modified CP P2P forward and backward passes to support chunked attention by introducing per-step padded cu_seqlens and calling chunked attention utility functions for diagonal, lower-triangle, and upper-triangle sections |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant TL as TransformerLayer
participant MHA as MultiheadAttention
participant DPA as DotProductAttention
participant Utils as dpa_utils
participant CP as context_parallel
participant Tex as C++ Extensions
participant CUDA as CUDA Kernels
User->>TL: forward(chunk_size=128)
TL->>MHA: forward(chunk_size=128)
alt Non-CP Case (bshd/sbhd)
MHA->>MHA: reshape input to chunks
Note over MHA: (b*s/128, 128, h, d)
else CP Case (thd only)
MHA->>Utils: thd_chunkify(cu_seqlens, chunk_size)
Utils->>Tex: thd_chunkify()
Tex->>CUDA: thd_chunkify_kernel
CUDA-->>Utils: chunked cu_seqlens
end
MHA->>DPA: forward(chunk_size=128)
alt CP with P2P
DPA->>CP: cp_p2p_fwd_fused_attn()
loop For each P2P step
alt Diagonal section
CP->>Utils: thd_chunkify_p2p()
Utils->>CUDA: thd_chunkify_p2p_kernel
else Lower-triangle section
CP->>Utils: thd_chunkify_p2p_below_diagonal()
Utils->>CUDA: thd_seq_tweak_below_diag_kernel
else Upper-triangle section
CP->>Utils: thd_chunkify_p2p_above_diagonal()
Utils->>CUDA: thd_seq_tweak_above_diag_kernel
end
CP->>CP: fused_attn_fwd with chunked seqlens
end
CP-->>DPA: attention output
else Non-CP Case
DPA->>DPA: Regular attention with reshaped input
end
DPA-->>MHA: output
alt Non-CP Case (bshd/sbhd)
MHA->>MHA: reshape output back
Note over MHA: (b, s, h, d)
end
MHA-->>TL: attention output
TL-->>User: layer output
@yaoyu-33 and @xrennvidia, could you please help take another look at this PR? There's been some delays, but we're making a push to get it merged in 2.10 (i.e. end of this week)! Thanks for your help!
/te-ci pytorch L1
/te-ci pytorch L1