llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

CUDA: optimize FA for GQA + large batches

Open JohannesGaessler opened this issue 1 day ago • 0 comments

This PR adds the following optimizations to the CUDA FlashAttention code:

  • For models with group-query attention, re-use the loaded K/V data across multiple attention heads. This also has the advantage of reducing the number of tokens that each CUDA block works on by a factor of up to 8 (limited by the ratio between K/V and Q) so there is less wasted compute due to padded tokens (up to 63 on master). For GQA models the mma-based kernels now seem to be faster than the generic vector kernels for batch size 1 so they are used if possible. The compilation time for a non-native CUDA build does not increase if enough threads are available: there is no increase at 32 threads, with 16 threads the compilation time increases by 6%. The size of ggml-cuda.so increases from 465 MiB to 499 MiB.
  • Asynchronously load the KQ mask if possible. Re-use the data for multiple attention heads when using GQA.
  • Assign more Q column to each warp. This increases arithmetic intensity and also makes it possible to use a column-major layout for KQ and KQV which needs less data movement. The implementation supports up to 32 columns per warp, but due to register pressure this is only 4% faster than 16 columns (kernel throughput, not end-to-end throughput). But it would require compiling an extra kernel version for a total of 128 columns. I'm keeping the 32 column implementation even though it is currently unused because I think it will be worthwhile for FP16 and BF16 precision (right now FP32 precision is always used).
  • Reduce stream-k overhead by reducing the granularity with which KQ columns are assigned and by using more threads.
