llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

ggml : add Flash Attention

Open ggerganov opened this issue 5 months ago • 118 comments

ref #3365

Setting up what's needed for Flash Attention support in ggml and llama.cpp

The proposed operator performs:

// new
res = ggml_flash_attn(ctx, q, k, v, kq_mask, kq_scale);

// fused scale + mask + soft_max (old)
kq  = ggml_mul_mat     (ctx, k,  q);
kq  = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
kqv = ggml_mul_mat     (ctx, v,  kq);
kqv = ggml_permute     (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d     (ctx, kqv, n_embd_head_k*n_head, n_tokens);

// unfused (old)
kq  = ggml_mul_mat (ctx, k,  q);
kq  = ggml_scale   (ctx, kq, kq_scale);
kq  = ggml_add     (ctx, kq, kq_mask);
kq  = ggml_soft_max(ctx, kq);
kqv = ggml_mul_mat (ctx, v,  kq);
kqv = ggml_permute (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d (ctx, kqv, n_embd_head_k*n_head, n_tokens);

Suggestions and comments for the API are welcome. Looking for help in implementing efficient GPU kernels - please open PR to this branch if you have proposals

  • [x] ggml API: ggml_flash_attn_ext()
  • [x] llama.cpp use in llm_build_kqv()
  • [x] add test-backend-ops test
  • [x] CPU implementation (slow, just for testing)
  • [x] CUDA implementation (https://github.com/ggerganov/llama.cpp/pull/6374)
  • [x] Metal implementation
  • [x] GGML_PREC_F32 support (CUDA) (https://github.com/ggerganov/llama.cpp/pull/6646)
  • [x] GGML_PREC_F32 support (Metal)

Changes to ggml/llama

Things to consider

  • Pass KQ list with/instead of KQ mask
  • Pass block-wise KQ mask
  • Support Alibi
  • Finally transform Alibi as ggml_add()? (low-prio)
  • No longer store transposed V-cache (gg/flash-attn-online)

Testing

./tests/test-backend-ops -o FLASH_ATTN_EXT
  • main, server: add -fa
  • llama-bench: add -fa 1

Benchmark

Baseline:

# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf

# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b Metal perf

FA kernel:

# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf

# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b Metal perf

Text-generation after long prompt:

# without flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 0 1 99 8192 256 1

# with flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 1 1 99 8192 256 1

References

  • https://arxiv.org/pdf/1805.02867.pdf Online softmax
  • https://arxiv.org/pdf/2112.05682.pdf O(n) memory self-attention
  • https://arxiv.org/pdf/2307.08691.pdf Flash-attention 2

ggerganov avatar Jan 18 '24 17:01 ggerganov