llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

CPU/CUDA: Gemma 2 FlashAttention support

Open JohannesGaessler opened this issue 1 year ago • 8 comments

This PR adds FlashAttention support for Gemma 2 on the CPU and CUDA backends by adding another parameter that controls the logit softcap. A value of 0.0f indicates no logit softcapping. When specifying a different value the scale parameter is divided by this value and prior to the softmax the tangens hyperbolicus as well as the scaling by logit_softcap is applied. Because this changes the position of the parameter that indicates the precision needed for FlashAttention I think that this PR breaks the Metal FlashAttention implementation (I am not aware of any other FlashAttention implementations). I tried searching for the spot where the Metal code retrieves the FlashaAttention precision but I was not able to find it; help would be very much appreciated.

This PR adds template instances with logit softcapping for the CUDA FlashAttention kernels for head sizes 128 (Gemma 2 27b) and 256 (Gemma 2 9b). For all other head sizes no instance is compiled since (to my knowledge) there are no models that would use them. The FlashAttention kernels that do not use FP16 tensor cores only support head sizes 64 and 128 so Gemma 2 9b is only supported on NVIDIA GPUs with compute capability >= 7.0. For simplicity tests/test-backend-ops.cpp only checks head size 128 (which should be enough since logit softcapping is a scalar operation and does not depend on head size).

