CPU/CUDA: Gemma 2 FlashAttention support
This PR adds FlashAttention support for Gemma 2 on the CPU and CUDA backends by adding another parameter that controls the logit softcap. A value of 0.0f indicates no logit softcapping. When specifying a different value the scale parameter is divided by this value and prior to the softmax the tangens hyperbolicus as well as the scaling by logit_softcap is applied. Because this changes the position of the parameter that indicates the precision needed for FlashAttention I think that this PR breaks the Metal FlashAttention implementation (I am not aware of any other FlashAttention implementations). I tried searching for the spot where the Metal code retrieves the FlashaAttention precision but I was not able to find it; help would be very much appreciated.
This PR adds template instances with logit softcapping for the CUDA FlashAttention kernels for head sizes 128 (Gemma 2 27b) and 256 (Gemma 2 9b). For all other head sizes no instance is compiled since (to my knowledge) there are no models that would use them. The FlashAttention kernels that do not use FP16 tensor cores only support head sizes 64 and 128 so Gemma 2 9b is only supported on NVIDIA GPUs with compute capability >= 7.0. For simplicity tests/test-backend-ops.cpp only checks head size 128 (which should be enough since logit softcapping is a scalar operation and does not depend on head size).
GPU Performance
| Model | GPU | Microbatch size | Test | t/s eager attention | t/s FlashAttention | Speedup |
|---|---|---|---|---|---|---|
| gemma2 27B Q2_K_M | RX 6800 | 1 | pp4096 | 10.96 | 13.79 | 1.26 |
| gemma2 27B Q2_K_M | RX 6800 | 2 | pp4096 | 17.34 | 20.48 | 1.18 |
| gemma2 27B Q2_K_M | RX 6800 | 4 | pp4096 | 21.76 | 22.31 | 1.03 |
| gemma2 27B Q2_K_M | RX 6800 | 8 | pp4096 | 25.90 | 22.50 | 0.87 |
| gemma2 27B Q2_K_M | RX 6800 | 16 | pp4096 | 58.28 | 40.61 | 0.70 |
| gemma2 27B Q2_K_M | RX 6800 | 32 | pp4096 | 84.49 | 66.19 | 0.78 |
| gemma2 27B Q2_K_M | RX 6800 | 64 | pp4096 | 94.47 | 69.94 | 0.74 |
| gemma2 27B Q2_K_M | RX 6800 | 128 | pp4096 | 112.39 | 84.81 | 0.75 |
| gemma2 27B Q2_K_M | RX 6800 | 256 | pp4096 | 126.55 | 98.17 | 0.78 |
| gemma2 27B Q2_K_M | RX 6800 | 512 | pp4096 | 126.16 | 102.64 | 0.81 |
| gemma2 27B Q2_K_M | RX 6800 | 1024 | pp4096 | 134.20 | 100.69 | 0.75 |
| gemma2 27B Q2_K_M | RX 6800 | 2048 | pp4096 | 130.98 | 95.72 | 0.73 |
| gemma2 27B Q4_0 | RTX 3090 | 1 | pp4096 | 41.22 | 44.08 | 1.07 |
| gemma2 27B Q4_0 | RTX 3090 | 2 | pp4096 | 73.88 | 80.99 | 1.10 |
| gemma2 27B Q4_0 | RTX 3090 | 4 | pp4096 | 115.02 | 125.48 | 1.09 |
| gemma2 27B Q4_0 | RTX 3090 | 8 | pp4096 | 145.30 | 161.40 | 1.11 |
| gemma2 27B Q4_0 | RTX 3090 | 16 | pp4096 | 318.50 | 406.84 | 1.28 |
| gemma2 27B Q4_0 | RTX 3090 | 32 | pp4096 | 585.99 | 654.12 | 1.12 |
| gemma2 27B Q4_0 | RTX 3090 | 64 | pp4096 | 743.71 | 859.31 | 1.16 |
| gemma2 27B Q4_0 | RTX 3090 | 128 | pp4096 | 913.71 | 1092.52 | 1.20 |
| gemma2 27B Q4_0 | RTX 3090 | 256 | pp4096 | 986.56 | 1206.55 | 1.22 |
| gemma2 27B Q4_0 | RTX 3090 | 512 | pp4096 | 992.06 | 1230.06 | 1.24 |
| gemma2 27B Q4_0 | RTX 3090 | 1024 | pp4096 | 978.00 | 1247.45 | 1.28 |
| gemma2 27B Q4_0 | RTX 3090 | 2048 | pp4096 | 942.77 | 1231.07 | 1.31 |
| gemma2 27B Q4_0 | RTX 3090 | 4096 | pp4096 | 850.81 | 1189.47 | 1.40 |
| gemma2 27B Q4_0 | RTX 4090 | 1 | pp4096 | 51.31 | 54.78 | 1.07 |
| gemma2 27B Q4_0 | RTX 4090 | 2 | pp4096 | 99.17 | 103.45 | 1.04 |
| gemma2 27B Q4_0 | RTX 4090 | 4 | pp4096 | 194.06 | 203.43 | 1.05 |
| gemma2 27B Q4_0 | RTX 4090 | 8 | pp4096 | 324.52 | 343.86 | 1.06 |
| gemma2 27B Q4_0 | RTX 4090 | 16 | pp4096 | 601.44 | 640.22 | 1.06 |
| gemma2 27B Q4_0 | RTX 4090 | 32 | pp4096 | 1081.21 | 1173.60 | 1.09 |
| gemma2 27B Q4_0 | RTX 4090 | 64 | pp4096 | 1737.82 | 1934.71 | 1.11 |
| gemma2 27B Q4_0 | RTX 4090 | 128 | pp4096 | 2325.77 | 2735.54 | 1.18 |
| gemma2 27B Q4_0 | RTX 4090 | 256 | pp4096 | 2237.85 | 3151.80 | 1.41 |
| gemma2 27B Q4_0 | RTX 4090 | 512 | pp4096 | 2154.18 | 3224.92 | 1.50 |
| gemma2 27B Q4_0 | RTX 4090 | 1024 | pp4096 | 2053.00 | 3218.63 | 1.57 |
| gemma2 27B Q4_0 | RTX 4090 | 2048 | pp4096 | 1839.51 | 3112.09 | 1.69 |
| gemma2 27B Q4_0 | RTX 4090 | 4096 | pp4096 | 1525.74 | 2873.14 | 1.88 |
| gemma2 27B Q4_0 | P40 | 1 | pp4096 | 14.45 | 13.63 | 0.94 |
| gemma2 27B Q4_0 | P40 | 2 | pp4096 | 19.53 | 25.06 | 1.28 |
| gemma2 27B Q4_0 | P40 | 4 | pp4096 | 31.14 | 34.85 | 1.12 |
| gemma2 27B Q4_0 | P40 | 8 | pp4096 | 36.98 | 39.83 | 1.08 |
| gemma2 27B Q4_0 | P40 | 16 | pp4096 | 82.56 | 144.28 | 1.75 |
| gemma2 27B Q4_0 | P40 | 32 | pp4096 | 127.71 | 193.58 | 1.52 |
| gemma2 27B Q4_0 | P40 | 64 | pp4096 | 165.50 | 213.45 | 1.29 |
| gemma2 27B Q4_0 | P40 | 128 | pp4096 | 206.77 | 245.01 | 1.18 |
| gemma2 27B Q4_0 | P40 | 256 | pp4096 | 236.71 | 267.31 | 1.13 |
| gemma2 27B Q4_0 | P40 | 512 | pp4096 | 248.61 | 272.20 | 1.09 |
| gemma2 27B Q4_0 | P40 | 1024 | pp4096 | 248.30 | 270.26 | 1.09 |
| gemma2 27B Q4_0 | P40 | 2048 | pp4096 | 245.07 | 265.91 | 1.09 |
| gemma2 27B Q4_0 | P40 | 4096 | pp4096 | 229.73 | 253.44 | 1.10 |
| gemma2 9B Q4_0 | RTX 3090 | 1 | pp4096 | 87.69 | 101.38 | 1.16 |
| gemma2 9B Q4_0 | RTX 3090 | 2 | pp4096 | 157.68 | 173.97 | 1.10 |
| gemma2 9B Q4_0 | RTX 3090 | 4 | pp4096 | 257.40 | 291.47 | 1.13 |
| gemma2 9B Q4_0 | RTX 3090 | 8 | pp4096 | 348.79 | 401.57 | 1.15 |
| gemma2 9B Q4_0 | RTX 3090 | 16 | pp4096 | 651.44 | 883.03 | 1.36 |
| gemma2 9B Q4_0 | RTX 3090 | 32 | pp4096 | 1284.64 | 1430.69 | 1.11 |
| gemma2 9B Q4_0 | RTX 3090 | 64 | pp4096 | 1703.04 | 1992.60 | 1.17 |
| gemma2 9B Q4_0 | RTX 3090 | 128 | pp4096 | 2151.58 | 2574.37 | 1.20 |
| gemma2 9B Q4_0 | RTX 3090 | 256 | pp4096 | 2408.43 | 3024.58 | 1.26 |
| gemma2 9B Q4_0 | RTX 3090 | 512 | pp4096 | 2460.62 | 3084.09 | 1.25 |
| gemma2 9B Q4_0 | RTX 3090 | 1024 | pp4096 | 2469.48 | 3215.71 | 1.30 |
| gemma2 9B Q4_0 | RTX 3090 | 2048 | pp4096 | 2348.39 | 3138.11 | 1.34 |
| gemma2 9B Q4_0 | RTX 3090 | 4096 | pp4096 | 2093.35 | 2990.59 | 1.43 |
| gemma2 9B Q4_0 | RTX 4090 | 1 | pp4096 | 114.38 | 129.59 | 1.13 |
| gemma2 9B Q4_0 | RTX 4090 | 2 | pp4096 | 217.58 | 232.48 | 1.07 |
| gemma2 9B Q4_0 | RTX 4090 | 4 | pp4096 | 428.54 | 456.03 | 1.06 |
| gemma2 9B Q4_0 | RTX 4090 | 8 | pp4096 | 719.26 | 772.48 | 1.07 |
| gemma2 9B Q4_0 | RTX 4090 | 16 | pp4096 | 1317.77 | 1402.33 | 1.06 |
| gemma2 9B Q4_0 | RTX 4090 | 32 | pp4096 | 2346.22 | 2568.79 | 1.09 |
| gemma2 9B Q4_0 | RTX 4090 | 64 | pp4096 | 3638.49 | 4005.28 | 1.10 |
| gemma2 9B Q4_0 | RTX 4090 | 128 | pp4096 | 5118.29 | 5714.04 | 1.12 |
| gemma2 9B Q4_0 | RTX 4090 | 256 | pp4096 | 6361.31 | 7390.26 | 1.16 |
| gemma2 9B Q4_0 | RTX 4090 | 512 | pp4096 | 5608.05 | 8214.40 | 1.46 |
| gemma2 9B Q4_0 | RTX 4090 | 1024 | pp4096 | 5171.15 | 8133.34 | 1.57 |
| gemma2 9B Q4_0 | RTX 4090 | 2048 | pp4096 | 4447.58 | 7590.81 | 1.71 |
| gemma2 9B Q4_0 | RTX 4090 | 4096 | pp4096 | 3523.21 | 6467.16 | 1.84 |
Currently, the Metal kernels always use F32 accumulators regardless of the selected precision so this should not be a problem
You can fix the failing Metal tests like this:
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index b5939efa..d16f1c6e 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -798,6 +798,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
if (op->src[0]->ne[0] == 256) {
return false;
}
+ {
+ float logit_softcap;
+
+ memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
+
+ if (logit_softcap != 0.0f) {
+ return false;
+ }
+ }
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
We'll implement support in the future.
It seems Gemma 2 with Flash Attention and quantized KV Cache + partial offloading slows down pompt processing a lot. Is this expected behavior? FA without quantized KV cache is fine though.
I didn't test that particular case, do you have similar issues with other models or only with Gemma 2?
I didn't test that particular case, do you have similar issues with other models or only with Gemma 2?
Yes, Llama 3 8B is unaffected.
FA without qKV cache (-n 180 -c 4096 -t 6 --gpu-layers 25 --ignore-eos -fa), RTX 2060, i7 9750H, Gemma 2 Q4_K_S
llama_print_timings: load time = 3774.87 ms
llama_print_timings: sample time = 44.50 ms / 180 runs ( 0.25 ms per token, 4044.76 tokens per second)
llama_print_timings: prompt eval time = 6291.04 ms / 3353 tokens ( 1.88 ms per token, 532.98 tokens per second)
llama_print_timings: eval time = 37701.14 ms / 179 runs ( 210.62 ms per token, 4.75 tokens per second)
llama_print_timings: total time = 44200.76 ms / 3532 tokens
FA with qKV cache (same as above + -ctk q8_0 -ctv q8_0)
llama_print_timings: load time = 3469.10 ms
llama_print_timings: sample time = 43.93 ms / 180 runs ( 0.24 ms per token, 4097.89 tokens per second)
llama_print_timings: prompt eval time = 117637.97 ms / 3353 tokens ( 35.08 ms per token, 28.50 tokens per second)
llama_print_timings: eval time = 49843.21 ms / 179 runs ( 278.45 ms per token, 3.59 tokens per second)
llama_print_timings: total time = 167677.84 ms / 3532 tokens
Llama 8B q4_K_S (17 GPU layers)
FA without qKV cache
llama_print_timings: load time = 3028.67 ms
llama_print_timings: sample time = 21.87 ms / 180 runs ( 0.12 ms per token, 8229.32 tokens per second)
llama_print_timings: prompt eval time = 5523.65 ms / 3671 tokens ( 1.50 ms per token, 664.60 tokens per second)
llama_print_timings: eval time = 30556.35 ms / 179 runs ( 170.71 ms per token, 5.86 tokens per second)
llama_print_timings: total time = 36237.96 ms / 3850 tokens
FA with qKV cache
llama_print_timings: load time = 3055.69 ms
llama_print_timings: sample time = 19.35 ms / 180 runs ( 0.11 ms per token, 9304.25 tokens per second)
llama_print_timings: prompt eval time = 5536.41 ms / 3671 tokens ( 1.51 ms per token, 663.07 tokens per second)
llama_print_timings: eval time = 26197.26 ms / 179 runs ( 146.35 ms per token, 6.83 tokens per second)
llama_print_timings: total time = 31884.67 ms / 3850 tokens
Here's a comparison for these two models (partially offloaded) As you can see with Gemma 2, as soon as you quantize the kv cache, the generation and especially prompt processing speed slows down a lot which is not the case with Llama 3 8B (text gen even increases nicely).
I found some weird problems using -fa with this branch.
Testing imatrix version on Q2 and Q3 of Tiger-Gemma 27b model completely broke the responses with the model answering non-sense for any kind of prompt. While not running with flash-attn seems to work normally. I've not tested with the 9b version.
Interesting enough that using Alpaca template worked with the flash-attn, but the tokenization provided by Gemma would generate infinite response of "Manneur Manneur Manneur Manneur Manneur..."
I also confirmed the slow token initilization mentioned by Dampfinchen on lower quants.
There are potentially numerical issues that only appear with specific models. I'm assuming you downloaded the models off of Huggingface; can you link them. Also, do the non-imatrix models work correctly?
There are potentially numerical issues that only appear with specific models. I'm assuming you downloaded the models off of Huggingface; can you link them. Also, do the non-imatrix models work correctly?
https://huggingface.co/mradermacher/Big-Tiger-Gemma-27B-v1-i1-GGUF
I've tested the gemma-27b as well https://huggingface.co/mradermacher/gemma-2-27b-it-i1-GGUF
The issue I mentioned happens with all of them when using -fa.
I've not tested non imatrix versions.
If nobody else is going to ask, I will! Has this pr been overlooked?
Testing imatrix version on Q2 and Q3 of Tiger-Gemma 27b model completely broke the responses with the model answering non-sense for any kind of prompt. While not running with flash-attn seems to work normally. I've not tested with the 9b version.
Interesting enough that using Alpaca template worked with the flash-attn, but the tokenization provided by Gemma would generate infinite response of "Manneur Manneur Manneur Manneur Manneur..."
Like @vitorfdl, since updating after this merge, Gemma 2 27B refuses to work for me with -fa.
Every response is:
<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
I'm using Bartowski's gemma-2-27b-it-Q6_K imatrix quant.
sampling params
llama_new_context_with_model: n_ctx = 8192
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA_Host KV buffer size = 704.00 MiB
llama_kv_cache_init: CUDA0 KV buffer size = 2240.00 MiB
llama_new_context_with_model: KV self size = 2944.00 MiB, K (f16): 1472.00 MiB, V (f16): 1472.00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 0.98 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 1431.85 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 41.01 MiB
llama_new_context_with_model: graph nodes = 1530
llama_new_context_with_model: graph splits = 147
system_info: n_threads = 8 / 16 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
sampling:
repeat_last_n = 64, repeat_penalty = 1.100, frequency_penalty = 0.000, presence_penalty = 0.000
top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 8192, n_batch = 2048, n_predict = -1, n_keep = 1
Example prompt:
<start_of_turn>user
how many squares are on a chessboard?<end_of_turn>
<start_of_turn>model
I am unable to reproduce the issue. What hardware are you using? Can you post the exact command with which the problem occurs?
WSL2, Ubuntu 22.04.4 LTS on Windows 11, Nvidia Geforce 3090.
Built with make clean; make GGML_CUDA=1
./llama-cli --model "gemma-2-27b-it-imat-Q6_K (bartowski).gguf" --prompt "<start_of_turn>user\nhow many squares are on a chessboard?<end_of_turn>\n<start_of_turn>model\n" --gpu-layers 99 --ctx-size 512 --temp 0.0 --repeat-penalty 1.1 --seed 1 --flash-attn --verbose
result:
<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
until I cancel. (Without --special there's no output)
Remove --flash-attn and it works as expected:
This is a classic riddle! There are more squares on a chessboard than you might initially think. Here's how to figure it [...]
It seems to come down to whether NGL is used with --flash-attn, on Gemma 2.
Simplified settings with FA and without NGL:
./llama-cli --model "gemma-2-27b-it-imat-Q6_K (bartowski).gguf" --prompt "<start_of_turn>user\nhow many squares are on a chessboard?<end_of_turn>\n<start_of_turn>model\n" --verbose --special --flash-attn
This is a classic riddle! [...]
vs FA with -ngl 99 added:
./llama-cli --model "gemma-2-27b-it-imat-Q6_K (bartowski).gguf" --prompt "<start_of_turn>user\nhow many squares are on a chessboard?<end_of_turn>\n<start_of_turn>model\n" --verbose --special --flash-attn -ngl 99
<unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31><unused31>
It seems it also works with partial offloading.
-ngl 45 and below works. -ngl 46 and above does not. (repeated <unused...>)
It also happens with Metal in #9159.
Thank you, I was able to reproduce the issue. Setting the precision to FP32 fixes it for me: https://github.com/ggerganov/llama.cpp/pull/9166 . I think to vaguely recall that the Metal FA implementation always uses FP32 precision anyways though.