paxml icon indicating copy to clipboard operation
paxml copied to clipboard

[NVIDIA] Add config option to use cudnn flash attention

Open kaixih opened this issue 3 months ago • 0 comments

This PR is to allow users to enable the cudnn flash attention. The PR depends on https://github.com/google/praxis/pull/53.

The preliminary results for the GPT3-5B, we can observe ~30% perf improve on 8xH100 GPUs.

With this PR, users can simply set USE_CUDNN_FLASH_ATTENTION=True in their config and then the attention part will be replaced with the cudnn flash attention.

cc. @nluehr @zhangqiaorjc

kaixih avatar Mar 22 '24 20:03 kaixih