Paddle icon indicating copy to clipboard operation
Paddle copied to clipboard

优化scaled_dot_product_attention中的后端切换逻辑

Open Qin-sx opened this issue 5 months ago • 4 comments
trafficstars

PR Category

User Experience

PR Types

Improvements

Description

flash attention应该是调用的Paddle fork的flash attention
scaled_dot_product_attention调用的接口应该是

bool flash_attn_fwd(const void * const q,         // batch_size x seqlen_q x num_heads x head_size
                    const void * const k,         // batch_size x seqlen_k x num_heads_k x head_size
                    const void * const v,         // batch_size x seqlen_k x num_heads_k x head_size
                    void * const rng_state,
                    void * const out,
                    void * const softmax_ptr,
                    void * const softmax_lse_ptr,
                    const int batch_size,
                    const int seqlen_q,
                    const int seqlen_k,
                    const int seqlen_q_rounded,
                    const int seqlen_k_rounded,
                    const int num_heads,
                    const int num_heads_k,
                    const int head_size,
                    const int head_size_rounded,
                    const float p_dropout,
                    const float softmax_scale,
                    const float softmax_unscale,
                    const bool is_causal,
                    const bool return_softmax,
                    const bool is_bf16,
                    cudaStream_t stream,
                    uint64_t seed,
                    uint64_t offset,
                    const void * const attn_mask,
                    const int64_t * const mask_dims,
                    const void * const flashmask_downstart_ptr,
                    const int64_t * const flashmask_dims,
                    const void * const flashmask_upend_ptr,
                    const void * const flashmask_downend_ptr,
                    const void * const flashmask_upstart_ptr,
                    const void * const flashmask_maxmin_ptr,
                    const int q_row_stride,
                    const int k_row_stride,
                    const int v_row_stride,
                    const int q_head_stride,
                    const int k_head_stride,
                    const int v_head_stride,
                    const int o_row_stride,
                    const int o_head_stride,
                    const int q_batch_stride,
                    const int k_batch_stride,
                    const int v_batch_stride,
                    const int o_batch_stride);

其中的检查为

#define CHECK_FWD_EXECTUABLE(__seqlen_q, __seqlen_k)                     \
      auto dprops = at::cuda::getCurrentDeviceProperties();              \
      const bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;     \
      const bool is_sm90 = dprops->major == 9 && dprops->minor == 0;     \
      ASSERT_CHECK(is_sm8x || is_sm90);                                  \
      ASSERT_CHECK(batch_size > 0);                                      \
      ASSERT_CHECK(head_size % 8 == 0);                                  \
      ASSERT_CHECK(head_size <= 256);                                    \
      ASSERT_CHECK(num_heads % num_heads_k == 0);                        \
      if (attn_mask) {                                                   \
          ASSERT_CHECK(mask_dims[0] == batch_size);                      \
          ASSERT_CHECK(mask_dims[1] == 1 || mask_dims[1] == num_heads);  \
          ASSERT_CHECK(mask_dims[2] == 1 || mask_dims[2] == __seqlen_q); \
          ASSERT_CHECK(mask_dims[3] == __seqlen_k);                      \
      }

head_size的检查为head_size <= 256即可

Qin-sx avatar Jun 07 '25 14:06 Qin-sx

你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle-bot[bot] avatar Jun 07 '25 14:06 paddle-bot[bot]

@Qin-sx 这个PR也需要修改

zhwesky2010 avatar Jun 13 '25 10:06 zhwesky2010

@Qin-sx 这个PR也需要修改

嗯,收到,但是这部分有点复杂,有可能涉及DCU的测试,我打算放在后面处理

Qin-sx avatar Jun 13 '25 15:06 Qin-sx

Sorry to inform you that 1cb2356's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Jun 16 '25 02:06 paddle-ci-bot[bot]