llama.cpp
llama.cpp copied to clipboard
FlashAttention: pragma unroll, use_mask template parameter
This PR adds two very simple optimizations for FlashAttention:
- Add
#pragma unroll
whereever possible so that the compiler unrolls more loops. - Add a template parameter
use_mask
to eliminate some runtime checks.
Together these provide ~2% speedup for small batch sizes on my system:
Quite honestly I don't think the current code can be made to work well for small batch sizes (< 16) though so whether or not this gets merged is of relatively little consequence. The tensor core fragments consume a lot of registers but for small batch sizes there is no real benefit. So I think for those batch sizes a specialized implementation without tensor cores will perform better.
Here are results on V100 using:
# baseline
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf
# flash attn
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf
Note that this is measuring just the attention operations - not the entire throughput.
hs | heads | n_kv | n_batch | attn (us/run) | gg/flash_attn (us/run) | jg/flash_attn (us/run) | speedup (attn/gg) | speedup (gg/jg) | speedup (attn/jg) |
---|---|---|---|---|---|---|---|---|---|
128 | 32 | 512 | 1 | 45.56 | 54.45 | 50.54 | 0.84 | 1.08 | 0.90 |
128 | 32 | 512 | 2 | 47.33 | 55.06 | 49.91 | 0.86 | 1.10 | 0.95 |
128 | 32 | 512 | 4 | 56.37 | 55.95 | 47.68 | 1.01 | 1.17 | 1.18 |
128 | 32 | 512 | 8 | 72.85 | 58.97 | 49.99 | 1.24 | 1.18 | 1.46 |
128 | 32 | 512 | 512 | 501.63 | 385.66 | 367.70 | 1.30 | 1.05 | 1.36 |
128 | 32 | 512 | 1024 | 971.40 | 718.00 | 690.30 | 1.35 | 1.04 | 1.41 |
128 | 32 | 512 | 2048 | 1918.88 | 1246.56 | 1213.68 | 1.54 | 1.03 | 1.58 |
128 | 32 | 1024 | 1 | 62.33 | 78.44 | 76.33 | 0.79 | 1.03 | 0.82 |
128 | 32 | 1024 | 2 | 64.11 | 78.89 | 76.72 | 0.81 | 1.03 | 0.84 |
128 | 32 | 1024 | 4 | 77.55 | 79.61 | 77.47 | 0.97 | 1.03 | 1.00 |
128 | 32 | 1024 | 8 | 108.78 | 81.15 | 79.11 | 1.34 | 1.03 | 1.38 |
128 | 32 | 1024 | 512 | 845.88 | 706.42 | 673.63 | 1.20 | 1.05 | 1.26 |
128 | 32 | 1024 | 1024 | 1682.18 | 1310.36 | 1247.34 | 1.28 | 1.05 | 1.35 |
128 | 32 | 1024 | 2048 | 3306.77 | 2206.84 | 2192.80 | 1.50 | 1.01 | 1.51 |
128 | 32 | 2048 | 1 | 96.69 | 131.55 | 126.53 | 0.74 | 1.04 | 0.76 |
128 | 32 | 2048 | 2 | 100.30 | 131.64 | 126.98 | 0.76 | 1.04 | 0.79 |
128 | 32 | 2048 | 4 | 125.38 | 132.14 | 127.47 | 0.95 | 1.04 | 0.98 |
128 | 32 | 2048 | 8 | 179.87 | 134.20 | 129.16 | 1.34 | 1.04 | 1.39 |
128 | 32 | 2048 | 512 | 1395.07 | 1343.34 | 1249.35 | 1.04 | 1.08 | 1.12 |
128 | 32 | 2048 | 1024 | 2717.62 | 2475.43 | 2353.68 | 1.10 | 1.05 | 1.15 |
128 | 32 | 2048 | 2048 | 5349.03 | 4155.13 | 4157.71 | 1.29 | 1.00 | 1.29 |
128 | 32 | 4096 | 1 | 163.47 | 234.90 | 223.14 | 0.70 | 1.05 | 0.73 |
128 | 32 | 4096 | 2 | 169.38 | 235.21 | 223.53 | 0.72 | 1.05 | 0.76 |
128 | 32 | 4096 | 4 | 215.75 | 235.74 | 223.69 | 0.92 | 1.05 | 0.96 |
128 | 32 | 4096 | 8 | 319.53 | 237.44 | 225.08 | 1.35 | 1.05 | 1.42 |
128 | 32 | 4096 | 512 | 2575.27 | 2603.56 | 2386.44 | 0.99 | 1.09 | 1.08 |
128 | 32 | 4096 | 1024 | 5045.29 | 4806.45 | 4575.97 | 1.05 | 1.05 | 1.10 |
128 | 32 | 4096 | 2048 | 9931.37 | 8034.10 | 8106.12 | 1.24 | 0.99 | 1.23 |
64 | 32 | 512 | 1 | 39.07 | 32.72 | 30.75 | 1.19 | 1.06 | 1.27 |
64 | 32 | 512 | 2 | 34.32 | 32.91 | 30.96 | 1.04 | 1.06 | 1.11 |
64 | 32 | 512 | 4 | 40.67 | 33.34 | 31.43 | 1.22 | 1.06 | 1.29 |
64 | 32 | 512 | 8 | 55.57 | 34.33 | 32.41 | 1.62 | 1.06 | 1.71 |
64 | 32 | 512 | 512 | 411.37 | 248.66 | 232.71 | 1.65 | 1.07 | 1.77 |
64 | 32 | 512 | 1024 | 795.00 | 436.61 | 427.91 | 1.82 | 1.02 | 1.86 |
64 | 32 | 512 | 2048 | 1565.25 | 724.60 | 774.42 | 2.16 | 0.94 | 2.02 |
64 | 32 | 1024 | 1 | 53.37 | 53.02 | 50.96 | 1.01 | 1.04 | 1.05 |
64 | 32 | 1024 | 2 | 53.96 | 53.06 | 51.30 | 1.02 | 1.03 | 1.05 |
64 | 32 | 1024 | 4 | 66.24 | 53.68 | 51.72 | 1.23 | 1.04 | 1.28 |
64 | 32 | 1024 | 8 | 93.75 | 54.38 | 52.25 | 1.72 | 1.04 | 1.79 |
64 | 32 | 1024 | 512 | 740.54 | 455.52 | 431.63 | 1.63 | 1.06 | 1.72 |
64 | 32 | 1024 | 1024 | 1475.62 | 769.90 | 798.53 | 1.92 | 0.96 | 1.85 |
64 | 32 | 1024 | 2048 | 2887.93 | 1329.48 | 1437.34 | 2.17 | 0.92 | 2.01 |
64 | 32 | 2048 | 1 | 82.71 | 90.10 | 84.41 | 0.92 | 1.07 | 0.98 |
64 | 32 | 2048 | 2 | 86.41 | 90.37 | 84.73 | 0.96 | 1.07 | 1.02 |
64 | 32 | 2048 | 4 | 110.69 | 90.96 | 85.21 | 1.22 | 1.07 | 1.30 |
64 | 32 | 2048 | 8 | 159.36 | 91.73 | 85.94 | 1.74 | 1.07 | 1.85 |
64 | 32 | 2048 | 512 | 1272.46 | 868.85 | 814.30 | 1.46 | 1.07 | 1.56 |
64 | 32 | 2048 | 1024 | 2486.13 | 1526.42 | 1534.50 | 1.63 | 0.99 | 1.62 |
64 | 32 | 2048 | 2048 | 4864.27 | 2576.05 | 2765.03 | 1.89 | 0.93 | 1.76 |
64 | 32 | 4096 | 1 | 135.99 | 163.04 | 150.03 | 0.83 | 1.09 | 0.91 |
64 | 32 | 4096 | 2 | 144.66 | 163.31 | 150.41 | 0.89 | 1.09 | 0.96 |
64 | 32 | 4096 | 4 | 190.04 | 163.64 | 150.81 | 1.16 | 1.09 | 1.26 |
64 | 32 | 4096 | 8 | 285.76 | 164.50 | 151.75 | 1.74 | 1.08 | 1.88 |
64 | 32 | 4096 | 512 | 2372.08 | 1951.07 | 1574.40 | 1.22 | 1.24 | 1.51 |
64 | 32 | 4096 | 1024 | 4726.56 | 3044.91 | 3005.09 | 1.55 | 1.01 | 1.57 |
64 | 32 | 4096 | 2048 | 9197.90 | 5079.92 | 5395.38 | 1.81 | 0.94 | 1.70 |
80 | 32 | 512 | 1 | 43.19 | 26.85 | 22.80 | 1.61 | 1.18 | 1.89 |
80 | 32 | 512 | 2 | 35.04 | 27.09 | 23.30 | 1.29 | 1.16 | 1.50 |
80 | 32 | 512 | 4 | 43.95 | 27.55 | 23.78 | 1.60 | 1.16 | 1.85 |
80 | 32 | 512 | 8 | 62.47 | 29.35 | 25.30 | 2.13 | 1.16 | 2.47 |
80 | 32 | 512 | 512 | 440.26 | 281.31 | 250.12 | 1.57 | 1.12 | 1.76 |
80 | 32 | 512 | 1024 | 853.62 | 455.14 | 480.31 | 1.88 | 0.95 | 1.78 |
80 | 32 | 512 | 2048 | 1682.13 | 782.02 | 858.72 | 2.15 | 0.91 | 1.96 |
80 | 32 | 1024 | 1 | 57.76 | 42.67 | 36.95 | 1.35 | 1.15 | 1.56 |
80 | 32 | 1024 | 2 | 59.24 | 43.04 | 37.44 | 1.38 | 1.15 | 1.58 |
80 | 32 | 1024 | 4 | 72.08 | 43.57 | 38.00 | 1.65 | 1.15 | 1.90 |
80 | 32 | 1024 | 8 | 100.33 | 44.93 | 39.45 | 2.23 | 1.14 | 2.54 |
80 | 32 | 1024 | 512 | 777.70 | 500.01 | 454.53 | 1.56 | 1.10 | 1.71 |
80 | 32 | 1024 | 1024 | 1553.05 | 800.78 | 863.37 | 1.94 | 0.93 | 1.80 |
80 | 32 | 1024 | 2048 | 3049.70 | 1377.34 | 1545.30 | 2.21 | 0.89 | 1.97 |
80 | 32 | 2048 | 1 | 90.71 | 76.92 | 60.87 | 1.18 | 1.26 | 1.49 |
80 | 32 | 2048 | 2 | 94.60 | 77.33 | 61.04 | 1.22 | 1.27 | 1.55 |
80 | 32 | 2048 | 4 | 118.46 | 77.86 | 61.44 | 1.52 | 1.27 | 1.93 |
80 | 32 | 2048 | 8 | 169.96 | 79.81 | 62.87 | 2.13 | 1.27 | 2.70 |
80 | 32 | 2048 | 512 | 1317.95 | 940.61 | 857.18 | 1.40 | 1.10 | 1.54 |
80 | 32 | 2048 | 1024 | 2572.62 | 1500.10 | 1627.96 | 1.71 | 0.92 | 1.58 |
80 | 32 | 2048 | 2048 | 5053.02 | 2583.43 | 2922.96 | 1.96 | 0.88 | 1.73 |
80 | 32 | 4096 | 1 | 151.67 | 141.47 | 105.64 | 1.07 | 1.34 | 1.44 |
80 | 32 | 4096 | 2 | 160.51 | 141.61 | 105.51 | 1.13 | 1.34 | 1.52 |
80 | 32 | 4096 | 4 | 204.87 | 142.93 | 106.39 | 1.43 | 1.34 | 1.93 |
80 | 32 | 4096 | 8 | 304.14 | 144.33 | 107.52 | 2.11 | 1.34 | 2.83 |
80 | 32 | 4096 | 512 | 2436.21 | 1818.19 | 1646.07 | 1.34 | 1.10 | 1.48 |
80 | 32 | 4096 | 1024 | 4798.57 | 2884.22 | 3152.04 | 1.66 | 0.92 | 1.52 |
80 | 32 | 4096 | 2048 | 9520.45 | 4991.22 | 5630.93 | 1.91 | 0.89 | 1.69 |
The proposed PR here generally seems to help.
So I think for those batch sizes a specialized implementation without tensor cores will perform better.
I was hoping to find a way to avoid this. The main reason is that the flash attention kernels can be extended to support quantized data (for quantum KV cache) and having 2 versions of the kernel would double the amount of code and effort
I was hoping to find a way to avoid this. The main reason is that the flash attention kernels can be extended to support quantized data (for quantum KV cache) and having 2 versions of the kernel would double the amount of code and effort
I have a prototype for a flash attention kernel for a batch size of 1. Due to the change in memory layout it will not be possible to use int8 intrinsics/tensor cores for the V portion of the KV cache. So the essentially only way to do it will be to use FP16 intrinsics/tensor cores. The overhead per quantization format will then be a conversion to FP16 which we essentially already have and can just be provided via a template parameter.
Obsolete now that https://github.com/ggerganov/llama.cpp/pull/5021 has been merged.