llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

Fused attention kernel for small batch sizes

Open JohannesGaessler opened this issue 11 months ago • 1 comments

This PR adds a fused attention kernel with almost the same interface as the FlashAttention kernel. It does not use tensor cores or any of the tricks in the FlashAttention paper, it simply calculates attention output for a single column of the output matrix as a fused operation. The performance on my RTX 3090 is as follows:

model test n_batch t/s jg/flash-attn t/s fused_attn_ext_f16 Speedup
llama 7B Q4_0 pp 4096 1 105.98 118.68 1.12
llama 7B Q4_0 pp 4096 2 206.89 229.49 1.11
llama 7B Q4_0 pp 4096 3 295.08 282.98 0.96
llama 7B Q4_0 pp 4096 4 355.27 341.78 0.96
llama 7B Q4_0 pp 4096 5 416.21 401.74 0.97
llama 7B Q4_0 pp 4096 6 443.87 428.70 0.97
llama 7B Q4_0 pp 4096 7 472.82 457.57 0.97
llama 7B Q4_0 pp 4096 8 495.92 480.93 0.97

If you look at the runtime of only the attention, NVIDIA NSight Systems reports a 1.83 speedup for a batch size of 1.

For the fused attention/FlashAttention kernels you need to do a softmax of the KQ columns so you are limited in the number of CUDA blocks by the number of attention heads x the batch size. The problem with this is that this leaves a significant portion of a typical GPU idle at small batch sizes. Though this can be mitigated to some degree by using as many threads as possible. Unfortunately tensor cores need a lot of registers and therefore limit the number of threads that you can use. So for small batch sizes I don't think they make sense.

The FlashAttention paper has some tricks which allow you to calculate tiles of KQ in SRAM without ever having to write KQ to VRAM. However, at a batch size of 1 you only need at most 2*context_size bytes of SRAM to store the entire KQ matrix. So even this very simple kernel works on Turing for up to 31232 context, and for up to 49664 context on Ampere. Of course, if the FlashAttention paper were followed you should be able to run the kernel at larger batch sizes but I think this will make the performance worse.

Overall the kernel reaches ~70% of the maximum theoretical memory bandwidth on my RTX 3090. The best I have ever been able to achieve is something like 90%. So realistically I don't think there is still much room for improvement over this kernel (assuming you have sufficient SRAM to store the entire KQ matrix). Still, I'll try to optimize the pre-existing FlashAttention kernel since it's only reaching ~10% tensor core utilization and can presumably still be improved.

JohannesGaessler avatar Mar 20 '24 15:03 JohannesGaessler

The result from this kernel is incorrect:

LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0

ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes
  Backend name: CUDA0
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1): [FLASH_ATTN_EXT] NMSE = 0.518331457 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2): [FLASH_ATTN_EXT] NMSE = 0.565371118 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=4): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=8): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=512): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1024): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2048): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1): [FLASH_ATTN_EXT] NMSE = 0.523882472 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2): [FLASH_ATTN_EXT] NMSE = 0.539142266 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=4): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=8): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=512): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1024): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2048): OK
  FLASH_ATTN_EXT(hs=128,nh=32,kv=2048,nb=1): [FLASH_ATTN_EXT] NMSE = 0.553847926 > 0.000500000 FAIL
  FLASH_ATTN_EXT(hs=128,nh=32,kv=2048,nb=2): [FLASH_ATTN_EXT] NMSE = 0.544283934 > 0.000500000 FAIL

ggerganov avatar Mar 24 '24 10:03 ggerganov

Obsolete now that https://github.com/ggerganov/llama.cpp/pull/5021 has been merged.

JohannesGaessler avatar Apr 30 '24 10:04 JohannesGaessler