paxml
paxml copied to clipboard
[NVIDIA] Add config option to use cudnn flash attention
This PR is to allow users to enable the cudnn flash attention. The PR depends on https://github.com/google/praxis/pull/53.
The preliminary results for the GPT3-5B, we can observe ~30% perf improve on 8xH100 GPUs.
With this PR, users can simply set USE_CUDNN_FLASH_ATTENTION=True
in their config and then the attention part will be replaced with the cudnn flash attention.
cc. @nluehr @zhangqiaorjc