TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Enable SWA with CP for THD input format

Open sudhakarsingh27 opened this issue 2 months ago • 11 comments

Description

Sliding Window Attention with CP for THD format is enabled with A2A communication.

Fixes # (issue)

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refactoring

Changes

Please list the changes introduced in this PR:

  • SWA+THD+CP (using A2A)
  • Filters that allow such config

sudhakarsingh27 avatar Sep 30 '25 22:09 sudhakarsingh27

/te-ci pytorch L0

sudhakarsingh27 avatar Oct 09 '25 12:10 sudhakarsingh27

/te-ci pytorch L0

sudhakarsingh27 avatar Oct 09 '25 19:10 sudhakarsingh27

I see you've been running L0 CI tests. Did you happen to test with all the CP tests in L1? Thanks.

cyanguwa avatar Oct 24 '25 12:10 cyanguwa

/te-ci pytorch L1

sudhakarsingh27 avatar Oct 27 '25 11:10 sudhakarsingh27

/te-ci pytorch

sudhakarsingh27 avatar Oct 27 '25 14:10 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 12 '25 18:11 sudhakarsingh27

Greptile Overview

Greptile Summary

This PR enables Sliding Window Attention (SWA) with Context Parallelism (CP) for the THD (total_tokens, num_heads, head_dim) input format using All-to-All (A2A) communication.

Key Changes:

  • Implements two new reordering functions (reorder_seq_chunks_before_a2a_after_attn_thd and reorder_seq_chunks_after_a2a_before_attn_thd) that handle THD-specific sequence chunking using DualChunking pattern
  • Updates flash_attn_a2a_communicate to branch on qkv_format and handle THD tensors differently from BSHD/SBHD formats
  • Adds padding logic to ensure all sequences are divisible by 2*cp_size (required for DualChunking across CP ranks)
  • Enables THD format by adding cu_seqlens_padded parameter to A2A communication functions
  • Updates test configurations to include new SWA configs (cp_1_2 and cp_1_4) and removes skip conditions for THD+A2A combinations
  • Dynamically converts mask type to padding or padding_causal for THD format in tests

Implementation Approach: The implementation extends existing A2A context parallelism to support variable-length sequences (THD format). Sequences are padded to be divisible by 2*cp_size, then chunks are exchanged via A2A communication. The reordering functions handle the complex index calculations needed to reconstruct sequences correctly after communication.

Confidence Score: 4/5

  • This PR is largely safe to merge with proper testing, though the complex index calculations warrant careful validation
  • The implementation is well-structured and follows existing patterns for BSHD/SBHD formats. The padding logic ensures divisibility constraints are met. However, the score is 4/5 rather than 5/5 because: (1) the reordering functions involve complex nested list comprehensions with intricate index calculations that are difficult to verify by inspection alone, (2) previous comments identified potential integer division issues that could cause data loss if sequences aren't properly padded, and (3) the test coverage dynamically modifies shared config objects which could create test interdependencies. The implementation appears correct based on the padding logic in run_attention_with_cp.py:93, but comprehensive integration testing is essential to validate correctness across various sequence lengths and CP sizes.
  • Pay close attention to context_parallel.py - the new reordering functions have complex index arithmetic that should be validated with integration tests across different sequence lengths and CP sizes

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 4/5 Adds two new THD-specific reordering functions for A2A communication and updates existing A2A functions to support THD format with padding
tests/pytorch/attention/test_attention_with_cp.py 4/5 Enables THD+CP+A2A testing by adding new test configs, updating skip conditions, and dynamically changing mask type to padding for THD format
tests/pytorch/attention/run_attention_with_cp.py 5/5 Refactors THD input shape generation to properly pad sequences to be divisible by 2*cp_size and fixes cu_seqlens calculation

Sequence Diagram

sequenceDiagram
    participant Input as Input Tensors (Q/K/V)
    participant Reshape as Reshape & Split
    participant A2A_Before as A2A Communication (Before Attn)
    participant Reorder_Before as Reorder (Before Attn)
    participant Attn as Attention Compute
    participant Reorder_After as Reorder (After Attn)
    participant A2A_After as A2A Communication (After Attn)
    participant Output as Output Tensor

    Note over Input,Output: THD Format: [total_tokens, num_heads, head_dim]
    
    Input->>Reshape: Split heads across CP ranks
    Note over Reshape: [t, np, hn] -> [t, cp, np//cp, hn] -> [cp, t, np//cp, hn]
    
    Reshape->>A2A_Before: All-to-All exchange
    Note over A2A_Before: Each rank gets chunks from all other ranks
    
    A2A_Before->>Reorder_Before: Flatten [cp, t, np//cp, hn] -> [cp*t, np//cp, hn]
    Note over Reorder_Before: reorder_seq_chunks_after_a2a_before_attn_thd()
    Note over Reorder_Before: Reconstructs sequences using DualChunking pattern
    
    Reorder_Before->>Attn: Compute attention with cu_seqlens_padded
    Note over Attn: Sequences padded to be divisible by 2*cp_size
    
    Attn->>Reorder_After: Output [cp*t, np//cp, hn]
    Note over Reorder_After: reorder_seq_chunks_before_a2a_after_attn_thd()
    Note over Reorder_After: Prepare chunks for reverse A2A
    
    Reorder_After->>A2A_After: Reshape to [cp, t, np//cp, hn]
    
    A2A_After->>Output: All-to-All exchange back
    Note over Output: [t, cp, np//cp, hn] -> [t, np, hn]

greptile-apps[bot] avatar Nov 12 '25 18:11 greptile-apps[bot]

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 13 '25 01:11 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 13 '25 18:11 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 13 '25 22:11 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 13 '25 23:11 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 14 '25 23:11 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 15 '25 02:11 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 17 '25 17:11 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 18 '25 21:11 sudhakarsingh27

/te-ci pytorch L1

sudhakarsingh27 avatar Nov 20 '25 21:11 sudhakarsingh27