TransformerEngine
TransformerEngine copied to clipboard
Enable SWA with CP for THD input format
Description
Sliding Window Attention with CP for THD format is enabled with A2A communication.
Fixes # (issue)
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refactoring
Changes
Please list the changes introduced in this PR:
- SWA+THD+CP (using A2A)
- Filters that allow such config
/te-ci pytorch L0
/te-ci pytorch L0
I see you've been running L0 CI tests. Did you happen to test with all the CP tests in L1? Thanks.
/te-ci pytorch L1
/te-ci pytorch
/te-ci pytorch L1
Greptile Overview
Greptile Summary
This PR enables Sliding Window Attention (SWA) with Context Parallelism (CP) for the THD (total_tokens, num_heads, head_dim) input format using All-to-All (A2A) communication.
Key Changes:
- Implements two new reordering functions (
reorder_seq_chunks_before_a2a_after_attn_thdandreorder_seq_chunks_after_a2a_before_attn_thd) that handle THD-specific sequence chunking using DualChunking pattern - Updates
flash_attn_a2a_communicateto branch onqkv_formatand handle THD tensors differently from BSHD/SBHD formats - Adds padding logic to ensure all sequences are divisible by
2*cp_size(required for DualChunking across CP ranks) - Enables THD format by adding
cu_seqlens_paddedparameter to A2A communication functions - Updates test configurations to include new SWA configs (
cp_1_2andcp_1_4) and removes skip conditions for THD+A2A combinations - Dynamically converts mask type to
paddingorpadding_causalfor THD format in tests
Implementation Approach:
The implementation extends existing A2A context parallelism to support variable-length sequences (THD format). Sequences are padded to be divisible by 2*cp_size, then chunks are exchanged via A2A communication. The reordering functions handle the complex index calculations needed to reconstruct sequences correctly after communication.
Confidence Score: 4/5
- This PR is largely safe to merge with proper testing, though the complex index calculations warrant careful validation
- The implementation is well-structured and follows existing patterns for BSHD/SBHD formats. The padding logic ensures divisibility constraints are met. However, the score is 4/5 rather than 5/5 because: (1) the reordering functions involve complex nested list comprehensions with intricate index calculations that are difficult to verify by inspection alone, (2) previous comments identified potential integer division issues that could cause data loss if sequences aren't properly padded, and (3) the test coverage dynamically modifies shared config objects which could create test interdependencies. The implementation appears correct based on the padding logic in
run_attention_with_cp.py:93, but comprehensive integration testing is essential to validate correctness across various sequence lengths and CP sizes. - Pay close attention to
context_parallel.py- the new reordering functions have complex index arithmetic that should be validated with integration tests across different sequence lengths and CP sizes
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 4/5 | Adds two new THD-specific reordering functions for A2A communication and updates existing A2A functions to support THD format with padding |
| tests/pytorch/attention/test_attention_with_cp.py | 4/5 | Enables THD+CP+A2A testing by adding new test configs, updating skip conditions, and dynamically changing mask type to padding for THD format |
| tests/pytorch/attention/run_attention_with_cp.py | 5/5 | Refactors THD input shape generation to properly pad sequences to be divisible by 2*cp_size and fixes cu_seqlens calculation |
Sequence Diagram
sequenceDiagram
participant Input as Input Tensors (Q/K/V)
participant Reshape as Reshape & Split
participant A2A_Before as A2A Communication (Before Attn)
participant Reorder_Before as Reorder (Before Attn)
participant Attn as Attention Compute
participant Reorder_After as Reorder (After Attn)
participant A2A_After as A2A Communication (After Attn)
participant Output as Output Tensor
Note over Input,Output: THD Format: [total_tokens, num_heads, head_dim]
Input->>Reshape: Split heads across CP ranks
Note over Reshape: [t, np, hn] -> [t, cp, np//cp, hn] -> [cp, t, np//cp, hn]
Reshape->>A2A_Before: All-to-All exchange
Note over A2A_Before: Each rank gets chunks from all other ranks
A2A_Before->>Reorder_Before: Flatten [cp, t, np//cp, hn] -> [cp*t, np//cp, hn]
Note over Reorder_Before: reorder_seq_chunks_after_a2a_before_attn_thd()
Note over Reorder_Before: Reconstructs sequences using DualChunking pattern
Reorder_Before->>Attn: Compute attention with cu_seqlens_padded
Note over Attn: Sequences padded to be divisible by 2*cp_size
Attn->>Reorder_After: Output [cp*t, np//cp, hn]
Note over Reorder_After: reorder_seq_chunks_before_a2a_after_attn_thd()
Note over Reorder_After: Prepare chunks for reverse A2A
Reorder_After->>A2A_After: Reshape to [cp, t, np//cp, hn]
A2A_After->>Output: All-to-All exchange back
Note over Output: [t, cp, np//cp, hn] -> [t, np, hn]
/te-ci pytorch L1
/te-ci pytorch L1
/te-ci pytorch L1
/te-ci pytorch L1
/te-ci pytorch L1
/te-ci pytorch L1
/te-ci pytorch L1
/te-ci pytorch L1
/te-ci pytorch L1