GPU Performance
Model GPU Microbatch size Test t/s eager attention t/s FlashAttention Speedup
gemma2 27B Q2_K_M RX 6800 1 pp4096 10.96 13.79 1.26
gemma2 27B Q2_K_M RX 6800 2 pp4096 17.34 20.48 1.18
gemma2 27B Q2_K_M RX 6800 4 pp4096 21.76 22.31 1.03
gemma2 27B Q2_K_M RX 6800 8 pp4096 25.90 22.50 0.87
gemma2 27B Q2_K_M RX 6800 16 pp4096 58.28 40.61 0.70
gemma2 27B Q2_K_M RX 6800 32 pp4096 84.49 66.19 0.78
gemma2 27B Q2_K_M RX 6800 64 pp4096 94.47 69.94 0.74
gemma2 27B Q2_K_M RX 6800 128 pp4096 112.39 84.81 0.75
gemma2 27B Q2_K_M RX 6800 256 pp4096 126.55 98.17 0.78
gemma2 27B Q2_K_M RX 6800 512 pp4096 126.16 102.64 0.81
gemma2 27B Q2_K_M RX 6800 1024 pp4096 134.20 100.69 0.75
gemma2 27B Q2_K_M RX 6800 2048 pp4096 130.98 95.72 0.73
gemma2 27B Q4_0 RTX 3090 1 pp4096 41.22 44.08 1.07
gemma2 27B Q4_0 RTX 3090 2 pp4096 73.88 80.99 1.10
gemma2 27B Q4_0 RTX 3090 4 pp4096 115.02 125.48 1.09
gemma2 27B Q4_0 RTX 3090 8 pp4096 145.30 161.40 1.11
gemma2 27B Q4_0 RTX 3090 16 pp4096 318.50 406.84 1.28
gemma2 27B Q4_0 RTX 3090 32 pp4096 585.99 654.12 1.12
gemma2 27B Q4_0 RTX 3090 64 pp4096 743.71 859.31 1.16
gemma2 27B Q4_0 RTX 3090 128 pp4096 913.71 1092.52 1.20
gemma2 27B Q4_0 RTX 3090 256 pp4096 986.56 1206.55 1.22
gemma2 27B Q4_0 RTX 3090 512 pp4096 992.06 1230.06 1.24
gemma2 27B Q4_0 RTX 3090 1024 pp4096 978.00 1247.45 1.28
gemma2 27B Q4_0 RTX 3090 2048 pp4096 942.77 1231.07 1.31
gemma2 27B Q4_0 RTX 3090 4096 pp4096 850.81 1189.47 1.40
gemma2 27B Q4_0 RTX 4090 1 pp4096 51.31 54.78 1.07
gemma2 27B Q4_0 RTX 4090 2 pp4096 99.17 103.45 1.04
gemma2 27B Q4_0 RTX 4090 4 pp4096 194.06 203.43 1.05
gemma2 27B Q4_0 RTX 4090 8 pp4096 324.52 343.86 1.06
gemma2 27B Q4_0 RTX 4090 16 pp4096 601.44 640.22 1.06
gemma2 27B Q4_0 RTX 4090 32 pp4096 1081.21 1173.60 1.09
gemma2 27B Q4_0 RTX 4090 64 pp4096 1737.82 1934.71 1.11
gemma2 27B Q4_0 RTX 4090 128 pp4096 2325.77 2735.54 1.18
gemma2 27B Q4_0 RTX 4090 256 pp4096 2237.85 3151.80 1.41
gemma2 27B Q4_0 RTX 4090 512 pp4096 2154.18 3224.92 1.50
gemma2 27B Q4_0 RTX 4090 1024 pp4096 2053.00 3218.63 1.57
gemma2 27B Q4_0 RTX 4090 2048 pp4096 1839.51 3112.09 1.69
gemma2 27B Q4_0 RTX 4090 4096 pp4096 1525.74 2873.14 1.88
gemma2 27B Q4_0 P40 1 pp4096 14.45 13.63 0.94
gemma2 27B Q4_0 P40 2 pp4096 19.53 25.06 1.28
gemma2 27B Q4_0 P40 4 pp4096 31.14 34.85 1.12
gemma2 27B Q4_0 P40 8 pp4096 36.98 39.83 1.08
gemma2 27B Q4_0 P40 16 pp4096 82.56 144.28 1.75
gemma2 27B Q4_0 P40 32 pp4096 127.71 193.58 1.52
gemma2 27B Q4_0 P40 64 pp4096 165.50 213.45 1.29
gemma2 27B Q4_0 P40 128 pp4096 206.77 245.01 1.18
gemma2 27B Q4_0 P40 256 pp4096 236.71 267.31 1.13
gemma2 27B Q4_0 P40 512 pp4096 248.61 272.20 1.09
gemma2 27B Q4_0 P40 1024 pp4096 248.30 270.26 1.09
gemma2 27B Q4_0 P40 2048 pp4096 245.07 265.91 1.09
gemma2 27B Q4_0 P40 4096 pp4096 229.73 253.44 1.10
gemma2 9B Q4_0 RTX 3090 1 pp4096 87.69 101.38 1.16
gemma2 9B Q4_0 RTX 3090 2 pp4096 157.68 173.97 1.10
gemma2 9B Q4_0 RTX 3090 4 pp4096 257.40 291.47 1.13
gemma2 9B Q4_0 RTX 3090 8 pp4096 348.79 401.57 1.15
gemma2 9B Q4_0 RTX 3090 16 pp4096 651.44 883.03 1.36
gemma2 9B Q4_0 RTX 3090 32 pp4096 1284.64 1430.69 1.11
gemma2 9B Q4_0 RTX 3090 64 pp4096 1703.04 1992.60 1.17
gemma2 9B Q4_0 RTX 3090 128 pp4096 2151.58 2574.37 1.20
gemma2 9B Q4_0 RTX 3090 256 pp4096 2408.43 3024.58 1.26
gemma2 9B Q4_0 RTX 3090 512 pp4096 2460.62 3084.09 1.25
gemma2 9B Q4_0 RTX 3090 1024 pp4096 2469.48 3215.71 1.30
gemma2 9B Q4_0 RTX 3090 2048 pp4096 2348.39 3138.11 1.34
gemma2 9B Q4_0 RTX 3090 4096 pp4096 2093.35 2990.59 1.43
gemma2 9B Q4_0 RTX 4090 1 pp4096 114.38 129.59 1.13
gemma2 9B Q4_0 RTX 4090 2 pp4096 217.58 232.48 1.07
gemma2 9B Q4_0 RTX 4090 4 pp4096 428.54 456.03 1.06
gemma2 9B Q4_0 RTX 4090 8 pp4096 719.26 772.48 1.07
gemma2 9B Q4_0 RTX 4090 16 pp4096 1317.77 1402.33 1.06
gemma2 9B Q4_0 RTX 4090 32 pp4096 2346.22 2568.79 1.09
gemma2 9B Q4_0 RTX 4090 64 pp4096 3638.49 4005.28 1.10
gemma2 9B Q4_0 RTX 4090 128 pp4096 5118.29 5714.04 1.12
gemma2 9B Q4_0 RTX 4090 256 pp4096 6361.31 7390.26 1.16
gemma2 9B Q4_0 RTX 4090 512 pp4096 5608.05 8214.40 1.46
gemma2 9B Q4_0 RTX 4090 1024 pp4096 5171.15 8133.34 1.57
gemma2 9B Q4_0 RTX 4090 2048 pp4096 4447.58 7590.81 1.71
gemma2 9B Q4_0 RTX 4090 4096 pp4096 3523.21 6467.16 1.84

JohannesGaessler avatar Jul 17 '24 14:07 JohannesGaessler

Currently, the Metal kernels always use F32 accumulators regardless of the selected precision so this should not be a problem