Performance changes
GPU Model Microbatch size Test t/s master t/s PR Speedup
RTX 3090 gemma 2B F16 1 pp16384 151.98 161.85 1.06
RTX 3090 gemma 2B F16 2 pp16384 248.87 259.87 1.04
RTX 3090 gemma 2B F16 3 pp16384 369.76 383.88 1.04
RTX 3090 gemma 2B F16 4 pp16384 498.01 521.12 1.05
RTX 3090 gemma 2B F16 6 pp16384 727.41 736.19 1.01
RTX 3090 gemma 2B F16 8 pp16384 951.74 972.83 1.02
RTX 3090 gemma 2B F16 12 pp16384 1349.70 1399.32 1.04
RTX 3090 gemma 2B F16 16 pp16384 1740.43 1859.05 1.07
RTX 3090 gemma 2B F16 24 pp16384 2318.17 2700.83 1.17
RTX 3090 gemma 2B F16 32 pp16384 2836.44 3411.02 1.20
RTX 3090 gemma 2B F16 48 pp16384 3258.66 4758.89 1.46
RTX 3090 gemma 2B F16 64 pp16384 3779.50 6000.23 1.59
RTX 3090 gemma 2B F16 96 pp16384 5828.39 7501.01 1.29
RTX 3090 gemma 2B F16 128 pp16384 7515.51 9210.69 1.23
RTX 3090 gemma 2B F16 256 pp16384 9844.77 10642.87 1.08
RTX 3090 gemma 2B F16 512 pp16384 10951.22 11514.83 1.05
RTX 3090 gemma 2B F16 1024 pp16384 11572.58 11844.54 1.02
RTX 3090 gemma 2B F16 2048 pp16384 11724.84 12087.77 1.03
RTX 3090 gemma 2B F16 4096 pp16384 11730.26 12089.48 1.03
RTX 3090 gemma 2B F16 8192 pp16384 11732.49 12104.29 1.03
RTX 3090 gemma 2B F16 16384 pp16384 11731.33 12123.26 1.03
RTX 3090 llama 8B Q4_0 1 pp16384 108.63 125.22 1.15
RTX 3090 llama 8B Q4_0 2 pp16384 158.55 223.90 1.41
RTX 3090 llama 8B Q4_0 3 pp16384 231.06 319.86 1.38
RTX 3090 llama 8B Q4_0 4 pp16384 287.54 388.13 1.35
RTX 3090 llama 8B Q4_0 6 pp16384 373.77 474.34 1.27
RTX 3090 llama 8B Q4_0 8 pp16384 429.04 523.66 1.22
RTX 3090 llama 8B Q4_0 12 pp16384 643.30 793.60 1.23
RTX 3090 llama 8B Q4_0 16 pp16384 848.68 1047.13 1.23
RTX 3090 llama 8B Q4_0 24 pp16384 1165.66 1357.03 1.16
RTX 3090 llama 8B Q4_0 32 pp16384 1448.24 1666.54 1.15
RTX 3090 llama 8B Q4_0 48 pp16384 1903.56 2173.54 1.14
RTX 3090 llama 8B Q4_0 64 pp16384 2229.33 2471.21 1.11
RTX 3090 llama 8B Q4_0 96 pp16384 2390.78 2751.08 1.15
RTX 3090 llama 8B Q4_0 128 pp16384 2830.15 3007.53 1.06
RTX 3090 llama 8B Q4_0 256 pp16384 3544.97 3648.56 1.03
RTX 3090 llama 8B Q4_0 512 pp16384 3668.09 3796.45 1.03
RTX 3090 llama 8B Q4_0 1024 pp16384 3805.70 3978.51 1.05
RTX 3090 llama 8B Q4_0 2048 pp16384 3787.72 3968.53 1.05
RTX 3090 llama 8B Q4_0 4096 pp16384 3772.90 3974.59 1.05
RTX 3090 llama 8B Q4_0 8192 pp16384 3795.54 3974.16 1.05
RTX 3090 llama 8B Q4_0 16384 pp16384 3796.23 3973.22 1.05
RTX 3090 phi2 3B F16 1 pp16384 82.97 81.13 0.98
RTX 3090 phi2 3B F16 2 pp16384 144.66 141.37 0.98
RTX 3090 phi2 3B F16 3 pp16384 216.39 211.16 0.98
RTX 3090 phi2 3B F16 4 pp16384 284.68 280.74 0.99
RTX 3090 phi2 3B F16 6 pp16384 420.45 417.67 0.99
RTX 3090 phi2 3B F16 8 pp16384 556.61 554.69 1.00
RTX 3090 phi2 3B F16 12 pp16384 802.97 811.75 1.01
RTX 3090 phi2 3B F16 16 pp16384 1065.07 1073.04 1.01
RTX 3090 phi2 3B F16 24 pp16384 1505.43 1559.60 1.04
RTX 3090 phi2 3B F16 32 pp16384 1893.57 1998.59 1.06
RTX 3090 phi2 3B F16 48 pp16384 2462.76 2734.08 1.11
RTX 3090 phi2 3B F16 64 pp16384 3154.33 3623.37 1.15
RTX 3090 phi2 3B F16 96 pp16384 3539.53 3820.18 1.08
RTX 3090 phi2 3B F16 128 pp16384 4573.37 4917.37 1.08
RTX 3090 phi2 3B F16 256 pp16384 5981.49 6405.33 1.07
RTX 3090 phi2 3B F16 512 pp16384 6438.91 6622.56 1.03
RTX 3090 phi2 3B F16 1024 pp16384 6494.09 6894.79 1.06
RTX 3090 phi2 3B F16 2048 pp16384 6609.94 7041.81 1.07
RTX 3090 phi2 3B F16 4096 pp16384 6674.94 7038.99 1.05
RTX 3090 phi2 3B F16 8192 pp16384 6674.89 7046.58 1.06
RTX 3090 phi2 3B F16 16384 pp16384 6676.13 7050.12 1.06
RTX 4090 gemma 2B F16 1 pp16384 188.71 192.21 1.02
RTX 4090 gemma 2B F16 2 pp16384 315.29 319.45 1.01
RTX 4090 gemma 2B F16 3 pp16384 467.69 472.12 1.01
RTX 4090 gemma 2B F16 4 pp16384 636.10 646.32 1.02
RTX 4090 gemma 2B F16 6 pp16384 924.58 919.83 0.99
RTX 4090 gemma 2B F16 8 pp16384 1208.77 1220.84 1.01
RTX 4090 gemma 2B F16 12 pp16384 1704.04 1809.66 1.06
RTX 4090 gemma 2B F16 16 pp16384 2174.59 2382.26 1.10
RTX 4090 gemma 2B F16 24 pp16384 2951.43 3506.05 1.19
RTX 4090 gemma 2B F16 32 pp16384 3679.03 4612.43 1.25
RTX 4090 gemma 2B F16 48 pp16384 4425.07 6216.53 1.40
RTX 4090 gemma 2B F16 64 pp16384 5176.16 8052.42 1.56
RTX 4090 gemma 2B F16 96 pp16384 8523.83 10718.71 1.26
RTX 4090 gemma 2B F16 128 pp16384 11237.08 13797.04 1.23
RTX 4090 gemma 2B F16 256 pp16384 19068.33 20758.07 1.09
RTX 4090 gemma 2B F16 512 pp16384 25130.53 26197.02 1.04
RTX 4090 gemma 2B F16 1024 pp16384 24602.15 25386.13 1.03
RTX 4090 gemma 2B F16 2048 pp16384 24860.08 25424.70 1.02
RTX 4090 gemma 2B F16 4096 pp16384 24780.69 25453.42 1.03
RTX 4090 gemma 2B F16 8192 pp16384 24787.61 25393.45 1.02
RTX 4090 gemma 2B F16 16384 pp16384 24823.15 25419.06 1.02
RTX 4090 llama 8B Q4_0 1 pp16384 141.40 146.86 1.04
RTX 4090 llama 8B Q4_0 2 pp16384 266.21 272.68 1.02
RTX 4090 llama 8B Q4_0 3 pp16384 396.38 402.35 1.02
RTX 4090 llama 8B Q4_0 4 pp16384 523.18 531.34 1.02
RTX 4090 llama 8B Q4_0 6 pp16384 738.60 733.59 0.99
RTX 4090 llama 8B Q4_0 8 pp16384 910.13 910.43 1.00
RTX 4090 llama 8B Q4_0 12 pp16384 1166.29 1156.32 0.99
RTX 4090 llama 8B Q4_0 16 pp16384 1529.82 1531.36 1.00
RTX 4090 llama 8B Q4_0 24 pp16384 2151.15 2215.12 1.03
RTX 4090 llama 8B Q4_0 32 pp16384 2682.64 2803.91 1.05
RTX 4090 llama 8B Q4_0 48 pp16384 3593.15 3906.87 1.09
RTX 4090 llama 8B Q4_0 64 pp16384 4331.18 5025.63 1.16
RTX 4090 llama 8B Q4_0 96 pp16384 5336.88 5825.80 1.09
RTX 4090 llama 8B Q4_0 128 pp16384 6511.68 6979.52 1.07
RTX 4090 llama 8B Q4_0 256 pp16384 8419.66 8866.89 1.05
RTX 4090 llama 8B Q4_0 512 pp16384 9595.88 9909.55 1.03
RTX 4090 llama 8B Q4_0 1024 pp16384 9620.89 9995.12 1.04
RTX 4090 llama 8B Q4_0 2048 pp16384 9189.02 9598.75 1.04
RTX 4090 llama 8B Q4_0 4096 pp16384 9197.18 9607.84 1.04
RTX 4090 llama 8B Q4_0 8192 pp16384 9190.60 9600.52 1.04
RTX 4090 llama 8B Q4_0 16384 pp16384 9201.13 9600.01 1.04
RTX 4090 phi2 3B F16 1 pp16384 103.75 103.12 0.99
RTX 4090 phi2 3B F16 2 pp16384 179.58 179.09 1.00
RTX 4090 phi2 3B F16 3 pp16384 268.44 267.97 1.00
RTX 4090 phi2 3B F16 4 pp16384 356.08 356.40 1.00
RTX 4090 phi2 3B F16 6 pp16384 527.54 531.12 1.01
RTX 4090 phi2 3B F16 8 pp16384 699.58 705.18 1.01
RTX 4090 phi2 3B F16 12 pp16384 1010.70 1039.98 1.03
RTX 4090 phi2 3B F16 16 pp16384 1314.95 1379.61 1.05
RTX 4090 phi2 3B F16 24 pp16384 1947.20 2028.79 1.04
RTX 4090 phi2 3B F16 32 pp16384 2529.39 2679.80 1.06
RTX 4090 phi2 3B F16 48 pp16384 3492.89 3817.46 1.09
RTX 4090 phi2 3B F16 64 pp16384 4354.42 4947.95 1.14
RTX 4090 phi2 3B F16 96 pp16384 6032.40 6465.26 1.07
RTX 4090 phi2 3B F16 128 pp16384 7906.06 8408.58 1.06
RTX 4090 phi2 3B F16 256 pp16384 11747.52 12016.59 1.02
RTX 4090 phi2 3B F16 512 pp16384 14291.87 14384.89 1.01
RTX 4090 phi2 3B F16 1024 pp16384 15898.14 16127.53 1.01
RTX 4090 phi2 3B F16 2048 pp16384 14419.83 14688.35 1.02
RTX 4090 phi2 3B F16 4096 pp16384 14447.28 14664.04 1.02
RTX 4090 phi2 3B F16 8192 pp16384 14423.79 14669.22 1.02
RTX 4090 phi2 3B F16 16384 pp16384 14421.78 14721.65 1.02

For context, LLaMA 3 uses GQA with 4 Q per K/V, Gemma uses GQA with 8 Q per K/V, and Phi 2 does not use GQA at all.

JohannesGaessler avatar Feb 21 '25 22:02 JohannesGaessler