Paddle
Paddle copied to clipboard
优化scaled_dot_product_attention中的后端切换逻辑
trafficstars
PR Category
User Experience
PR Types
Improvements
Description
flash attention应该是调用的Paddle fork的flash attention库
scaled_dot_product_attention调用的接口应该是
bool flash_attn_fwd(const void * const q, // batch_size x seqlen_q x num_heads x head_size
const void * const k, // batch_size x seqlen_k x num_heads_k x head_size
const void * const v, // batch_size x seqlen_k x num_heads_k x head_size
void * const rng_state,
void * const out,
void * const softmax_ptr,
void * const softmax_lse_ptr,
const int batch_size,
const int seqlen_q,
const int seqlen_k,
const int seqlen_q_rounded,
const int seqlen_k_rounded,
const int num_heads,
const int num_heads_k,
const int head_size,
const int head_size_rounded,
const float p_dropout,
const float softmax_scale,
const float softmax_unscale,
const bool is_causal,
const bool return_softmax,
const bool is_bf16,
cudaStream_t stream,
uint64_t seed,
uint64_t offset,
const void * const attn_mask,
const int64_t * const mask_dims,
const void * const flashmask_downstart_ptr,
const int64_t * const flashmask_dims,
const void * const flashmask_upend_ptr,
const void * const flashmask_downend_ptr,
const void * const flashmask_upstart_ptr,
const void * const flashmask_maxmin_ptr,
const int q_row_stride,
const int k_row_stride,
const int v_row_stride,
const int q_head_stride,
const int k_head_stride,
const int v_head_stride,
const int o_row_stride,
const int o_head_stride,
const int q_batch_stride,
const int k_batch_stride,
const int v_batch_stride,
const int o_batch_stride);
其中的检查为
#define CHECK_FWD_EXECTUABLE(__seqlen_q, __seqlen_k) \
auto dprops = at::cuda::getCurrentDeviceProperties(); \
const bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; \
const bool is_sm90 = dprops->major == 9 && dprops->minor == 0; \
ASSERT_CHECK(is_sm8x || is_sm90); \
ASSERT_CHECK(batch_size > 0); \
ASSERT_CHECK(head_size % 8 == 0); \
ASSERT_CHECK(head_size <= 256); \
ASSERT_CHECK(num_heads % num_heads_k == 0); \
if (attn_mask) { \
ASSERT_CHECK(mask_dims[0] == batch_size); \
ASSERT_CHECK(mask_dims[1] == 1 || mask_dims[1] == num_heads); \
ASSERT_CHECK(mask_dims[2] == 1 || mask_dims[2] == __seqlen_q); \
ASSERT_CHECK(mask_dims[3] == __seqlen_k); \
}
head_size的检查为head_size <= 256即可
你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.
@Qin-sx 这个PR也需要修改
@Qin-sx 这个PR也需要修改
嗯,收到,但是这部分有点复杂,有可能涉及DCU的测试,我打算放在后面处理
Sorry to inform you that 1cb2356's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.