Paddle icon indicating copy to clipboard operation
Paddle copied to clipboard

Update cudnn flash attention

Open Wong4j opened this issue 1 year ago • 22 comments

PR Category

Others

PR Types

Others

Description

cuDNN Flash Attention在Hopper GPU上性能超过开源的Flash Attention。在Ampere GPU上部分case比开源版本慢,目前cuDNN team正在优化中。

PR的改动:

  1. cudnn-frontend 从v0.9升级到v1.2.
  2. 重构fused_dot_product_attention,增加新的feature,比如bias, GQA/MQA.
  3. 重构fuse_dot_product_attention_pass IR pass.

解释一下 incubate/nn/functional/fused_dot_product_attention.py中添加的两个APIs:

  • cudnn_flash_attention
    • 这个API具有cuDNN实现的flash attention的完整功能,跟开源版本的flash attention的接口类似。输入参数有q, k, v, cu_seqlen_q, cu_seqlen_kv等等,没有mask。另外还支持bias
  • fused_dot_product_attention
    • 其底层调用的也是cuDNN版的flash attention,相当于一个特化版本,这样设计主要是为了在参数和功能上尽量对齐nn.functional.scaled_dot_product_attention,输入参数是q, k, v, mask等等。
    • 其中需要注意细节的是,cudnn_flash_attention本身没有mask输入,而这个API支持arbitrary mask,内部的实现机制是将mask作为post_scale_bias传进去。

Wong4j avatar Mar 26 '24 07:03 Wong4j

你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle-bot[bot] avatar Mar 26 '24 07:03 paddle-bot[bot]

❌ The PR is not created using PR's template. You can refer to this Demo. Please use PR's template, it helps save our maintainers' time so that more developers get helped.

paddle-bot[bot] avatar Mar 26 '24 07:03 paddle-bot[bot]

Sorry to inform you that 2977323's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Apr 03 '24 03:04 paddle-ci-bot[bot]

This PR need to test with cudnn>=9.0. Could you help to upgrade cuDNN version in GPUPS. @tianshuo78520a

Wong4j avatar Apr 06 '24 13:04 Wong4j

This PR need to test with cudnn>=9.0. Could you help to upgrade cuDNN version in GPUPS. @tianshuo78520a

The GPUPS manager encountered an issue while testing CUDNN9.0 and is still working on resolving it. We need to wait for them to solve the problem before CI can upgrade.

tianshuo78520a avatar Apr 08 '24 06:04 tianshuo78520a

Sorry to inform you that 6cb5064's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Apr 14 '24 03:04 paddle-ci-bot[bot]

Sorry to inform you that 943ea79's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar May 04 '24 03:05 paddle-ci-bot[bot]

@tianshuo78520a 有两个问题:

  1. 我这个PR有两个UT(test_fused_dot_product_attention_op和test_fused_dot_product_attention_op_static)被添加到了disable_ut中了,所以没有被跑到,导致PR-CI-Coverage报错,在这个里面https://sys-p0.bj.bcebos.com/prec/disable_ut 。请问这个是什么原因?可以移除吗?
  2. 我想确认一下。PR-CI-Distribute-stable这个CI,编Paddle的时候开了WITH_CUDNN_FRONTEND吗?

Wong4j avatar May 13 '24 01:05 Wong4j

2. WITH_CUDNN_FRONTEND 第1点,这两个单测好像是会出现随机挂情况,如果解决了可以从列表中移出,暂时没有人力看,我询问下能否豁免Coverage CI。如果有时间,也可以帮看下单测问题。

第2点,你可以在PR-CI-Distribute-stable CI日志中找到,已经是打开状态(WITH_CUDNN_FRONTEND=ON)。 image

tianshuo78520a avatar May 13 '24 02:05 tianshuo78520a

这两个单测好像是会出现随机挂情况,如果解决了可以从列表中移出

@tianshuo78520a 谢谢回复,这两个单测我本地测是没问题的,能否从列表中移除,跑CI试一下,如果遇到问题我再修

Wong4j avatar May 13 '24 07:05 Wong4j

这两个单测好像是会出现随机挂情况,如果解决了可以从列表中移出

@tianshuo78520a 谢谢回复,这两个单测我本地测是没问题的,能否从列表中移除,跑CI试一下,如果遇到问题我再修

