llama.cpp
llama.cpp copied to clipboard
CUDA: optimize FA for GQA + large batches
This PR adds the following optimizations to the CUDA FlashAttention code:
- For models with group-query attention, re-use the loaded K/V data across multiple attention heads. This also has the advantage of reducing the number of tokens that each CUDA block works on by a factor of up to 8 (limited by the ratio between K/V and Q) so there is less wasted compute due to padded tokens (up to 63 on master). For GQA models the mma-based kernels now seem to be faster than the generic vector kernels for batch size 1 so they are used if possible. The compilation time for a non-native CUDA build does not increase if enough threads are available: there is no increase at 32 threads, with 16 threads the compilation time increases by 6%. The size of
ggml-cuda.so
increases from 465 MiB to 499 MiB. - Asynchronously load the KQ mask if possible. Re-use the data for multiple attention heads when using GQA.
- Assign more Q column to each warp. This increases arithmetic intensity and also makes it possible to use a column-major layout for KQ and KQV which needs less data movement. The implementation supports up to 32 columns per warp, but due to register pressure this is only 4% faster than 16 columns (kernel throughput, not end-to-end throughput). But it would require compiling an extra kernel version for a total of 128 columns. I'm keeping the 32 column implementation even though it is currently unused because I think it will be worthwhile for FP16 and BF16 precision (right now FP32 precision is always used).
- Reduce stream-k overhead by reducing the granularity with which KQ columns are assigned and by using more threads.
Performance changes
GPU | Model | Microbatch size | Test | t/s master | t/s PR | Speedup |
---|---|---|---|---|---|---|
RTX 3090 | gemma 2B F16 | 1 | pp16384 | 151.98 | 161.85 | 1.06 |
RTX 3090 | gemma 2B F16 | 2 | pp16384 | 248.87 | 259.87 | 1.04 |
RTX 3090 | gemma 2B F16 | 3 | pp16384 | 369.76 | 383.88 | 1.04 |
RTX 3090 | gemma 2B F16 | 4 | pp16384 | 498.01 | 521.12 | 1.05 |
RTX 3090 | gemma 2B F16 | 6 | pp16384 | 727.41 | 736.19 | 1.01 |
RTX 3090 | gemma 2B F16 | 8 | pp16384 | 951.74 | 972.83 | 1.02 |
RTX 3090 | gemma 2B F16 | 12 | pp16384 | 1349.70 | 1399.32 | 1.04 |
RTX 3090 | gemma 2B F16 | 16 | pp16384 | 1740.43 | 1859.05 | 1.07 |
RTX 3090 | gemma 2B F16 | 24 | pp16384 | 2318.17 | 2700.83 | 1.17 |
RTX 3090 | gemma 2B F16 | 32 | pp16384 | 2836.44 | 3411.02 | 1.20 |
RTX 3090 | gemma 2B F16 | 48 | pp16384 | 3258.66 | 4758.89 | 1.46 |
RTX 3090 | gemma 2B F16 | 64 | pp16384 | 3779.50 | 6000.23 | 1.59 |
RTX 3090 | gemma 2B F16 | 96 | pp16384 | 5828.39 | 7501.01 | 1.29 |
RTX 3090 | gemma 2B F16 | 128 | pp16384 | 7515.51 | 9210.69 | 1.23 |
RTX 3090 | gemma 2B F16 | 256 | pp16384 | 9844.77 | 10642.87 | 1.08 |
RTX 3090 | gemma 2B F16 | 512 | pp16384 | 10951.22 | 11514.83 | 1.05 |
RTX 3090 | gemma 2B F16 | 1024 | pp16384 | 11572.58 | 11844.54 | 1.02 |
RTX 3090 | gemma 2B F16 | 2048 | pp16384 | 11724.84 | 12087.77 | 1.03 |
RTX 3090 | gemma 2B F16 | 4096 | pp16384 | 11730.26 | 12089.48 | 1.03 |
RTX 3090 | gemma 2B F16 | 8192 | pp16384 | 11732.49 | 12104.29 | 1.03 |
RTX 3090 | gemma 2B F16 | 16384 | pp16384 | 11731.33 | 12123.26 | 1.03 |
RTX 3090 | llama 8B Q4_0 | 1 | pp16384 | 108.63 | 125.22 | 1.15 |
RTX 3090 | llama 8B Q4_0 | 2 | pp16384 | 158.55 | 223.90 | 1.41 |
RTX 3090 | llama 8B Q4_0 | 3 | pp16384 | 231.06 | 319.86 | 1.38 |
RTX 3090 | llama 8B Q4_0 | 4 | pp16384 | 287.54 | 388.13 | 1.35 |
RTX 3090 | llama 8B Q4_0 | 6 | pp16384 | 373.77 | 474.34 | 1.27 |
RTX 3090 | llama 8B Q4_0 | 8 | pp16384 | 429.04 | 523.66 | 1.22 |
RTX 3090 | llama 8B Q4_0 | 12 | pp16384 | 643.30 | 793.60 | 1.23 |
RTX 3090 | llama 8B Q4_0 | 16 | pp16384 | 848.68 | 1047.13 | 1.23 |
RTX 3090 | llama 8B Q4_0 | 24 | pp16384 | 1165.66 | 1357.03 | 1.16 |
RTX 3090 | llama 8B Q4_0 | 32 | pp16384 | 1448.24 | 1666.54 | 1.15 |
RTX 3090 | llama 8B Q4_0 | 48 | pp16384 | 1903.56 | 2173.54 | 1.14 |
RTX 3090 | llama 8B Q4_0 | 64 | pp16384 | 2229.33 | 2471.21 | 1.11 |
RTX 3090 | llama 8B Q4_0 | 96 | pp16384 | 2390.78 | 2751.08 | 1.15 |
RTX 3090 | llama 8B Q4_0 | 128 | pp16384 | 2830.15 | 3007.53 | 1.06 |
RTX 3090 | llama 8B Q4_0 | 256 | pp16384 | 3544.97 | 3648.56 | 1.03 |
RTX 3090 | llama 8B Q4_0 | 512 | pp16384 | 3668.09 | 3796.45 | 1.03 |
RTX 3090 | llama 8B Q4_0 | 1024 | pp16384 | 3805.70 | 3978.51 | 1.05 |
RTX 3090 | llama 8B Q4_0 | 2048 | pp16384 | 3787.72 | 3968.53 | 1.05 |
RTX 3090 | llama 8B Q4_0 | 4096 | pp16384 | 3772.90 | 3974.59 | 1.05 |
RTX 3090 | llama 8B Q4_0 | 8192 | pp16384 | 3795.54 | 3974.16 | 1.05 |
RTX 3090 | llama 8B Q4_0 | 16384 | pp16384 | 3796.23 | 3973.22 | 1.05 |
RTX 3090 | phi2 3B F16 | 1 | pp16384 | 82.97 | 81.13 | 0.98 |
RTX 3090 | phi2 3B F16 | 2 | pp16384 | 144.66 | 141.37 | 0.98 |
RTX 3090 | phi2 3B F16 | 3 | pp16384 | 216.39 | 211.16 | 0.98 |
RTX 3090 | phi2 3B F16 | 4 | pp16384 | 284.68 | 280.74 | 0.99 |
RTX 3090 | phi2 3B F16 | 6 | pp16384 | 420.45 | 417.67 | 0.99 |
RTX 3090 | phi2 3B F16 | 8 | pp16384 | 556.61 | 554.69 | 1.00 |
RTX 3090 | phi2 3B F16 | 12 | pp16384 | 802.97 | 811.75 | 1.01 |
RTX 3090 | phi2 3B F16 | 16 | pp16384 | 1065.07 | 1073.04 | 1.01 |
RTX 3090 | phi2 3B F16 | 24 | pp16384 | 1505.43 | 1559.60 | 1.04 |
RTX 3090 | phi2 3B F16 | 32 | pp16384 | 1893.57 | 1998.59 | 1.06 |
RTX 3090 | phi2 3B F16 | 48 | pp16384 | 2462.76 | 2734.08 | 1.11 |
RTX 3090 | phi2 3B F16 | 64 | pp16384 | 3154.33 | 3623.37 | 1.15 |
RTX 3090 | phi2 3B F16 | 96 | pp16384 | 3539.53 | 3820.18 | 1.08 |
RTX 3090 | phi2 3B F16 | 128 | pp16384 | 4573.37 | 4917.37 | 1.08 |
RTX 3090 | phi2 3B F16 | 256 | pp16384 | 5981.49 | 6405.33 | 1.07 |
RTX 3090 | phi2 3B F16 | 512 | pp16384 | 6438.91 | 6622.56 | 1.03 |
RTX 3090 | phi2 3B F16 | 1024 | pp16384 | 6494.09 | 6894.79 | 1.06 |
RTX 3090 | phi2 3B F16 | 2048 | pp16384 | 6609.94 | 7041.81 | 1.07 |
RTX 3090 | phi2 3B F16 | 4096 | pp16384 | 6674.94 | 7038.99 | 1.05 |
RTX 3090 | phi2 3B F16 | 8192 | pp16384 | 6674.89 | 7046.58 | 1.06 |
RTX 3090 | phi2 3B F16 | 16384 | pp16384 | 6676.13 | 7050.12 | 1.06 |
RTX 4090 | gemma 2B F16 | 1 | pp16384 | 188.71 | 192.21 | 1.02 |
RTX 4090 | gemma 2B F16 | 2 | pp16384 | 315.29 | 319.45 | 1.01 |
RTX 4090 | gemma 2B F16 | 3 | pp16384 | 467.69 | 472.12 | 1.01 |
RTX 4090 | gemma 2B F16 | 4 | pp16384 | 636.10 | 646.32 | 1.02 |
RTX 4090 | gemma 2B F16 | 6 | pp16384 | 924.58 | 919.83 | 0.99 |
RTX 4090 | gemma 2B F16 | 8 | pp16384 | 1208.77 | 1220.84 | 1.01 |
RTX 4090 | gemma 2B F16 | 12 | pp16384 | 1704.04 | 1809.66 | 1.06 |
RTX 4090 | gemma 2B F16 | 16 | pp16384 | 2174.59 | 2382.26 | 1.10 |
RTX 4090 | gemma 2B F16 | 24 | pp16384 | 2951.43 | 3506.05 | 1.19 |
RTX 4090 | gemma 2B F16 | 32 | pp16384 | 3679.03 | 4612.43 | 1.25 |
RTX 4090 | gemma 2B F16 | 48 | pp16384 | 4425.07 | 6216.53 | 1.40 |
RTX 4090 | gemma 2B F16 | 64 | pp16384 | 5176.16 | 8052.42 | 1.56 |
RTX 4090 | gemma 2B F16 | 96 | pp16384 | 8523.83 | 10718.71 | 1.26 |
RTX 4090 | gemma 2B F16 | 128 | pp16384 | 11237.08 | 13797.04 | 1.23 |
RTX 4090 | gemma 2B F16 | 256 | pp16384 | 19068.33 | 20758.07 | 1.09 |
RTX 4090 | gemma 2B F16 | 512 | pp16384 | 25130.53 | 26197.02 | 1.04 |
RTX 4090 | gemma 2B F16 | 1024 | pp16384 | 24602.15 | 25386.13 | 1.03 |
RTX 4090 | gemma 2B F16 | 2048 | pp16384 | 24860.08 | 25424.70 | 1.02 |
RTX 4090 | gemma 2B F16 | 4096 | pp16384 | 24780.69 | 25453.42 | 1.03 |
RTX 4090 | gemma 2B F16 | 8192 | pp16384 | 24787.61 | 25393.45 | 1.02 |
RTX 4090 | gemma 2B F16 | 16384 | pp16384 | 24823.15 | 25419.06 | 1.02 |
RTX 4090 | llama 8B Q4_0 | 1 | pp16384 | 141.40 | 146.86 | 1.04 |
RTX 4090 | llama 8B Q4_0 | 2 | pp16384 | 266.21 | 272.68 | 1.02 |
RTX 4090 | llama 8B Q4_0 | 3 | pp16384 | 396.38 | 402.35 | 1.02 |
RTX 4090 | llama 8B Q4_0 | 4 | pp16384 | 523.18 | 531.34 | 1.02 |
RTX 4090 | llama 8B Q4_0 | 6 | pp16384 | 738.60 | 733.59 | 0.99 |
RTX 4090 | llama 8B Q4_0 | 8 | pp16384 | 910.13 | 910.43 | 1.00 |
RTX 4090 | llama 8B Q4_0 | 12 | pp16384 | 1166.29 | 1156.32 | 0.99 |
RTX 4090 | llama 8B Q4_0 | 16 | pp16384 | 1529.82 | 1531.36 | 1.00 |
RTX 4090 | llama 8B Q4_0 | 24 | pp16384 | 2151.15 | 2215.12 | 1.03 |
RTX 4090 | llama 8B Q4_0 | 32 | pp16384 | 2682.64 | 2803.91 | 1.05 |
RTX 4090 | llama 8B Q4_0 | 48 | pp16384 | 3593.15 | 3906.87 | 1.09 |
RTX 4090 | llama 8B Q4_0 | 64 | pp16384 | 4331.18 | 5025.63 | 1.16 |
RTX 4090 | llama 8B Q4_0 | 96 | pp16384 | 5336.88 | 5825.80 | 1.09 |
RTX 4090 | llama 8B Q4_0 | 128 | pp16384 | 6511.68 | 6979.52 | 1.07 |
RTX 4090 | llama 8B Q4_0 | 256 | pp16384 | 8419.66 | 8866.89 | 1.05 |
RTX 4090 | llama 8B Q4_0 | 512 | pp16384 | 9595.88 | 9909.55 | 1.03 |
RTX 4090 | llama 8B Q4_0 | 1024 | pp16384 | 9620.89 | 9995.12 | 1.04 |
RTX 4090 | llama 8B Q4_0 | 2048 | pp16384 | 9189.02 | 9598.75 | 1.04 |
RTX 4090 | llama 8B Q4_0 | 4096 | pp16384 | 9197.18 | 9607.84 | 1.04 |
RTX 4090 | llama 8B Q4_0 | 8192 | pp16384 | 9190.60 | 9600.52 | 1.04 |
RTX 4090 | llama 8B Q4_0 | 16384 | pp16384 | 9201.13 | 9600.01 | 1.04 |
RTX 4090 | phi2 3B F16 | 1 | pp16384 | 103.75 | 103.12 | 0.99 |
RTX 4090 | phi2 3B F16 | 2 | pp16384 | 179.58 | 179.09 | 1.00 |
RTX 4090 | phi2 3B F16 | 3 | pp16384 | 268.44 | 267.97 | 1.00 |
RTX 4090 | phi2 3B F16 | 4 | pp16384 | 356.08 | 356.40 | 1.00 |
RTX 4090 | phi2 3B F16 | 6 | pp16384 | 527.54 | 531.12 | 1.01 |
RTX 4090 | phi2 3B F16 | 8 | pp16384 | 699.58 | 705.18 | 1.01 |
RTX 4090 | phi2 3B F16 | 12 | pp16384 | 1010.70 | 1039.98 | 1.03 |
RTX 4090 | phi2 3B F16 | 16 | pp16384 | 1314.95 | 1379.61 | 1.05 |
RTX 4090 | phi2 3B F16 | 24 | pp16384 | 1947.20 | 2028.79 | 1.04 |
RTX 4090 | phi2 3B F16 | 32 | pp16384 | 2529.39 | 2679.80 | 1.06 |
RTX 4090 | phi2 3B F16 | 48 | pp16384 | 3492.89 | 3817.46 | 1.09 |
RTX 4090 | phi2 3B F16 | 64 | pp16384 | 4354.42 | 4947.95 | 1.14 |
RTX 4090 | phi2 3B F16 | 96 | pp16384 | 6032.40 | 6465.26 | 1.07 |
RTX 4090 | phi2 3B F16 | 128 | pp16384 | 7906.06 | 8408.58 | 1.06 |
RTX 4090 | phi2 3B F16 | 256 | pp16384 | 11747.52 | 12016.59 | 1.02 |
RTX 4090 | phi2 3B F16 | 512 | pp16384 | 14291.87 | 14384.89 | 1.01 |
RTX 4090 | phi2 3B F16 | 1024 | pp16384 | 15898.14 | 16127.53 | 1.01 |
RTX 4090 | phi2 3B F16 | 2048 | pp16384 | 14419.83 | 14688.35 | 1.02 |
RTX 4090 | phi2 3B F16 | 4096 | pp16384 | 14447.28 | 14664.04 | 1.02 |
RTX 4090 | phi2 3B F16 | 8192 | pp16384 | 14423.79 | 14669.22 | 1.02 |
RTX 4090 | phi2 3B F16 | 16384 | pp16384 | 14421.78 | 14721.65 | 1.02 |
For context, LLaMA 3 uses GQA with 4 Q per K/V, Gemma uses GQA with 8 Q per K/V, and Phi 2 does not use GQA at all.