TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

[Feature]: FlashAttention 3 support

Open fan-niu opened this issue 1 year ago • 2 comments

As you know, flashatten3 promises 1.5x~ improvements Is there any plan for support? Thanks! https://github.com/Dao-AILab/flash-attention/commit/7ef24848cf2f855077cef88fe122775b727dcd74

fan-niu avatar Jul 15 '24 06:07 fan-niu

@byshiue @nv-guomingz @nv-hwoo @juney-nvidia @AdamzNV @kaiyux @Shixiaowei02

avianion avatar Jul 15 '24 19:07 avianion

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 15 days."

github-actions[bot] avatar Aug 17 '24 01:08 github-actions[bot]

Is this feature under development? Has there been any progress?

amk9978 avatar Oct 13 '24 16:10 amk9978

It's currently under consideration, but development hasn't begun yet.

AdamzNV avatar Oct 14 '24 02:10 AdamzNV

@AdamzNV I don't understand how an algorithm that you've already implemented the 2nd version of, that provides 1.5x speed boost, can still be under "consideration"

avianion avatar Oct 14 '24 17:10 avianion

@avianion FA3 is nothing new except for utilizing hopper features (i.e. warp specialized kernels with TMA + GMMA). And this has already been implemented in TRT-LLM since the first public release.

  • We have also done some benchmarks, which shows TRT-LLM FMHA kernels are faster in most cases (especially for longer sequence lengths). You are free to compare the performance. And let us know if you find cases that TRT-LLM is much worse. Thanks.
  • Besides, we have much faster FP8 FMHA implementation (note that our implementation still uses per-tensor scales, so it might not be a fair comparison).

PerkzZheng avatar Oct 15 '24 05:10 PerkzZheng

@PerkzZheng Thank you for your insights. Can you please share some open-source link for benchmarking results that you mentioned above? Also, when I compare TRT-LLM FMHA kernel (fmha_v2_flash_attention_fp16_64_64_S_qkv_256_causal_tma_ws_sm90_kernel) for input=16384, BS=1, the TFLOP/sec are almost the same as compared to Original FA V3 on H100. Can you please share a simple script to benchmark Trt-LLM FMHA Kernel or some insights if I missed anything. Thank you! @byshiue @QiJune @AdamzNV

usajid14 avatar Nov 11 '24 20:11 usajid14

@usajid14 you might want to try with head size 128/256 which should have better performance, but FA3 might have been updated so what we have collected is a bit out of dated. Let me know if there are cases that trtllm kernels are much worse. Thanks.

PerkzZheng avatar Nov 14 '24 05:11 PerkzZheng