You can fix the failing Metal tests like this:

diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index b5939efa..d16f1c6e 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -798,6 +798,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
             if (op->src[0]->ne[0] == 256) {
                 return false;
             }
+            {
+                float logit_softcap;
+
+                memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
+
+                if (logit_softcap != 0.0f) {
+                    return false;
+                }
+            }
             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:

We'll implement support in the future.

ggerganov avatar Jul 17 '24 15:07 ggerganov

It seems Gemma 2 with Flash Attention and quantized KV Cache + partial offloading slows down pompt processing a lot. Is this expected behavior? FA without quantized KV cache is fine though.

Dampfinchen avatar Jul 18 '24 19:07 Dampfinchen

I didn't test that particular case, do you have similar issues with other models or only with Gemma 2?

JohannesGaessler avatar Jul 18 '24 19:07 JohannesGaessler

I didn't test that particular case, do you have similar issues with other models or only with Gemma 2?

Yes, Llama 3 8B is unaffected.

FA without qKV cache (-n 180 -c 4096 -t 6 --gpu-layers 25 --ignore-eos -fa), RTX 2060, i7 9750H, Gemma 2 Q4_K_S

llama_print_timings:        load time =    3774.87 ms
llama_print_timings:      sample time =      44.50 ms /   180 runs   (    0.25 ms per token,  4044.76 tokens per second)
llama_print_timings: prompt eval time =    6291.04 ms /  3353 tokens (    1.88 ms per token,   532.98 tokens per second)
llama_print_timings:        eval time =   37701.14 ms /   179 runs   (  210.62 ms per token,     4.75 tokens per second)
llama_print_timings:       total time =   44200.76 ms /  3532 tokens 

FA with qKV cache (same as above + -ctk q8_0 -ctv q8_0)

llama_print_timings:        load time =    3469.10 ms
llama_print_timings:      sample time =      43.93 ms /   180 runs   (    0.24 ms per token,  4097.89 tokens per second)
llama_print_timings: prompt eval time =  117637.97 ms /  3353 tokens (   35.08 ms per token,    28.50 tokens per second)
llama_print_timings:        eval time =   49843.21 ms /   179 runs   (  278.45 ms per token,     3.59 tokens per second)
llama_print_timings:       total time =  167677.84 ms /  3532 tokens 

Llama 8B q4_K_S (17 GPU layers)

FA without qKV cache

llama_print_timings:        load time =    3028.67 ms
llama_print_timings:      sample time =      21.87 ms /   180 runs   (    0.12 ms per token,  8229.32 tokens per second)
llama_print_timings: prompt eval time =    5523.65 ms /  3671 tokens (    1.50 ms per token,   664.60 tokens per second)
llama_print_timings:        eval time =   30556.35 ms /   179 runs   (  170.71 ms per token,     5.86 tokens per second)
llama_print_timings:       total time =   36237.96 ms /  3850 tokens 

FA with qKV cache

llama_print_timings:        load time =    3055.69 ms
llama_print_timings:      sample time =      19.35 ms /   180 runs   (    0.11 ms per token,  9304.25 tokens per second)
llama_print_timings: prompt eval time =    5536.41 ms /  3671 tokens (    1.51 ms per token,   663.07 tokens per second)
llama_print_timings:        eval time =   26197.26 ms /   179 runs   (  146.35 ms per token,     6.83 tokens per second)
llama_print_timings:       total time =   31884.67 ms /  3850 tokens

Here's a comparison for these two models (partially offloaded) As you can see with Gemma 2, as soon as you quantize the kv cache, the generation and especially prompt processing speed slows down a lot which is not the case with Llama 3 8B (text gen even increases nicely).

Dampfinchen avatar Jul 18 '24 19:07 Dampfinchen

I found some weird problems using -fa with this branch.

Testing imatrix version on Q2 and Q3 of Tiger-Gemma 27b model completely broke the responses with the model answering non-sense for any kind of prompt. While not running with flash-attn seems to work normally. I've not tested with the 9b version.

Interesting enough that using Alpaca template worked with the flash-attn, but the tokenization provided by Gemma would generate infinite response of "Manneur Manneur Manneur Manneur Manneur..."

I also confirmed the slow token initilization mentioned by Dampfinchen on lower quants.

vitorfdl avatar Jul 22 '24 20:07 vitorfdl

There are potentially numerical issues that only appear with specific models. I'm assuming you downloaded the models off of Huggingface; can you link them. Also, do the non-imatrix models work correctly?

JohannesGaessler avatar Jul 23 '24 09:07 JohannesGaessler

