TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[Pytorch] CP + THD + chunked attention support.

Open pggPL opened this issue 9 months ago • 3 comments

Description

This PR introduces support for CP + THD + chunked attention

Fixes # (issue)

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refactoring

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

pggPL avatar Jun 17 '25 13:06 pggPL

/te-ci pytorch L1

pggPL avatar Jul 16 '25 18:07 pggPL

/te-ci pytorch L1

pggPL avatar Jul 17 '25 10:07 pggPL

Hi @pggPL - Would you please say what is left to get this merged? To make it by TE2.10, it needs to be merged by 11/15.

nvMelissa avatar Oct 30 '25 18:10 nvMelissa

Greptile Overview

Greptile Summary

This PR implements chunked attention support for Llama 4 in TransformerEngine-PyTorch. The implementation adds a chunk_size parameter that partitions sequences into fixed-size chunks to reduce memory usage during attention computation.

Key changes:

  • Added CUDA kernel implementations for chunking sequences in THD format for context parallelism (thd_chunkify_kernel, thd_chunkify_p2p_kernel, and diagonal/above/below kernels)
  • Modified attention forward/backward passes to support chunked computation for non-CP (bshd/sbhd/thd) and CP cases (thd with p2p only)
  • Extended C++ and Python APIs to expose chunking utilities
  • Added comprehensive test coverage for both CP and non-CP scenarios

Supported configurations:

  • Non-CP: All three qkv_formats (bshd, sbhd, thd), FlashAttention/FusedAttention/UnfusedDotProductAttention backends
  • CP: thd format only with p2p communication type, FlashAttention/FusedAttention backends
  • Mask types: no_mask, causal, causal_bottom_right (though code assertion differs from PR description for thd+chunked)
  • Bias types: no_bias only with chunked attention

Issues identified:

  • Previous review thread already identified the causal_bottom_right mask inconsistency between PR description and implementation assertion
  • The implementation is complex with significant CUDA kernel code added (526 lines), requiring thorough testing

Confidence Score: 3/5

  • This PR adds significant complexity with extensive CUDA kernel implementations and changes to critical attention paths; thorough testing and verification are needed before merge
  • Score reflects: (1) large implementation with 1318 insertions including complex CUDA kernels that are difficult to verify without runtime testing, (2) previous review thread identified issues that should be addressed, (3) good test coverage added but the feature touches critical performance paths, (4) implementation appears sound architecturally but has intricate chunking logic for CP P2P cases
  • Pay close attention to transformer_engine/common/fused_attn/context_parallel.cu (complex CUDA kernels) and transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py (intricate CP logic with chunking)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/context_parallel.cu 4/5 Added 526 lines of CUDA kernel implementations for chunked attention support in CP: thd_chunkify_kernel, thd_chunkify_p2p_kernel, thd_seq_tweak_below_diag_kernel, and thd_seq_tweak_above_diag_kernel with corresponding host wrapper functions
transformer_engine/pytorch/csrc/extensions/attention.cpp 4/5 Added PyTorch C++ extension wrappers for chunked attention functions and modified tensor allocation to support negative infinity initialization for softmax buffers
transformer_engine/pytorch/attention/dot_product_attention/utils.py 4/5 Added 146 lines including chunk_size to AttentionParams, chunked attention detection logic in get_attention_backend, and utility functions thd_chunkify, thd_chunkify_p2p, thd_chunkify_p2p_below_diagonal, and thd_chunkify_p2p_above_diagonal
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py 3/5 Added chunked attention parameter and implementation logic with tensor reshaping for bshd/sbhd formats and sequence chunkification for thd format, including validation assertions
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py 3/5 Modified CP P2P forward and backward passes to support chunked attention by introducing per-step padded cu_seqlens and calling chunked attention utility functions for diagonal, lower-triangle, and upper-triangle sections

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant TL as TransformerLayer
    participant MHA as MultiheadAttention
    participant DPA as DotProductAttention
    participant Utils as dpa_utils
    participant CP as context_parallel
    participant Tex as C++ Extensions
    participant CUDA as CUDA Kernels

    User->>TL: forward(chunk_size=128)
    TL->>MHA: forward(chunk_size=128)
    
    alt Non-CP Case (bshd/sbhd)
        MHA->>MHA: reshape input to chunks
        Note over MHA: (b*s/128, 128, h, d)
    else CP Case (thd only)
        MHA->>Utils: thd_chunkify(cu_seqlens, chunk_size)
        Utils->>Tex: thd_chunkify()
        Tex->>CUDA: thd_chunkify_kernel
        CUDA-->>Utils: chunked cu_seqlens
    end
    
    MHA->>DPA: forward(chunk_size=128)
    
    alt CP with P2P
        DPA->>CP: cp_p2p_fwd_fused_attn()
        
        loop For each P2P step
            alt Diagonal section
                CP->>Utils: thd_chunkify_p2p()
                Utils->>CUDA: thd_chunkify_p2p_kernel
            else Lower-triangle section
                CP->>Utils: thd_chunkify_p2p_below_diagonal()
                Utils->>CUDA: thd_seq_tweak_below_diag_kernel
            else Upper-triangle section
                CP->>Utils: thd_chunkify_p2p_above_diagonal()
                Utils->>CUDA: thd_seq_tweak_above_diag_kernel
            end
            
            CP->>CP: fused_attn_fwd with chunked seqlens
        end
        
        CP-->>DPA: attention output
    else Non-CP Case
        DPA->>DPA: Regular attention with reshaped input
    end
    
    DPA-->>MHA: output
    
    alt Non-CP Case (bshd/sbhd)
        MHA->>MHA: reshape output back
        Note over MHA: (b, s, h, d)
    end
    
    MHA-->>TL: attention output
    TL-->>User: layer output

greptile-apps[bot] avatar Nov 17 '25 21:11 greptile-apps[bot]

@yaoyu-33 and @xrennvidia, could you please help take another look at this PR? There's been some delays, but we're making a push to get it merged in 2.10 (i.e. end of this week)! Thanks for your help!

cyanguwa avatar Nov 19 '25 01:11 cyanguwa

/te-ci pytorch L1

pggPL avatar Nov 20 '25 16:11 pggPL

/te-ci pytorch L1

pggPL avatar Nov 20 '25 22:11 pggPL