Update cudnn flash attention
PR Category
Others
PR Types
Others
Description
cuDNN Flash Attention在Hopper GPU上性能超过开源的Flash Attention。在Ampere GPU上部分case比开源版本慢,目前cuDNN team正在优化中。
PR的改动:
- cudnn-frontend 从v0.9升级到v1.2.
- 重构fused_dot_product_attention,增加新的feature,比如bias, GQA/MQA.
- 重构
fuse_dot_product_attention_passIR pass.
解释一下 incubate/nn/functional/fused_dot_product_attention.py中添加的两个APIs:
cudnn_flash_attention- 这个API具有cuDNN实现的flash attention的完整功能,跟开源版本的flash attention的接口类似。输入参数有
q, k, v, cu_seqlen_q, cu_seqlen_kv等等,没有mask。另外还支持bias。
- 这个API具有cuDNN实现的flash attention的完整功能,跟开源版本的flash attention的接口类似。输入参数有
fused_dot_product_attention- 其底层调用的也是cuDNN版的flash attention,相当于一个特化版本,这样设计主要是为了在参数和功能上尽量对齐
nn.functional.scaled_dot_product_attention,输入参数是q, k, v, mask等等。 - 其中需要注意细节的是,
cudnn_flash_attention本身没有mask输入,而这个API支持arbitrary mask,内部的实现机制是将mask作为post_scale_bias传进去。
- 其底层调用的也是cuDNN版的flash attention,相当于一个特化版本,这样设计主要是为了在参数和功能上尽量对齐
你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.
❌ The PR is not created using PR's template. You can refer to this Demo. Please use PR's template, it helps save our maintainers' time so that more developers get helped.
Sorry to inform you that 2977323's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
This PR need to test with cudnn>=9.0. Could you help to upgrade cuDNN version in GPUPS. @tianshuo78520a
This PR need to test with cudnn>=9.0. Could you help to upgrade cuDNN version in GPUPS. @tianshuo78520a
The GPUPS manager encountered an issue while testing CUDNN9.0 and is still working on resolving it. We need to wait for them to solve the problem before CI can upgrade.
Sorry to inform you that 6cb5064's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
Sorry to inform you that 943ea79's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
@tianshuo78520a 有两个问题:
- 我这个PR有两个UT(test_fused_dot_product_attention_op和test_fused_dot_product_attention_op_static)被添加到了disable_ut中了,所以没有被跑到,导致PR-CI-Coverage报错,在这个里面https://sys-p0.bj.bcebos.com/prec/disable_ut 。请问这个是什么原因?可以移除吗?
- 我想确认一下。PR-CI-Distribute-stable这个CI,编Paddle的时候开了WITH_CUDNN_FRONTEND吗?
2. WITH_CUDNN_FRONTEND 第1点,这两个单测好像是会出现随机挂情况,如果解决了可以从列表中移出,暂时没有人力看,我询问下能否豁免Coverage CI。如果有时间,也可以帮看下单测问题。
第2点,你可以在PR-CI-Distribute-stable CI日志中找到,已经是打开状态(WITH_CUDNN_FRONTEND=ON)。
这两个单测好像是会出现随机挂情况,如果解决了可以从列表中移出
@tianshuo78520a 谢谢回复,这两个单测我本地测是没问题的,能否从列表中移除,跑CI试一下,如果遇到问题我再修
这两个单测好像是会出现随机挂情况,如果解决了可以从列表中移出
@tianshuo78520a 谢谢回复,这两个单测我本地测是没问题的,能否从列表中移除,跑CI试一下,如果遇到问题我再修
test_fused_dot_product_attention_op已经解除 test_fused_dot_product_attention_op_static并没有禁用,需要测试你可以提个PR,把单测添加到tools/gpups_test.sh目录下
@tianshuo78520a 解除后,PR-CI-Distribute-stable中test_fused_dot_product_attention_op这个单测pass了。但PR-CI-Coverage里面仍然显示没有跑到相关代码。 PR-CI-Distribute-stable是不会影响coverage检查的吗?如果是这样的话,就只能麻烦帮忙豁免coverage了
@tianshuo78520a 解除后,PR-CI-Distribute-stable中test_fused_dot_product_attention_op这个单测pass了。但PR-CI-Coverage里面仍然显示没有跑到相关代码。 PR-CI-Distribute-stable是不会影响coverage检查的吗?如果是这样的话,就只能麻烦帮忙豁免coverage了
是的Coverage 是V100环境,应该不跑这个单测,我联系负责人豁免下
@onecatcn 这个PR CI都过了,下面这几项麻烦帮忙找对应的人approve。
PR-CI-APPROVAL
0. You must have raindrops2sea or XiaoguangHu01 approval for change 20+ files or add than 1000+ lines of content.
1. You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93,Aurelius84) approval for the usage of const_cast.
2. Unittest is not allowed to be disabled. You must have one RD (kolinwei(Recommend), wanghuancoder, luotao1, QingshuChen, qili93 or ZzSean or Aurelius84) approval for the usage of @unittest.skip or @unittest.skipIf.
3. The error message you wrote in PADDLE_ENFORCE{_**} or PADDLE_THROW does not meet our error message writing specification. Possible errors include 1. the error message is empty / 2. the error message is too short / 3. the error type is not specified. Please read the specification [ https://github.com/PaddlePaddle/Paddle/wiki/Paddle-Error-Message-Writing-Specification ], then refine the error message. If it is a mismatch, please request chenwhql (Recommend), luotao1 or lanxianghit or Aurelius84 review and approve.
PR-CI-Static-Check
1. You must have one RD (qingqing01(Recommend), heavengate) approval for the changes of Inputs/Output/Attrs of OPs.
2024-05-13 17:15:02 For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].
* The added Input 'cu_seqlen_q' is `def`, need inference to review.
* The added Input 'bias' is `def`, need inference to review.
* The added Input 'cu_seqlen_kv' is `def`, need inference to review.
* The added attr 'bias_type_str' is `def`, need inference to review.
* The added attr 'mask_type_str' is `def`, need inference to review.
2. print or std::cout is not recommended for direct use, please use logging or VLOG. If it is necessary to use, please contact tianshuo78520a (Recommend) or zhangbo9674 review and approve.
请问新的cudnn_flash_attention,性能上和开源版本flash_attention有差异吗?
@vivienfanghuagood 目前cuDNN v9.0,在Hopper GPU上是比开源版本快,在Ampere上有一些case慢一些,cuDNN team在优化中
Sorry to inform you that cc43071's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
@vivienfanghuagood 请问还有review意见吗?
review意见都修改了 @qingqing01 @kolinwei @chenwhql 麻烦看一下是否可以approve
Sorry to inform you that f0578a0's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
@onecatcn 麻烦帮忙提醒一下这几位
review意见都修改了 @qingqing01 @kolinwei @chenwhql 麻烦看一下是否可以approve
Sorry to inform you that ac9b059's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.