llama.cpp
llama.cpp copied to clipboard
Fused attention kernel for small batch sizes
This PR adds a fused attention kernel with almost the same interface as the FlashAttention kernel. It does not use tensor cores or any of the tricks in the FlashAttention paper, it simply calculates attention output for a single column of the output matrix as a fused operation. The performance on my RTX 3090 is as follows:
model | test | n_batch | t/s jg/flash-attn | t/s fused_attn_ext_f16 | Speedup |
---|---|---|---|---|---|
llama 7B Q4_0 | pp 4096 | 1 | 105.98 | 118.68 | 1.12 |
llama 7B Q4_0 | pp 4096 | 2 | 206.89 | 229.49 | 1.11 |
llama 7B Q4_0 | pp 4096 | 3 | 295.08 | 282.98 | 0.96 |
llama 7B Q4_0 | pp 4096 | 4 | 355.27 | 341.78 | 0.96 |
llama 7B Q4_0 | pp 4096 | 5 | 416.21 | 401.74 | 0.97 |
llama 7B Q4_0 | pp 4096 | 6 | 443.87 | 428.70 | 0.97 |
llama 7B Q4_0 | pp 4096 | 7 | 472.82 | 457.57 | 0.97 |
llama 7B Q4_0 | pp 4096 | 8 | 495.92 | 480.93 | 0.97 |
If you look at the runtime of only the attention, NVIDIA NSight Systems reports a 1.83 speedup for a batch size of 1.
For the fused attention/FlashAttention kernels you need to do a softmax of the KQ columns so you are limited in the number of CUDA blocks by the number of attention heads x the batch size. The problem with this is that this leaves a significant portion of a typical GPU idle at small batch sizes. Though this can be mitigated to some degree by using as many threads as possible. Unfortunately tensor cores need a lot of registers and therefore limit the number of threads that you can use. So for small batch sizes I don't think they make sense.
The FlashAttention paper has some tricks which allow you to calculate tiles of KQ in SRAM without ever having to write KQ to VRAM. However, at a batch size of 1 you only need at most 2*context_size
bytes of SRAM to store the entire KQ matrix. So even this very simple kernel works on Turing for up to 31232 context, and for up to 49664 context on Ampere. Of course, if the FlashAttention paper were followed you should be able to run the kernel at larger batch sizes but I think this will make the performance worse.
Overall the kernel reaches ~70% of the maximum theoretical memory bandwidth on my RTX 3090. The best I have ever been able to achieve is something like 90%. So realistically I don't think there is still much room for improvement over this kernel (assuming you have sufficient SRAM to store the entire KQ matrix). Still, I'll try to optimize the pre-existing FlashAttention kernel since it's only reaching ~10% tensor core utilization and can presumably still be improved.
The result from this kernel is incorrect:
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0
ggml_init_cublas: GGML_CUDA_FORCE_MMQ: no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
Device 0: Tesla V100-PCIE-16GB, compute capability 7.0, VMM: yes
Backend name: CUDA0
FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1): [FLASH_ATTN_EXT] NMSE = 0.518331457 > 0.000500000 FAIL
FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2): [FLASH_ATTN_EXT] NMSE = 0.565371118 > 0.000500000 FAIL
FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=4): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=8): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=512): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1024): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2048): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1): [FLASH_ATTN_EXT] NMSE = 0.523882472 > 0.000500000 FAIL
FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2): [FLASH_ATTN_EXT] NMSE = 0.539142266 > 0.000500000 FAIL
FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=4): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=8): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=512): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1024): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2048): OK
FLASH_ATTN_EXT(hs=128,nh=32,kv=2048,nb=1): [FLASH_ATTN_EXT] NMSE = 0.553847926 > 0.000500000 FAIL
FLASH_ATTN_EXT(hs=128,nh=32,kv=2048,nb=2): [FLASH_ATTN_EXT] NMSE = 0.544283934 > 0.000500000 FAIL
Obsolete now that https://github.com/ggerganov/llama.cpp/pull/5021 has been merged.