llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

FlashAttention: pragma unroll, use_mask template parameter

Open JohannesGaessler opened this issue 3 months ago • 2 comments

This PR adds two very simple optimizations for FlashAttention:

  • Add #pragma unroll whereever possible so that the compiler unrolls more loops.
  • Add a template parameter use_mask to eliminate some runtime checks.

Together these provide ~2% speedup for small batch sizes on my system:

flash_attention_perf

Quite honestly I don't think the current code can be made to work well for small batch sizes (< 16) though so whether or not this gets merged is of relatively little consequence. The tensor core fragments consume a lot of registers but for small batch sizes there is no real benefit. So I think for those batch sizes a specialized implementation without tensor cores will perform better.

JohannesGaessler avatar Mar 19 '24 16:03 JohannesGaessler

Here are results on V100 using:

# baseline
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf

# flash attn
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf

Note that this is measuring just the attention operations - not the entire throughput.

hs heads n_kv n_batch attn (us/run) gg/flash_attn (us/run) jg/flash_attn (us/run) speedup (attn/gg) speedup (gg/jg) speedup (attn/jg)
128 32 512 1 45.56 54.45 50.54 0.84 1.08 0.90
128 32 512 2 47.33 55.06 49.91 0.86 1.10 0.95
128 32 512 4 56.37 55.95 47.68 1.01 1.17 1.18
128 32 512 8 72.85 58.97 49.99 1.24 1.18 1.46
128 32 512 512 501.63 385.66 367.70 1.30 1.05 1.36
128 32 512 1024 971.40 718.00 690.30 1.35 1.04 1.41
128 32 512 2048 1918.88 1246.56 1213.68 1.54 1.03 1.58
128 32 1024 1 62.33 78.44 76.33 0.79 1.03 0.82
128 32 1024 2 64.11 78.89 76.72 0.81 1.03 0.84
128 32 1024 4 77.55 79.61 77.47 0.97 1.03 1.00
128 32 1024 8 108.78 81.15 79.11 1.34 1.03 1.38
128 32 1024 512 845.88 706.42 673.63 1.20 1.05 1.26
128 32 1024 1024 1682.18 1310.36 1247.34 1.28 1.05 1.35
128 32 1024 2048 3306.77 2206.84 2192.80 1.50 1.01 1.51
128 32 2048 1 96.69 131.55 126.53 0.74 1.04 0.76
128 32 2048 2 100.30 131.64 126.98 0.76 1.04 0.79
128 32 2048 4 125.38 132.14 127.47 0.95 1.04 0.98
128 32 2048 8 179.87 134.20 129.16 1.34 1.04 1.39
128 32 2048 512 1395.07 1343.34 1249.35 1.04 1.08 1.12
128 32 2048 1024 2717.62 2475.43 2353.68 1.10 1.05 1.15
128 32 2048 2048 5349.03 4155.13 4157.71 1.29 1.00 1.29
128 32 4096 1 163.47 234.90 223.14 0.70 1.05 0.73
128 32 4096 2 169.38 235.21 223.53 0.72 1.05 0.76
128 32 4096 4 215.75 235.74 223.69 0.92 1.05 0.96
128 32 4096 8 319.53 237.44 225.08 1.35 1.05 1.42
128 32 4096 512 2575.27 2603.56 2386.44 0.99 1.09 1.08
128 32 4096 1024 5045.29 4806.45 4575.97 1.05 1.05 1.10
128 32 4096 2048 9931.37 8034.10 8106.12 1.24 0.99 1.23
64 32 512 1 39.07 32.72 30.75 1.19 1.06 1.27
64 32 512 2 34.32 32.91 30.96 1.04 1.06 1.11
64 32 512 4 40.67 33.34 31.43 1.22 1.06 1.29
64 32 512 8 55.57 34.33 32.41 1.62 1.06 1.71
64 32 512 512 411.37 248.66 232.71 1.65 1.07 1.77
64 32 512 1024 795.00 436.61 427.91 1.82 1.02 1.86
64 32 512 2048 1565.25 724.60 774.42 2.16 0.94 2.02
64 32 1024 1 53.37 53.02 50.96 1.01 1.04 1.05
64 32 1024 2 53.96 53.06 51.30 1.02 1.03 1.05
64 32 1024 4 66.24 53.68 51.72 1.23 1.04 1.28
64 32 1024 8 93.75 54.38 52.25 1.72 1.04 1.79
64 32 1024 512 740.54 455.52 431.63 1.63 1.06 1.72
64 32 1024 1024 1475.62 769.90 798.53 1.92 0.96 1.85
64 32 1024 2048 2887.93 1329.48 1437.34 2.17 0.92 2.01
64 32 2048 1 82.71 90.10 84.41 0.92 1.07 0.98
64 32 2048 2 86.41 90.37 84.73 0.96 1.07 1.02
64 32 2048 4 110.69 90.96 85.21 1.22 1.07 1.30
64 32 2048 8 159.36 91.73 85.94 1.74 1.07 1.85
64 32 2048 512 1272.46 868.85 814.30 1.46 1.07 1.56
64 32 2048 1024 2486.13 1526.42 1534.50 1.63 0.99 1.62
64 32 2048 2048 4864.27 2576.05 2765.03 1.89 0.93 1.76
64 32 4096 1 135.99 163.04 150.03 0.83 1.09 0.91
64 32 4096 2 144.66 163.31 150.41 0.89 1.09 0.96
64 32 4096 4 190.04 163.64 150.81 1.16 1.09 1.26
64 32 4096 8 285.76 164.50 151.75 1.74 1.08 1.88
64 32 4096 512 2372.08 1951.07 1574.40 1.22 1.24 1.51
64 32 4096 1024 4726.56 3044.91 3005.09 1.55 1.01 1.57
64 32 4096 2048 9197.90 5079.92 5395.38 1.81 0.94 1.70
80 32 512 1 43.19 26.85 22.80 1.61 1.18 1.89
80 32 512 2 35.04 27.09 23.30 1.29 1.16 1.50
80 32 512 4 43.95 27.55 23.78 1.60 1.16 1.85
80 32 512 8 62.47 29.35 25.30 2.13 1.16 2.47
80 32 512 512 440.26 281.31 250.12 1.57 1.12 1.76
80 32 512 1024 853.62 455.14 480.31 1.88 0.95 1.78
80 32 512 2048 1682.13 782.02 858.72 2.15 0.91 1.96
80 32 1024 1 57.76 42.67 36.95 1.35 1.15 1.56
80 32 1024 2 59.24 43.04 37.44 1.38 1.15 1.58
80 32 1024 4 72.08 43.57 38.00 1.65 1.15 1.90
80 32 1024 8 100.33 44.93 39.45 2.23 1.14 2.54
80 32 1024 512 777.70 500.01 454.53 1.56 1.10 1.71
80 32 1024 1024 1553.05 800.78 863.37 1.94 0.93 1.80
80 32 1024 2048 3049.70 1377.34 1545.30 2.21 0.89 1.97
80 32 2048 1 90.71 76.92 60.87 1.18 1.26 1.49
80 32 2048 2 94.60 77.33 61.04 1.22 1.27 1.55
80 32 2048 4 118.46 77.86 61.44 1.52 1.27 1.93
80 32 2048 8 169.96 79.81 62.87 2.13 1.27 2.70
80 32 2048 512 1317.95 940.61 857.18 1.40 1.10 1.54
80 32 2048 1024 2572.62 1500.10 1627.96 1.71 0.92 1.58
80 32 2048 2048 5053.02 2583.43 2922.96 1.96 0.88 1.73
80 32 4096 1 151.67 141.47 105.64 1.07 1.34 1.44
80 32 4096 2 160.51 141.61 105.51 1.13 1.34 1.52
80 32 4096 4 204.87 142.93 106.39 1.43 1.34 1.93
80 32 4096 8 304.14 144.33 107.52 2.11 1.34 2.83
80 32 4096 512 2436.21 1818.19 1646.07 1.34 1.10 1.48
80 32 4096 1024 4798.57 2884.22 3152.04 1.66 0.92 1.52
80 32 4096 2048 9520.45 4991.22 5630.93 1.91 0.89 1.69

The proposed PR here generally seems to help.

So I think for those batch sizes a specialized implementation without tensor cores will perform better.

I was hoping to find a way to avoid this. The main reason is that the flash attention kernels can be extended to support quantized data (for quantum KV cache) and having 2 versions of the kernel would double the amount of code and effort

ggerganov avatar Mar 20 '24 11:03 ggerganov

I was hoping to find a way to avoid this. The main reason is that the flash attention kernels can be extended to support quantized data (for quantum KV cache) and having 2 versions of the kernel would double the amount of code and effort

I have a prototype for a flash attention kernel for a batch size of 1. Due to the change in memory layout it will not be possible to use int8 intrinsics/tensor cores for the V portion of the KV cache. So the essentially only way to do it will be to use FP16 intrinsics/tensor cores. The overhead per quantization format will then be a conversion to FP16 which we essentially already have and can just be provided via a template parameter.

JohannesGaessler avatar Mar 20 '24 12:03 JohannesGaessler

Obsolete now that https://github.com/ggerganov/llama.cpp/pull/5021 has been merged.

JohannesGaessler avatar Apr 30 '24 10:04 JohannesGaessler