There are potentially numerical issues that only appear with specific models. I'm assuming you downloaded the models off of Huggingface; can you link them. Also, do the non-imatrix models work correctly?

https://huggingface.co/mradermacher/Big-Tiger-Gemma-27B-v1-i1-GGUF

I've tested the gemma-27b as well https://huggingface.co/mradermacher/gemma-2-27b-it-i1-GGUF

The issue I mentioned happens with all of them when using -fa. I've not tested non imatrix versions.

vitorfdl avatar Jul 23 '24 13:07 vitorfdl

If nobody else is going to ask, I will! Has this pr been overlooked?

Rotatingxenomorph avatar Aug 08 '24 15:08 Rotatingxenomorph

Testing imatrix version on Q2 and Q3 of Tiger-Gemma 27b model completely broke the responses with the model answering non-sense for any kind of prompt. While not running with flash-attn seems to work normally. I've not tested with the 9b version.

Interesting enough that using Alpaca template worked with the flash-attn, but the tokenization provided by Gemma would generate infinite response of "Manneur Manneur Manneur Manneur Manneur..."

Like @vitorfdl, since updating after this merge, Gemma 2 27B refuses to work for me with -fa.

Every response is:

<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

I'm using Bartowski's gemma-2-27b-it-Q6_K imatrix quant.

sampling params
llama_new_context_with_model: n_ctx      = 8192
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:  CUDA_Host KV buffer size =   704.00 MiB
llama_kv_cache_init:      CUDA0 KV buffer size =  2240.00 MiB
llama_new_context_with_model: KV self size  = 2944.00 MiB, K (f16): 1472.00 MiB, V (f16): 1472.00 MiB
llama_new_context_with_model:  CUDA_Host  output buffer size =     0.98 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =  1431.85 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =    41.01 MiB
llama_new_context_with_model: graph nodes  = 1530
llama_new_context_with_model: graph splits = 147

system_info: n_threads = 8 / 16 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
sampling:
        repeat_last_n = 64, repeat_penalty = 1.100, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 8192, n_batch = 2048, n_predict = -1, n_keep = 1

Example prompt:

<start_of_turn>user
how many squares are on a chessboard?<end_of_turn>
<start_of_turn>model

strawberrymelonpanda avatar Aug 24 '24 21:08 strawberrymelonpanda

I am unable to reproduce the issue. What hardware are you using? Can you post the exact command with which the problem occurs?

JohannesGaessler avatar Aug 24 '24 23:08 JohannesGaessler

WSL2, Ubuntu 22.04.4 LTS on Windows 11, Nvidia Geforce 3090. Built with make clean; make GGML_CUDA=1

./llama-cli --model "gemma-2-27b-it-imat-Q6_K (bartowski).gguf" --prompt "<start_of_turn>user\nhow many squares are on a chessboard?<end_of_turn>\n<start_of_turn>model\n" --gpu-layers 99 --ctx-size 512 --temp 0.0 --repeat-penalty 1.1 --seed 1 --flash-attn --verbose

result:

<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>

until I cancel. (Without --special there's no output)

Remove --flash-attn and it works as expected:

This is a classic riddle! There are more squares on a chessboard than you might initially think. Here's how to figure it [...]

It seems to come down to whether NGL is used with --flash-attn, on Gemma 2.

Simplified settings with FA and without NGL:

./llama-cli --model "gemma-2-27b-it-imat-Q6_K (bartowski).gguf" --prompt "<start_of_turn>user\nhow many squares are on a chessboard?<end_of_turn>\n<start_of_turn>model\n" --verbose --special --flash-attn
This is a classic riddle! [...]

vs FA with -ngl 99 added:

./llama-cli --model "gemma-2-27b-it-imat-Q6_K (bartowski).gguf" --prompt "<start_of_turn>user\nhow many squares are on a chessboard?<end_of_turn>\n<start_of_turn>model\n" --verbose --special --flash-attn -ngl 99
<unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31>

strawberrymelonpanda avatar Aug 24 '24 23:08 strawberrymelonpanda

It seems it also works with partial offloading.

-ngl 45 and below works. -ngl 46 and above does not. (repeated <unused...>)

strawberrymelonpanda avatar Aug 24 '24 23:08 strawberrymelonpanda

It also happens with Metal in #9159.

slaren avatar Aug 25 '24 01:08 slaren

Thank you, I was able to reproduce the issue. Setting the precision to FP32 fixes it for me: https://github.com/ggerganov/llama.cpp/pull/9166 . I think to vaguely recall that the Metal FA implementation always uses FP32 precision anyways though.

JohannesGaessler avatar Aug 25 '24 07:08 JohannesGaessler