test_fused_dot_product_attention_op已经解除 test_fused_dot_product_attention_op_static并没有禁用,需要测试你可以提个PR,把单测添加到tools/gpups_test.sh目录下

tianshuo78520a avatar May 13 '24 08:05 tianshuo78520a

@tianshuo78520a 解除后,PR-CI-Distribute-stable中test_fused_dot_product_attention_op这个单测pass了。但PR-CI-Coverage里面仍然显示没有跑到相关代码。 PR-CI-Distribute-stable是不会影响coverage检查的吗?如果是这样的话,就只能麻烦帮忙豁免coverage了

Wong4j avatar May 14 '24 00:05 Wong4j

@tianshuo78520a 解除后,PR-CI-Distribute-stable中test_fused_dot_product_attention_op这个单测pass了。但PR-CI-Coverage里面仍然显示没有跑到相关代码。 PR-CI-Distribute-stable是不会影响coverage检查的吗?如果是这样的话,就只能麻烦帮忙豁免coverage了

是的Coverage 是V100环境,应该不跑这个单测,我联系负责人豁免下

tianshuo78520a avatar May 14 '24 07:05 tianshuo78520a

@onecatcn 这个PR CI都过了,下面这几项麻烦帮忙找对应的人approve。

PR-CI-APPROVAL

0. You must have raindrops2sea or XiaoguangHu01 approval for change 20+ files or add than 1000+ lines of content.
1. You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93,Aurelius84) approval for the usage of const_cast.
2. Unittest is not allowed to be disabled. You must have one RD (kolinwei(Recommend), wanghuancoder, luotao1, QingshuChen, qili93 or ZzSean or Aurelius84) approval for the usage of @unittest.skip or @unittest.skipIf.
3. The error message you wrote in PADDLE_ENFORCE{_**} or PADDLE_THROW does not meet our error message writing specification. Possible errors include 1. the error message is empty / 2. the error message is too short / 3. the error type is not specified. Please read the specification [ https://github.com/PaddlePaddle/Paddle/wiki/Paddle-Error-Message-Writing-Specification ], then refine the error message. If it is a mismatch, please request chenwhql (Recommend), luotao1 or lanxianghit or Aurelius84 review and approve.

PR-CI-Static-Check

1. You must have one RD (qingqing01(Recommend), heavengate) approval for the changes of  Inputs/Output/Attrs of OPs. 
2024-05-13 17:15:02  For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].
  * The added Input 'cu_seqlen_q' is `def`, need inference to review.
  * The added Input 'bias' is `def`, need inference to review.
  * The added Input 'cu_seqlen_kv' is `def`, need inference to review.
  * The added attr 'bias_type_str' is `def`, need inference to review.
  * The added attr 'mask_type_str' is `def`, need inference to review.
2. print or std::cout is not recommended for direct use, please use logging or VLOG. If it is necessary to use, please contact tianshuo78520a (Recommend) or zhangbo9674 review and approve.

Wong4j avatar May 14 '24 08:05 Wong4j

请问新的cudnn_flash_attention,性能上和开源版本flash_attention有差异吗?

vivienfanghuagood avatar May 14 '24 08:05 vivienfanghuagood

@vivienfanghuagood 目前cuDNN v9.0,在Hopper GPU上是比开源版本快,在Ampere上有一些case慢一些,cuDNN team在优化中

Wong4j avatar May 14 '24 08:05 Wong4j

Sorry to inform you that cc43071's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar May 21 '24 03:05 paddle-ci-bot[bot]

@vivienfanghuagood 请问还有review意见吗?

Wong4j avatar May 23 '24 07:05 Wong4j

review意见都修改了 @qingqing01 @kolinwei @chenwhql 麻烦看一下是否可以approve

Wong4j avatar May 24 '24 07:05 Wong4j

Sorry to inform you that f0578a0's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar May 30 '24 03:05 paddle-ci-bot[bot]

@onecatcn 麻烦帮忙提醒一下这几位

review意见都修改了 @qingqing01 @kolinwei @chenwhql 麻烦看一下是否可以approve

Wong4j avatar Jun 05 '24 06:06 Wong4j

Sorry to inform you that ac9b059's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

paddle-ci-bot[bot] avatar Jun 19 '24 03:06 paddle-ci-bot[bot]