FasterTransformer
FasterTransformer copied to clipboard
Does FasterTransformer support flash attention
Hi, does FasterTransformer support flash attention?
Yes. The kernels with xxx_flash_attention_xxx in https://github.com/NVIDIA/FasterTransformer/tree/main/3rdparty/trt_fused_multihead_attention are flash attention kernels.
In the process of using the flash attention, I have two questions. The first question, I found that using FMHA_ ENABLE=ON, there may be precision problem. Is there any way to solve that? The second question, I found that flash attention is implemented by assembly code . Could you provide the source code in C++/CUDA?
In the process of using the flash attention, I have two questions. The first question, I found that using FMHA_ ENABLE=ON, there may be precision problem. Is there any way to solve that? The second question, I found that flash attention is implemented by assembly code . Could you provide the source code in C++/CUDA?
No and no.
In the process of using the flash attention, I have two questions. The first question, I found that using FMHA_ ENABLE=ON, there may be precision problem. Is there any way to solve that? The second question, I found that flash attention is implemented by assembly code . Could you provide the source code in C++/CUDA?
No and no.
Is it because the kernels of flash attention cannot be open sourced?
@hujiaxin0 what precision problem are you referring to?
Also, @byshiue it's maybe a little unclear to me how to know whether flash attention is being used - is this exposed in any of the models that are currently in the repo?
@hujiaxin0 what precision problem are you referring to?
By setting the environment variable FMHA_ENABLE=ON, and then the output of model is useless. When unsetting the environment variable FMHA_ENABLE, the the output of model is normal.
I found the reply in this issue https://github.com/NVIDIA/FasterTransformer/issues/548#issuecomment-1499899742, which mentioned that FMHA bring further accuracy concerns.
Since I can't see the internal implementation of the flash attention kernel, I can't debug this problem.
I hope you can give me some advice on debugging this problem. @ankit-db
I'm not an expert on the FMHA environment variable, so I can't really help much here.
That being said, accuracy concerns is totally different than "output is garbage", so I'm guessing you've run into some more problematic issue? @byshiue do you mind clarifying what accuracy concerns are?
how to use fuse mha layer in pytorch ?