FasterTransformer icon indicating copy to clipboard operation
FasterTransformer copied to clipboard

Does FasterTransformer support flash attention

Open hujiaxin0 opened this issue 2 years ago • 9 comments

Hi, does FasterTransformer support flash attention?

hujiaxin0 avatar Apr 24 '23 03:04 hujiaxin0

Yes. The kernels with xxx_flash_attention_xxx in https://github.com/NVIDIA/FasterTransformer/tree/main/3rdparty/trt_fused_multihead_attention are flash attention kernels.

byshiue avatar Apr 24 '23 03:04 byshiue

In the process of using the flash attention, I have two questions. The first question, I found that using FMHA_ ENABLE=ON, there may be precision problem. Is there any way to solve that? The second question, I found that flash attention is implemented by assembly code . Could you provide the source code in C++/CUDA?

hujiaxin0 avatar Apr 24 '23 12:04 hujiaxin0

In the process of using the flash attention, I have two questions. The first question, I found that using FMHA_ ENABLE=ON, there may be precision problem. Is there any way to solve that? The second question, I found that flash attention is implemented by assembly code . Could you provide the source code in C++/CUDA?

No and no.

byshiue avatar Apr 24 '23 12:04 byshiue

In the process of using the flash attention, I have two questions. The first question, I found that using FMHA_ ENABLE=ON, there may be precision problem. Is there any way to solve that? The second question, I found that flash attention is implemented by assembly code . Could you provide the source code in C++/CUDA?

No and no.

Is it because the kernels of flash attention cannot be open sourced?

hujiaxin0 avatar Apr 27 '23 03:04 hujiaxin0

@hujiaxin0 what precision problem are you referring to?

ankit-db avatar May 16 '23 20:05 ankit-db

Also, @byshiue it's maybe a little unclear to me how to know whether flash attention is being used - is this exposed in any of the models that are currently in the repo?

ankit-db avatar May 16 '23 20:05 ankit-db

@hujiaxin0 what precision problem are you referring to?

By setting the environment variable FMHA_ENABLE=ON, and then the output of model is useless. When unsetting the environment variable FMHA_ENABLE, the the output of model is normal.

I found the reply in this issue https://github.com/NVIDIA/FasterTransformer/issues/548#issuecomment-1499899742, which mentioned that FMHA bring further accuracy concerns.

Since I can't see the internal implementation of the flash attention kernel, I can't debug this problem.

I hope you can give me some advice on debugging this problem. @ankit-db

hujiaxin0 avatar May 17 '23 02:05 hujiaxin0

I'm not an expert on the FMHA environment variable, so I can't really help much here.

That being said, accuracy concerns is totally different than "output is garbage", so I'm guessing you've run into some more problematic issue? @byshiue do you mind clarifying what accuracy concerns are?

ankit-db avatar May 17 '23 03:05 ankit-db

how to use fuse mha layer in pytorch ?

ranjiewwen avatar Jul 24 '23 11:07 ranjiewwen