llama.cpp
llama.cpp copied to clipboard
ggml : add Flash Attention
ref #3365
Setting up what's needed for Flash Attention support in ggml
and llama.cpp
The proposed operator performs:
// new
res = ggml_flash_attn(ctx, q, k, v, kq_mask, kq_scale);
// fused scale + mask + soft_max (old)
kq = ggml_mul_mat (ctx, k, q);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
kqv = ggml_mul_mat (ctx, v, kq);
kqv = ggml_permute (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d (ctx, kqv, n_embd_head_k*n_head, n_tokens);
// unfused (old)
kq = ggml_mul_mat (ctx, k, q);
kq = ggml_scale (ctx, kq, kq_scale);
kq = ggml_add (ctx, kq, kq_mask);
kq = ggml_soft_max(ctx, kq);
kqv = ggml_mul_mat (ctx, v, kq);
kqv = ggml_permute (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d (ctx, kqv, n_embd_head_k*n_head, n_tokens);
Suggestions and comments for the API are welcome. Looking for help in implementing efficient GPU kernels - please open PR to this branch if you have proposals
- [x]
ggml
API:ggml_flash_attn_ext()
- [x]
llama.cpp
use inllm_build_kqv()
- [x] add
test-backend-ops
test - [x] CPU implementation (slow, just for testing)
- [x] CUDA implementation (https://github.com/ggerganov/llama.cpp/pull/6374)
- [x] Metal implementation
- [x]
GGML_PREC_F32
support (CUDA) (https://github.com/ggerganov/llama.cpp/pull/6646) - [x]
GGML_PREC_F32
support (Metal)
Changes to ggml
/llama
-
Add new op
GGML_OP_FLASH_ATTN_EXT
andggml_flash_attn_ext()
call (before merging we can consider reusing the oldGGML_OP_FLASH_ATTN
and removing the legacy code) -
Change
mask
type to F16 forggml_soft_max_ext()
and require that it is padded toGGML_KQ_MASK_PAD 32
-
The
n_kv
denoting the number of computed tokens from the KV cache is now padded to 128 (from 32) to support larger FA blocks without making out-of-bounds access -
The minimum
llama_context_params.n_batch
that can be used isGGML_KQ_MASK_PAD 32
to avoid out-of-bounds access in the FA kernels for small batch size -
The
V
tensor is no longer transposed when storing it in the KV cache - The input buffer is cleared with zeros to avoid NaNs in the padded tensors
Things to consider
- Pass KQ list with/instead of KQ mask
- Pass block-wise KQ mask
- Support Alibi
- Finally transform Alibi as
ggml_add()
? (low-prio) - No longer store transposed V-cache (gg/flash-attn-online)
Testing
./tests/test-backend-ops -o FLASH_ATTN_EXT
-
main
,server
: add-fa
-
llama-bench
: add-fa 1
Benchmark
Baseline:
# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf
# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b Metal perf
FA kernel:
# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf
# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b Metal perf
Text-generation after long prompt:
# without flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 0 1 99 8192 256 1
# with flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 1 1 99 8192 256 1
References
- https://arxiv.org/pdf/1805.02867.pdf Online softmax
- https://arxiv.org/pdf/2112.05682.pdf O(n) memory self-attention
- https://arxiv.org/pdf/2307.08691.pdf Flash-attention 2