llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

ggml : add Flash Attention

Open ggerganov opened this issue 1 year ago • 118 comments

ref #3365

Setting up what's needed for Flash Attention support in ggml and llama.cpp

The proposed operator performs:

// new
res = ggml_flash_attn(ctx, q, k, v, kq_mask, kq_scale);

// fused scale + mask + soft_max (old)
kq  = ggml_mul_mat     (ctx, k,  q);
kq  = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
kqv = ggml_mul_mat     (ctx, v,  kq);
kqv = ggml_permute     (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d     (ctx, kqv, n_embd_head_k*n_head, n_tokens);

// unfused (old)
kq  = ggml_mul_mat (ctx, k,  q);
kq  = ggml_scale   (ctx, kq, kq_scale);
kq  = ggml_add     (ctx, kq, kq_mask);
kq  = ggml_soft_max(ctx, kq);
kqv = ggml_mul_mat (ctx, v,  kq);
kqv = ggml_permute (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d (ctx, kqv, n_embd_head_k*n_head, n_tokens);

Suggestions and comments for the API are welcome. Looking for help in implementing efficient GPU kernels - please open PR to this branch if you have proposals

  • [x] ggml API: ggml_flash_attn_ext()
  • [x] llama.cpp use in llm_build_kqv()
  • [x] add test-backend-ops test
  • [x] CPU implementation (slow, just for testing)
  • [x] CUDA implementation (https://github.com/ggerganov/llama.cpp/pull/6374)
  • [x] Metal implementation
  • [x] GGML_PREC_F32 support (CUDA) (https://github.com/ggerganov/llama.cpp/pull/6646)
  • [x] GGML_PREC_F32 support (Metal)

Changes to ggml/llama

Things to consider

  • Pass KQ list with/instead of KQ mask
  • Pass block-wise KQ mask
  • Support Alibi
  • Finally transform Alibi as ggml_add()? (low-prio)
  • No longer store transposed V-cache (gg/flash-attn-online)

Testing

./tests/test-backend-ops -o FLASH_ATTN_EXT
  • main, server: add -fa
  • llama-bench: add -fa 1

Benchmark

Baseline:

# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf

# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b Metal perf

FA kernel:

# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf

# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b Metal perf

Text-generation after long prompt:

# without flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 0 1 99 8192 256 1

# with flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 1 1 99 8192 256 1

References

  • https://arxiv.org/pdf/1805.02867.pdf Online softmax
  • https://arxiv.org/pdf/2112.05682.pdf O(n) memory self-attention
  • https://arxiv.org/pdf/2307.08691.pdf Flash-attention 2

ggerganov avatar Jan 18 '24 17:01 ggerganov

Since we are doing this from scratch, wouldn't it be better to remove the custom attention mask entirely and pass a list of KV cells used in each sequence? Considering our implementation of batching, I think we should be looking at implementing something closer to paged attention rather than flash attention. I suppose it is possible to convert the mask to a list of sequences in the kernels, but it would be less efficient.

slaren avatar Jan 18 '24 17:01 slaren

Yes, we can pass list instead of mask. I am not sure of the format though - if each list has different length I feel it will hinder the GPU performance.

Edit: I just got an idea - we can pass both the kq_mask as it is, plus a second boolean tensor that tells each token to which KV blocks it should attend. For example, we split the KV cache in blocks of 128 (or some other round number) and a token (i.e. row in q) attends to a block if atleast one of the cells in it belongs to the token's sequence. This way, we can skip entire blocks of the KV cache that do not belong to the current sequence and keep the problem parallel-friendly. Thoughts?

ggerganov avatar Jan 18 '24 17:01 ggerganov

We could use a vector with dimension [num_seqs] that contains the length of the sequences, and a 2D tensor with dimensions [max_seq_len, num_seqs] that contains the KV cells in each sequence, padded to the length of the longest sequence.

slaren avatar Jan 18 '24 17:01 slaren

It seems that vLLM has added a new version of paged attention since it looked into the implementation (https://github.com/vllm-project/vllm/pull/1348). I am not sure what are the changes, but I think it is worth looking into what they are doing. The kernel is in https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu

slaren avatar Jan 18 '24 17:01 slaren

Alibi could also be done in this kernel.

slaren avatar Jan 18 '24 17:01 slaren

Regarding the Alibi, I feel reinterpreting it as a KQ_mask via ggml_add() is a more general solution - we will avoid having a ggml_alibi() operator and explicit support in the kernels that we write (like in vLLM).

It remains to be seen though if the KQ_mask will be a bottleneck - my feeling is that just avoiding the extra read/write of KQ will bring us close to the optimal performance, even with the existing "cross-KV compute" drawback.

Will take a look at the vLLM code and I've updated the description with some of the things from this discussion

ggerganov avatar Jan 18 '24 18:01 ggerganov

@ggerganov @slaren Together with @JohannesGaessler and @FSSRepo we are working on the same thing over at https://github.com/Pints-App/llama.cpp/pull/1 which we intend to do a pull to llamacpp once work is done.

However, I think we will converge into this one. Given the amount of work here, @ggerganov @slaren how do you want to organise this? The 3 of us are in a temporary discord group actually to work this out, perhaps we can use that?

What are your thoughts?

calvintwr avatar Jan 19 '24 01:01 calvintwr

Discord is not an option for me - I prefer to communicate over Github issues / discussions / e-mail.

Happy to see you have started work on the CUDA implementation. Please take into account the proposed API here - note that it is still a WIP and can change. I can review the implementation that you have when you think it is in a good state. Would prefer PR's that are compatible with this branch so we can verify correctness using test-backend-ops and support for all backends.

ggerganov avatar Jan 19 '24 12:01 ggerganov

@ggerganov Got it. Let us work on a plan to converge with this PR.

calvintwr avatar Jan 20 '24 03:01 calvintwr

~~test-backend-ops -o FLASH_ATTN_EXT fails for Metal on my M2 Pro, is this known?~~ edit: I see, not implemented yet.

cebtenzzre avatar Jan 20 '24 17:01 cebtenzzre

Any performance numbers?

JianbangZ avatar Jan 20 '24 19:01 JianbangZ

There is now an initial version for Metal - see the kernel_flash_attn_ext_f16 kernel. It's already slightly faster for TG after a long prompt, but the PP speed is ~15% lower compared to master. Still looking into it - any ideas are appreciated.

I think the CPU is already better, but I haven't done too much tests in this regard, as I'm more interested in the GPU performance

This is based on reading the 3 papers in the description and implementing my understanding of what Flash-attention / Flash-decoding is supposed to do. I guess there is some chance that I'm still misunderstanding something and that's why I can't improve the performance universally.

I've changed the API so that the V tensor is no longer stored transposed in the KV cache. Since we no longer rely on the ggml_mul_mat in the attention, this is no longer required and it would allow for easier quantization of the V cache in the future since the head dimensions are quite often a factor of 32. Still contemplating over this, but I think it will be a good change overall:

https://github.com/ggerganov/llama.cpp/blob/17720fad669eed6171ddf17184da5bab50adeb72/ggml.h#L1622-L1635

The ggml_flash_attn_ext also avoids the final permute + cont that we normally have - this is built-in the kernels

I'm currently looking into the performance for a full KQ_mask so the format of the masked indices is not relevant. The plan is to get to something that is faster without masking and only after that will start looking into improving this

If anyone is interested in porting this implementation to CUDA - feel free to give it a try. For the moment I'm focusing on Metal as it is more convenient for me to develop

ggerganov avatar Jan 22 '24 14:01 ggerganov

Great job!! I'm struggling to achieve performance improvement in CUDA because I'm having issues where 90% of the kernel execution time is spent on memory I/O, and the remaining 10% on computation.

FSSRepo avatar Jan 22 '24 14:01 FSSRepo

It's already slightly faster for TG after a long prompt, but the PP speed is ~15% lower compared to master. Still looking into it - any ideas are appreciated.

I'm struggling to achieve performance improvement in CUDA because I'm having issues where 90% of the kernel execution time is spent on memory I/O, and the remaining 10% on computation.

My experience is that writing GPU matrix multiplication code for cases in which at least one of the matrices is thin in at least one of its dimensions is comparatively easy. No matter what you do, the implementation is going to be I/O bound anyways and a simple implementation based on dot products is going to be close to optimal. If you have large matrices the operation becomes compute rather than I/O bound however. Now the challenge is to utilize the compute pipelines by loading the data in such a way that arithmetic intensity becomes maximal.

The speedup reported by the FlashAttention paper was I think ~15%. This is a significant speedup but it also means that to get any speedup at all you would need to base your FlashAttention implementation off of a matrix multiplication kernel that achieves performance comparable to cuBLAS GEMM. This is no easy task.

I am currently working on matrix multiplication kernels for quantized data based on int8 arithmetic: https://github.com/ggerganov/llama.cpp/pull/4801 . They are already faster than cuBLAS FP16 GEMM and can be used in the future as the basis for custom matrix multiplication kernels. Conceivably one of those custom matrix multiplication kernels could be FlashAttention.

If anyone is interested in porting this implementation to CUDA - feel free to give it a try. For the moment I'm focusing on Metal as it is more convenient for me to develop

I will not get to it in the foreseeable future. My priority is to get the int8 matrix multiplication in order, after that the quantization of the KV cache, and after that I want to look into training in llama.cpp. If it turns out that FlashAttention would be useful there I will maybe look into it but definitely no promises in terms of timeline or features.

JohannesGaessler avatar Jan 22 '24 15:01 JohannesGaessler

Having a fused kernel allows you to skip 2 times reading and writing KQ to global memory and avoid computing half the operations compared to matrix-multiplication based implementation via masking and you can also apply online softmax. So even if the fused kernel is not cuBLAS-level optimized I was hoping to outperform the existing implementation

ggerganov avatar Jan 22 '24 15:01 ggerganov

I mean, I don't have the hardware to profile the Metal code but you should check how long the kernel takes compared to the equivalent GEMM kernels. In my experience getting within 50% of the performance of highly optimized GEMM code is already not easy. And if your kernel is slower than 50% GEMM then the total runtime will increase even if you do only half as many calculations. In my int8 PR the actual matrix multiplication kernel is currently only ~10% faster than the cuBLAS FP16 GEMM kernel even though int8 tensor cores are in theory twice as fast as FP16 tensor cores.

For large batch sizes the amount of data read from and written to global memory should not make much of a difference in terms of inference performance since the matrices should have size $O(N^2)$ but the computation needs $O(N^3)$ operations.

JohannesGaessler avatar Jan 22 '24 16:01 JohannesGaessler

After implementing the kernel with simdgroup matrix ops, it is now universally better than the master version. Tested with head size 128 - might needs some extra work for heads that are not divisible by 32, but it also might just work.

Next step will be to port this implementation in CUDA - will get into this in a few days after catching up a bit with other issues and if it hasn't been implemented yet by other people

ggerganov avatar Jan 25 '24 11:01 ggerganov

After implementing the kernel with simdgroup matrix ops, it is now universally better than the master version. Tested with head size 128 - might needs some extra work for heads that are not divisible by 32, but it also might just work.

Next step will be to port this implementation in CUDA - will get into this in a few days after catching up a bit with other issues and if it hasn't been implemented yet by other people

Can you give some example benchmark numbers on master vs this PR with Metal ?

YavorGIvanov avatar Jan 25 '24 11:01 YavorGIvanov

It's early for numbers - with the current block size of 8x32 (8 queries x 32 cache items) it's marginally better. I'm just glad that the performance finally makes sense and this looks like the correct way to implement this kernel

ggerganov avatar Jan 25 '24 12:01 ggerganov

@ggerganov I'm trying to port the kernel you already have in Metal to CUDA, but I'm not completely clear on how it works yet. So, I would appreciate it if you could help me clarify my understanding.

  1. I understand that for each thread group, 8 queries are processed (batch size). For every thread group, there are 2 warps(simdgroups) Then the batch size is greater than 4 (number of queries), each warp handles one query. Is that right?

  2. Could you explain the data layout in shared memory?

threadgroup half  * pq  = (threadgroup half  *) (shared +                   0*D); // offset: 0, size: head_dim
threadgroup half4 * pq4 = (threadgroup half4 *) (shared +                0*D);
threadgroup half  * ps  = (threadgroup half  *) (shared + sgitg*(D + 1*C) + 1*D); // offset: head_dim + (warp_id * (head_dim + cache_per_warp)), size: head_dim + cache_per_warp
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D);

// In the code, it implies that it has a size of (head_dim + warps*(head_dim + cache_per_warp)) * queries per threadgroup.
threadgroup half  * ss  = (threadgroup half  *) (shared + sgitg*(D + 1*C) + 2*D); // offset: head_dim*2 + (warp_id * (head_dim + cache_per_warp)), size: ??
  1. In CUDA, there is no vector type half4, only half and half2. Therefore, I will use the latter when copying data from VRAM to SRAM.
     const int64_t D4 = D/4;   // changing this to D/2 to allow half2 when copying data?
    const int64_t N4 = N_SIMDWIDTH;
    const int64_t L4 = (D4 + N4 - 1)/N4;
    const int64_t D8 = D/8;

    const int64_t T  = D + nsg*(D + 1*C); // shared memory size per query in half
    const int64_t T4 = T/4;               // shared memory size per query in half4 <-- this too. T/2
  1. Tensor cores in CUDA require a multiplication of 16x16x16, whereas the multiplication you are performing here is in 8x8x8. So, I assume you will need to make the following changes.
     // change queries per threadgroup (Q) to 16 to have a multiplication 16x16x16, 
// head_dim and batch_size must be multiple of 16 to use tensor cores, or add padding for alignment to avoid overflow when load the data into the tensor core fragments
    const int64_t L4 = (D4 + N4 - 1)/N4;
    const int64_t D8 = D/8; // change this to 16
  1. mqk = make_filled_simdgroup_matrix<half, Q>(0.h); Does this create a QxQ matrix or a vector with Q elements?

  2. In the online softmax loop, it seems like this loop does nothing since it only multiplies ms by an array of zeros.

// online softmax
            for (int64_t j = 0; j < Q; ++j) {
                ...
                for (int64_t i = 0; i < L4; ++i) {
                    ps4[j*T4 + N4*i + tiisg] *= ms; // This does nothing since the ps4 has only been initialized to 0.0 so far.
                }

                ss[j*T + p] = vs;
            }

Any other questions I have, I will let you know. Thank you!

FSSRepo avatar Jan 26 '24 23:01 FSSRepo

@FSSRepo Thanks for looking into this. Here is some information:

  1. The threadgroup can have configurable number of warps:

https://github.com/ggerganov/llama.cpp/blob/6fea843b246409a3c4b26156745a89e4ba01029b/ggml-metal.m#L2255-L2258

It is yet to be determined how many warps to set - I think the paper recommends 4 or 8. But in any case, the kernel should support a configurable number of warps.

Each warp in the threadgroup works on the same exact batch of 8 queries - one head at a time. (in the q query tensor, the batch index is 1 and the head index is 2). Each warp processes 1/nsg of the KV cache:

https://github.com/ggerganov/llama.cpp/blob/6fea843b246409a3c4b26156745a89e4ba01029b/ggml-metal.metal#L2129-L2130

Here ne11 is the size of the KV cache, nsg is the number of warps (a.k.a. simdgroups) in the threadgroup and ~C = 8~ C = 32 is the number of cache items processed on each iteration. The warps work completely independently from each other thanks to the online softmax. They need to be synchronized just at the end when we reduce their results. ~For CUDA likely C has to be 16.~

  1. The total shared memory required by this kernel is:

https://github.com/ggerganov/llama.cpp/blob/6fea843b246409a3c4b26156745a89e4ba01029b/ggml-metal.m#L2263

  • nqptg - number of queries per threadgroup (i.e. currently 8, but for CUDA likely to be 16)
  • ne00 - the head size (e.g. 128 for LLaMA, Mistral, etc.)
  • nsg - number of warps (configurable)
  • ncpsg - number of cache values per simdgroup (i.e. 32, has to be multiple of the warp size)

On Metal, we have 32KB of shared memory per threadgroup, so here are some possible configurations:

nqptg ne00 nsg ncpsg smem (bytes)
8 128 4 32 12288
8 128 8 32 22528
8 128 12 32 32768
8 128 4 64 14336
8 128 8 64 26624
8 128 4 96 16384
8 128 8 96 30720
8 128 4 128 18432
16 128 4 32 24576
16 128 4 64 28672
16 128 4 96 32768

Not sure about CUDA - hopefully it has more shared memory so we can fit more queries and cache values per threadgroup.

The reason for this shared memory layout is because for each query, we need to load the head (this is ne00 elements) and then for each warp, we need a scratch buffer to be storing the resulting head of QKV. We also need a small scratch buffer of size ncspg to write the intermediate attention values from the Q*K^T result.

In theory the nsg*(ne00 + 1*ncpsg) part of the buffer could reside in the warp registers. But on Metal, there is no way to utilize the simd matrix operations in that case - they require to load and store data to shared or device memory. I don't know what is the situation with CUDA, but one should look into moving this in the registers if possible

  1. Yes, use the biggest vector type available - half2 in this case

  2. Yes, given that CUDA operates with 16x16 matrices, I suspect you should be using 16 queries per threadgroup. Or a multiple of 16 queries if the SRAM allows it. I plan to extend the Metal kernel to support configurable number of queries (multiple of 8).

Head dimensions are always multiple of 16 so this should be OK.

Number of queries should be padded with zeros if not a multiple:

https://github.com/ggerganov/llama.cpp/blob/6fea843b246409a3c4b26156745a89e4ba01029b/ggml-metal.metal#L2055-L2063

  1. Yes, this creates a 8x8 matrix filled with zeros

  2. Notice that the ps and ps4 pointers point to the same shared memory data. ps is necessary for loading the data into the 8x8 simd matrices (i.e. simdgroup_load), while the ps4 is needed for faster multiplication. The line that you referenced basically does the scaling of the ~attention~ output during online softmax - this is the first term in this equation from the FA2 paper:

image

Depending on what matrix operators are available in CUDA, this could be represented in matrix form and even might be possible to avoid writing and reading these intermediate results to the ncpsg shared memory buffer. But on Metal, there is no way to create a 8x8 diagonal matrix with different diagonal elements ms, so that is why I've implemented this in scalar form.

I plan to extend the kernel to support larger block sizes (i.e. 8x64, 16x32, 16x64, etc.) which I expect to result in further improvement. But I suppose for initial implementation, one should first try to implement the smallest block size - 16x32 for CUDA and work from there.

ggerganov avatar Jan 27 '24 12:01 ggerganov

Here are some preliminary results measuring the performance just for computing the attention. This speedup is lower-bound since the test data does not use -INF mask which in real cases helps significantly the new flash attention kernel by skipping the computation of such blocks.

The speed on master for head sizes not multiple of 64 is pretty bad because in such cases we fallback to the mat-vec kernels, instead of using the efficient mat-mat kernels.

For small contexts (< 1024) the bs=1 case on master is better due to the efficient mat-vec kernels, but at bigger contexts the PR is better

The speed in this PR for head size of 256 is pretty bad because it seems that at this size the amount of local memory in the simdgroup is quite large and this somehow affects the performance. Still not 100% sure though - there could be some other explanation

M2 Ultra
Head size Heads n_kv n_batch master us/run PR us/run speedup
64 32 512 1 48.35 52.66 0.918
64 32 512 2 65.10 53.51 1.217
64 32 512 4 64.60 54.14 1.193
64 32 512 8 66.28 54.98 1.206
64 32 512 512 359.21 320.89 1.119
64 32 512 1024 679.95 556.66 1.221
64 32 512 2048 1307.07 1070.70 1.221
64 32 1024 1 73.74 55.80 1.322
64 32 1024 2 84.81 56.87 1.491
64 32 1024 4 87.57 57.65 1.519
64 32 1024 8 90.37 59.91 1.508
64 32 1024 512 615.42 515.89 1.193
64 32 1024 1024 1162.55 976.28 1.191
64 32 1024 2048 2258.26 1897.12 1.190
64 32 2048 1 127.41 77.92 1.635
64 32 2048 2 137.83 74.32 1.855
64 32 2048 4 141.00 73.93 1.907
64 32 2048 8 143.79 75.10 1.915
64 32 2048 512 941.67 955.09 0.986
64 32 2048 1024 1859.78 1829.82 1.016
64 32 2048 2048 3773.73 3563.03 1.059
64 32 4096 1 236.36 134.64 1.755
64 32 4096 2 242.09 113.24 2.138
64 32 4096 4 238.95 107.49 2.223
64 32 4096 8 248.57 107.94 2.303
64 32 4096 512 2128.95 1909.24 1.115
64 32 4096 1024 4191.58 3833.60 1.093
64 32 4096 2048 8662.38 7592.72 1.141
80 32 512 1 51.22 45.78 1.119
80 32 512 2 78.24 46.65 1.677
80 32 512 4 85.26 48.55 1.756
80 32 512 8 113.76 48.40 2.350
80 32 512 512 4179.59 318.88 13.107
80 32 512 1024 8208.02 591.91 13.867
80 32 512 2048 16379.66 1144.30 14.314
80 32 1024 1 78.68 61.85 1.272
80 32 1024 2 120.74 62.71 1.925
80 32 1024 4 135.29 64.29 2.104
80 32 1024 8 194.74 65.25 2.985
80 32 1024 512 8164.54 554.10 14.735
80 32 1024 1024 16046.73 1042.44 15.393
80 32 1024 2048 32204.06 2019.51 15.946
80 32 2048 1 120.86 86.05 1.405
80 32 2048 2 207.56 84.41 2.459
80 32 2048 4 236.96 86.10 2.752
80 32 2048 8 343.46 85.79 4.003
80 32 2048 512 15971.57 1030.95 15.492
80 32 2048 1024 31757.59 1956.62 16.231
80 32 2048 2048 63448.85 3805.53 16.673
80 32 4096 1 220.70 133.37 1.655
80 32 4096 2 385.37 121.89 3.162
80 32 4096 4 433.54 121.74 3.561
80 32 4096 8 656.55 122.32 5.367
80 32 4096 512 32207.92 2042.83 15.766
80 32 4096 1024 63862.88 3952.05 16.159
80 32 4096 2048 127819.94 7618.14 16.778
96 32 512 1 52.16 48.19 1.082
96 32 512 2 64.60 49.42 1.307
96 32 512 4 65.62 50.44 1.301
96 32 512 8 68.31 51.56 1.325
96 32 512 512 479.10 367.21 1.305
96 32 512 1024 918.58 684.40 1.342
96 32 512 2048 1798.89 1323.87 1.359
96 32 1024 1 76.41 62.88 1.215
96 32 1024 2 88.54 64.60 1.371
96 32 1024 4 89.44 65.23 1.371
96 32 1024 8 92.75 66.22 1.401
96 32 1024 512 806.81 643.86 1.253
96 32 1024 1024 1534.11 1220.26 1.257
96 32 1024 2048 3015.06 2366.33 1.274
96 32 2048 1 122.82 88.51 1.388
96 32 2048 2 142.06 87.75 1.619
96 32 2048 4 145.86 89.62 1.628
96 32 2048 8 149.40 88.97 1.679
96 32 2048 512 1263.90 1230.62 1.027
96 32 2048 1024 2496.61 2319.37 1.076
96 32 2048 2048 5082.33 4508.18 1.127
96 32 4096 1 226.70 147.65 1.535
96 32 4096 2 249.52 134.95 1.849
96 32 4096 4 259.10 136.24 1.902
96 32 4096 8 264.56 135.94 1.946
96 32 4096 512 2726.24 2541.81 1.073
96 32 4096 1024 5391.37 4940.26 1.091
96 32 4096 2048 11118.87 9774.00 1.138
112 32 512 1 53.03 50.35 1.053
112 32 512 2 78.74 52.39 1.503
112 32 512 4 89.65 53.01 1.691
112 32 512 8 116.58 54.33 2.146
112 32 512 512 4398.15 388.05 11.334
112 32 512 1024 8582.10 731.05 11.739
112 32 512 2048 17236.83 1416.57 12.168
112 32 1024 1 74.96 67.07 1.118
112 32 1024 2 123.45 68.21 1.810
112 32 1024 4 144.76 68.74 2.106
112 32 1024 8 196.37 70.79 2.774
112 32 1024 512 8554.53 688.48 12.425
112 32 1024 1024 16775.31 1307.00 12.835
112 32 1024 2048 33713.82 2548.04 13.231
112 32 2048 1 126.37 92.52 1.366
112 32 2048 2 213.71 92.35 2.314
112 32 2048 4 248.02 94.07 2.637
112 32 2048 8 357.08 95.50 3.739
112 32 2048 512 16697.99 1317.57 12.673
112 32 2048 1024 33034.48 2491.99 13.256
112 32 2048 2048 66383.25 4868.13 13.636
112 32 4096 1 237.16 155.86 1.522
112 32 4096 2 373.50 144.62 2.583
112 32 4096 4 435.65 146.07 2.982
112 32 4096 8 643.29 146.13 4.402
112 32 4096 512 33739.44 2635.00 12.804
112 32 4096 1024 66798.13 5029.28 13.282
112 32 4096 2048 133687.09 9861.75 13.556
128 32 512 1 46.68 54.74 0.853
128 32 512 2 66.58 57.64 1.155
128 32 512 4 67.92 57.95 1.172
128 32 512 8 70.21 58.34 1.203
128 32 512 512 517.57 436.09 1.187
128 32 512 1024 1000.31 826.79 1.210
128 32 512 2048 1979.87 1605.15 1.233
128 32 1024 1 69.69 69.38 1.004
128 32 1024 2 91.11 71.09 1.282
128 32 1024 4 92.39 72.72 1.270
128 32 1024 8 95.79 72.42 1.323
128 32 1024 512 865.96 770.68 1.124
128 32 1024 1024 1660.97 1463.05 1.135
128 32 1024 2048 3288.41 2845.84 1.156
128 32 2048 1 114.68 99.64 1.151
128 32 2048 2 149.57 100.09 1.494
128 32 2048 4 151.86 100.76 1.507
128 32 2048 8 155.56 101.88 1.527
128 32 2048 512 1367.97 1539.56 0.889
128 32 2048 1024 2697.36 2844.56 0.948
128 32 2048 2048 5535.49 5478.64 1.010
128 32 4096 1 210.81 175.34 1.202
128 32 4096 2 264.01 166.61 1.585
128 32 4096 4 269.43 166.14 1.622
128 32 4096 8 276.98 166.71 1.661
128 32 4096 512 2926.51 3033.96 0.965
128 32 4096 1024 5816.29 5828.25 0.998
128 32 4096 2048 12035.27 11497.53 1.047
256 32 512 1 52.02 326.86 0.159
256 32 512 2 79.17 327.95 0.241
256 32 512 4 79.87 329.14 0.243
256 32 512 8 83.67 329.81 0.254
256 32 512 512 843.42 4572.88 0.184
256 32 512 1024 1665.21 8587.96 0.194
256 32 512 2048 3390.91 16978.11 0.200
256 32 1024 1 75.91 628.88 0.121
256 32 1024 2 116.52 630.33 0.185
256 32 1024 4 118.84 628.09 0.189
256 32 1024 8 125.11 630.85 0.198
256 32 1024 512 1362.92 8857.97 0.154
256 32 1024 1024 2686.44 17106.66 0.157
256 32 1024 2048 5370.67 34220.45 0.157
256 32 2048 1 134.83 917.62 0.147
256 32 2048 2 195.35 916.88 0.213
256 32 2048 4 199.93 909.51 0.220
256 32 2048 8 205.00 909.46 0.225
256 32 2048 512 2232.24 17710.61 0.126
256 32 2048 1024 4480.89 34771.62 0.129
256 32 2048 2048 9132.45 69601.42 0.131
256 32 4096 1 258.21 1530.00 0.169
256 32 4096 2 342.99 1529.23 0.224
256 32 4096 4 357.43 1526.52 0.234
256 32 4096 8 358.46 1523.29 0.235
256 32 4096 512 4506.89 36281.69 0.124
256 32 4096 1024 9059.17 70992.35 0.128
256 32 4096 2048 18719.47 142732.73 0.131
M1 Pro
Head size Heads n_kv n_batch master us/run PR us/run speedup
64 32 512 1 171.72 145.12 1.183
64 32 512 2 158.36 122.98 1.288
64 32 512 4 169.14 127.02 1.332
64 32 512 8 192.28 125.90 1.527
64 32 512 512 1854.92 1516.88 1.223
64 32 512 1024 3205.20 2682.15 1.195
64 32 512 2048 6451.63 5339.80 1.208
64 32 1024 1 173.10 111.86 1.547
64 32 1024 2 175.84 109.55 1.605
64 32 1024 4 186.56 110.96 1.681
64 32 1024 8 193.02 112.15 1.721
64 32 1024 512 2820.42 2422.31 1.164
64 32 1024 1024 5655.72 4825.81 1.172
64 32 1024 2048 11300.35 9621.63 1.174
64 32 2048 1 328.60 202.96 1.619
64 32 2048 2 304.72 181.98 1.674
64 32 2048 4 315.31 149.04 2.116
64 32 2048 8 332.71 150.34 2.213
64 32 2048 512 5477.61 4578.76 1.196
64 32 2048 1024 11708.44 9115.65 1.284
64 32 2048 2048 24465.36 18190.31 1.345
64 32 4096 1 660.99 295.43 2.237
64 32 4096 2 564.48 268.22 2.105
64 32 4096 4 583.85 263.57 2.215
64 32 4096 8 634.65 265.49 2.390
64 32 4096 512 11739.99 8937.42 1.314
64 32 4096 1024 24629.73 17878.37 1.378
64 32 4096 2048 49174.03 35553.78 1.383
80 32 512 1 103.56 84.43 1.227
80 32 512 2 161.68 86.45 1.870
80 32 512 4 215.49 87.03 2.476
80 32 512 8 358.89 88.77 4.043
80 32 512 512 18589.89 1443.56 12.878
80 32 512 1024 37176.46 2862.85 12.986
80 32 512 2048 74406.37 5700.48 13.053
80 32 1024 1 180.91 126.16 1.434
80 32 1024 2 291.73 122.60 2.380
80 32 1024 4 388.21 123.62 3.140
80 32 1024 8 664.89 126.11 5.272
80 32 1024 512 36531.48 2574.14 14.192
80 32 1024 1024 73058.68 5121.40 14.265
80 32 1024 2048 146167.13 10211.46 14.314
80 32 2048 1 345.14 241.11 1.431
80 32 2048 2 543.55 170.14 3.195
80 32 2048 4 724.30 168.35 4.302
80 32 2048 8 1267.28 168.50 7.521
80 32 2048 512 72650.87 4885.48 14.871
80 32 2048 1024 146028.15 9746.80 14.982
80 32 2048 2048 293172.12 19514.49 15.023
80 32 4096 1 705.26 320.43 2.201
80 32 4096 2 1051.25 320.59 3.279
80 32 4096 4 1408.55 321.61 4.380
80 32 4096 8 2515.26 323.06 7.786
80 32 4096 512 145826.01 9616.87 15.164
80 32 4096 1024 292761.77 19220.51 15.232
80 32 4096 2048 585546.11 38428.58 15.237
96 32 512 1 105.99 89.76 1.181
96 32 512 2 128.04 91.61 1.398
96 32 512 4 139.17 92.39 1.506
96 32 512 8 152.91 93.62 1.633
96 32 512 512 2230.92 1658.70 1.345
96 32 512 1024 4455.13 3287.62 1.355
96 32 512 2048 8969.06 6543.33 1.371
96 32 1024 1 184.12 131.07 1.405
96 32 1024 2 214.95 123.40 1.742
96 32 1024 4 227.36 124.23 1.830
96 32 1024 8 242.90 126.34 1.923
96 32 1024 512 3787.77 3007.45 1.259
96 32 1024 1024 7609.07 5979.71 1.272
96 32 1024 2048 15195.30 11897.70 1.277
96 32 2048 1 355.71 247.04 1.440
96 32 2048 2 381.19 192.88 1.976
96 32 2048 4 393.55 188.02 2.093
96 32 2048 8 413.80 190.69 2.170
96 32 2048 512 7170.58 5742.46 1.249
96 32 2048 1024 15010.50 11424.02 1.314
96 32 2048 2048 31179.88 22777.09 1.369
96 32 4096 1 732.07 365.71 2.002
96 32 4096 2 723.18 369.12 1.959
96 32 4096 4 746.12 368.88 2.023
96 32 4096 8 794.66 371.86 2.137
96 32 4096 512 14802.02 11198.49 1.322
96 32 4096 1024 30692.86 22647.14 1.355
96 32 4096 2048 61414.68 45290.27 1.356
112 32 512 1 110.03 94.80 1.161
112 32 512 2 164.12 96.85 1.695
112 32 512 4 224.63 98.12 2.289
112 32 512 8 379.52 99.09 3.830
112 32 512 512 19778.95 1784.52 11.084
112 32 512 1024 39568.86 3540.50 11.176
112 32 512 2048 79228.59 7047.04 11.243
112 32 1024 1 191.73 141.67 1.353
112 32 1024 2 298.63 133.50 2.237
112 32 1024 4 402.43 134.75 2.986
112 32 1024 8 697.98 137.51 5.076
112 32 1024 512 38610.12 3261.74 11.837
112 32 1024 1024 77230.99 6463.59 11.949
112 32 1024 2048 154548.6 12847.12 12.030
112 32 2048 1 377.50 224.07 1.685
112 32 2048 2 559.62 221.53 2.526
112 32 2048 4 756.76 223.23 3.390
112 32 2048 8 1332.31 224.76 5.928
112 32 2048 512 76513.81 6267.96 12.207
112 32 2048 1024 153822.6 12480.19 12.325
112 32 2048 2048 308721.8 24994.85 12.351
112 32 4096 1 764.97 427.78 1.788
112 32 4096 2 1085.56 417.19 2.602
112 32 4096 4 1476.68 421.29 3.505
112 32 4096 8 2641.42 426.24 6.197
112 32 4096 512 153534.6 12365.98 12.416
112 32 4096 1024 307760.6 24717.05 12.451
112 32 4096 2048 615496.0 49432.70 12.451
128 32 512 1 99.07 103.97 0.953
128 32 512 2 138.02 106.05 1.301
128 32 512 4 146.58 107.98 1.357
128 32 512 8 161.30 109.31 1.476
128 32 512 512 2439.23 1998.11 1.221
128 32 512 1024 4894.82 3970.35 1.233
128 32 512 2048 9891.79 7891.88 1.253
128 32 1024 1 170.30 134.44 1.267
128 32 1024 2 229.90 135.96 1.691
128 32 1024 4 235.05 137.93 1.704
128 32 1024 8 258.60 139.78 1.850
128 32 1024 512 4104.27 3614.97 1.135
128 32 1024 1024 8242.82 7162.60 1.151
128 32 1024 2048 16558.43 14240.18 1.163
128 32 2048 1 349.25 262.17 1.332
128 32 2048 2 411.53 257.67 1.597
128 32 2048 4 424.24 259.95 1.632
128 32 2048 8 446.55 261.23 1.709
128 32 2048 512 7650.29 6922.82 1.105
128 32 2048 1024 16121.95 13767.31 1.171
128 32 2048 2048 33444.10 27602.74 1.212
128 32 4096 1 694.53 477.46 1.455
128 32 4096 2 781.76 468.61 1.668
128 32 4096 4 806.22 472.49 1.706
128 32 4096 8 855.88 476.62 1.796
128 32 4096 512 15799.73 13656.81 1.157
128 32 4096 1024 32739.61 27299.12 1.199
128 32 4096 2048 65542.43 54540.84 1.202

The following tests show the advantage of skipping -INF blocks in the mask. This occurs during batched decoding and non-shared prompts (e.g. server slots). For large contexts and batch size of 8 the TG speed is x1.5 faster compared to master since we avoid a large amount of the "cross-sequence" attention compute that we have due to the unified KV cache:

make -j batched-bench && ./bin/batched-bench ../models/mistral-7b/ggml-model-f16.gguf 35000 0 999 0 512,1024,2048,4092 128 1,2,4,8
M2 Ultra

master

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.384 1332.29 3.388 37.78 3.772 169.68
512 128 2 1280 0.765 1337.80 8.241 31.07 9.006 142.13
512 128 4 2560 1.564 1309.74 8.535 59.99 10.099 253.50
512 128 8 5120 3.292 1244.24 9.118 112.31 12.410 412.58
1024 128 1 1152 0.764 1340.59 3.511 36.45 4.275 269.46
1024 128 2 2304 1.563 1310.05 8.480 30.19 10.044 229.40
1024 128 4 4608 3.291 1244.62 8.996 56.91 12.287 375.03
1024 128 8 9216 7.329 1117.68 10.083 101.56 17.412 529.29
2048 128 1 2176 1.564 1309.26 3.702 34.58 5.266 413.22
2048 128 2 4352 3.291 1244.69 8.928 28.67 12.219 356.17
2048 128 4 8704 7.328 1117.87 9.935 51.54 17.263 504.20
2048 128 8 17408 18.209 899.77 11.950 85.69 30.159 577.21
4092 128 1 4220 3.294 1242.34 4.099 31.23 7.393 570.84
4092 128 2 8440 7.329 1116.71 9.853 25.98 17.181 491.24
4092 128 4 16880 18.179 900.38 11.831 43.28 30.010 562.48
4092 128 8 33760 57.250 571.80 15.917 64.33 73.168 461.41

PR

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.379 1351.74 3.443 37.17 3.822 167.45
512 128 2 1280 0.731 1400.37 8.117 31.54 8.848 144.66
512 128 4 2560 1.486 1377.78 8.244 62.10 9.731 263.09
512 128 8 5120 3.107 1318.17 8.424 121.56 11.531 444.01
1024 128 1 1152 0.722 1418.04 3.485 36.73 4.207 273.82
1024 128 2 2304 1.469 1393.88 8.214 31.17 9.683 237.93
1024 128 4 4608 3.106 1318.78 8.473 60.43 11.579 397.97
1024 128 8 9216 6.937 1180.91 8.748 117.06 15.685 587.58
2048 128 1 2176 1.470 1393.22 3.599 35.57 5.069 429.31
2048 128 2 4352 3.106 1318.69 8.439 30.33 11.545 376.94
2048 128 4 8704 6.930 1182.06 8.936 57.30 15.866 548.59
2048 128 8 17408 16.801 975.16 9.423 108.67 26.224 663.82
4092 128 1 4220 3.111 1315.26 3.816 33.54 6.927 609.20
4092 128 2 8440 6.934 1180.19 8.907 28.74 15.842 532.77
4092 128 4 16880 16.785 975.16 9.883 51.81 26.668 632.97
4092 128 8 33760 45.540 718.85 10.802 94.80 56.342 599.20

Additionally, some results for prompt processing at different batch sizes and model sizes:

make -j llama-bench && ./bin/llama-bench -m ../models/llama-7b-v2/ggml-model-f16.gguf -m ../models/llama-13b-v2/ggml-model-f16.gguf -p 1024,2048,4096,8192 -b 512,1024,2048,4096,8192 -ngl 99
M2 Ultra
model backend n_batch test master t/s PR t/s speedup
llama 7B F16 Metal 512 pp 1024 1408.56 ± 1.98 1444.83 ± 2.00 1.026
llama 7B F16 Metal 512 pp 2048 1372.74 ± 0.54 1402.69 ± 1.29 1.022
llama 7B F16 Metal 512 pp 4096 1303.37 ± 1.22 1317.79 ± 0.73 1.011
llama 7B F16 Metal 512 pp 8192 1163.45 ± 0.33 1166.52 ± 0.57 1.003
llama 7B F16 Metal 512 tg 128 41.76 ± 0.07 41.87 ± 0.08 1.003
llama 7B F16 Metal 1024 pp 1024 1475.37 ± 3.67 1524.12 ± 3.29 1.033
llama 7B F16 Metal 1024 pp 2048 1439.38 ± 0.99 1478.06 ± 1.58 1.027
llama 7B F16 Metal 1024 pp 4096 1362.46 ± 1.20 1386.09 ± 0.86 1.017
llama 7B F16 Metal 1024 pp 8192 1210.67 ± 1.01 1228.48 ± 0.81 1.015
llama 7B F16 Metal 1024 tg 128 41.86 ± 0.05 41.95 ± 0.01 1.002
llama 7B F16 Metal 2048 pp 1024 1476.71 ± 1.78 1527.52 ± 1.96 1.034
llama 7B F16 Metal 2048 pp 2048 1444.91 ± 2.61 1486.82 ± 1.98 1.029
llama 7B F16 Metal 2048 pp 4096 1359.67 ± 1.37 1392.01 ± 3.13 1.024
llama 7B F16 Metal 2048 pp 8192 1203.30 ± 0.88 1233.46 ± 0.39 1.025
llama 7B F16 Metal 2048 tg 128 41.81 ± 0.04 41.90 ± 0.03 1.002
llama 7B F16 Metal 4096 pp 1024 1476.58 ± 1.68 1526.93 ± 2.56 1.034
llama 7B F16 Metal 4096 pp 2048 1445.21 ± 2.67 1487.42 ± 1.19 1.029
llama 7B F16 Metal 4096 pp 4096 1290.82 ± 2.85 1338.66 ± 3.77 1.037
llama 7B F16 Metal 4096 pp 8192 1148.23 ± 2.84 1194.02 ± 2.38 1.040
llama 7B F16 Metal 4096 tg 128 41.86 ± 0.03 41.85 ± 0.02 1.000
llama 7B F16 Metal 8192 pp 1024 1477.47 ± 0.70 1527.04 ± 1.24 1.034
llama 7B F16 Metal 8192 pp 2048 1445.74 ± 1.11 1487.39 ± 1.97 1.029
llama 7B F16 Metal 8192 pp 4096 1291.49 ± 2.80 1339.12 ± 3.03 1.037
llama 7B F16 Metal 8192 pp 8192 1036.40 ± 3.29 1101.68 ± 3.15 1.063
llama 7B F16 Metal 8192 tg 128 41.83 ± 0.05 41.76 ± 0.06 0.998
llama 13B F16 Metal 512 pp 1024 759.97 ± 0.34 775.94 ± 0.46 1.021
llama 13B F16 Metal 512 pp 2048 743.75 ± 0.32 755.90 ± 0.11 1.016
llama 13B F16 Metal 512 pp 4096 711.48 ± 0.11 716.48 ± 0.17 1.007
llama 13B F16 Metal 512 pp 8192 645.61 ± 0.15 646.88 ± 0.06 1.002
llama 13B F16 Metal 512 tg 128 22.37 ± 0.02 22.53 ± 0.01 1.007
llama 13B F16 Metal 1024 pp 1024 781.60 ± 0.65 805.34 ± 0.22 1.030
llama 13B F16 Metal 1024 pp 2048 765.52 ± 0.62 784.03 ± 0.38 1.024
llama 13B F16 Metal 1024 pp 4096 732.14 ± 0.15 742.55 ± 0.27 1.014
llama 13B F16 Metal 1024 pp 8192 662.75 ± 0.33 670.83 ± 0.11 1.012
llama 13B F16 Metal 1024 tg 128 22.36 ± 0.01 22.52 ± 0.01 1.007
llama 13B F16 Metal 2048 pp 1024 781.22 ± 0.97 804.82 ± 0.77 1.030
llama 13B F16 Metal 2048 pp 2048 768.84 ± 0.58 788.66 ± 0.08 1.026
llama 13B F16 Metal 2048 pp 4096 731.17 ± 0.52 745.80 ± 0.28 1.020
llama 13B F16 Metal 2048 pp 8192 655.80 ± 0.39 675.39 ± 0.28 1.030
llama 13B F16 Metal 2048 tg 128 22.37 ± 0.02 22.52 ± 0.02 1.007
llama 13B F16 Metal 4096 pp 1024 781.29 ± 1.11 805.09 ± 0.96 1.030
llama 13B F16 Metal 4096 pp 2048 768.68 ± 0.84 788.46 ± 0.50 1.026
llama 13B F16 Metal 4096 pp 4096 696.49 ± 0.86 722.03 ± 1.01 1.037
llama 13B F16 Metal 4096 pp 8192 626.14 ± 0.96 658.53 ± 0.63 1.052
llama 13B F16 Metal 4096 tg 128 22.39 ± 0.02 22.52 ± 0.02 1.006
llama 13B F16 Metal 8192 pp 1024 781.37 ± 0.53 804.94 ± 1.53 1.030
llama 13B F16 Metal 8192 pp 2048 768.90 ± 0.45 788.89 ± 0.36 1.026
llama 13B F16 Metal 8192 pp 4096 696.26 ± 0.79 721.58 ± 0.87 1.036
llama 13B F16 Metal 8192 pp 8192 570.26 ± 0.97 615.81 ± 1.01 1.080
llama 13B F16 Metal 8192 tg 128 22.34 ± 0.07 22.53 ± 0.02 1.009

build: b68a1122 (2086)

ggerganov avatar Jan 29 '24 17:01 ggerganov

After playing for some time with this kernel, I'm more convinced we should put efforts in implementing a matrix multiplication kernel that works for src1->ne[1] <= 8 (Metal), 16 (CUDA) in order to solve the inefficient batched decoding that we currently have at small batches. That kernel would pad the missing src1 rows with zeros and use the built-in matrix types (8x8 for Metal and 16x16 for CUDA) - this should give approximately constant speed for the different batch sizes. Ideally, it will be templated with a dequantization function so it works for all data types and it would be so performant that we can drop the mat-vec kernels all together.

ggerganov avatar Jan 29 '24 20:01 ggerganov

@ggerganov in the function simdgroup_load, what is the last parameter passed? I assume it's the stride of the data or am I wrong?

// load the queries from shared memory into local memory
        simdgroup_half8x8 mq[Q8][D8];

        for (int64_t j = 0; j < Q8; ++j) {
            for (int64_t i = 0; i < D8; ++i) {
                simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); // what is the third paramter
            }
        }

FSSRepo avatar Jan 29 '24 22:01 FSSRepo

in the function simdgroup_load, what is the last parameter passed? I assume it's the stride of the data or am I wrong?

The Metal shading language spec calls it elements_per_row, which defaults to the number of columns in the destination matrix (8 in this case).

The elements_per_row parameter indicates the number of elements in the source memory layout.

cebtenzzre avatar Jan 29 '24 22:01 cebtenzzre

Yes, it's the stride of the row in the source buffer (i.e. the shared memory buffer sq that holds the queries). It is specified as number of elements (i.e. number of halfs). Same for simdgroup_store but the stride is in the destination array.

ggerganov avatar Jan 30 '24 06:01 ggerganov

@ggerganov Is this behaivor expected? Did you think that if it was like this, all the elements of the array would be negative infinity in Metal?

Screenshot 2024-01-30 141733

FSSRepo avatar Jan 30 '24 19:01 FSSRepo

It's a bug - thanks for spotting it. Should be fixed in d073e4f

ggerganov avatar Jan 30 '24 19:01 ggerganov

@ggerganov I have been examining the kernel I created in CUDA, but it produces incorrect values despite all the operations being exactly the same. I really want to ask for your help, but I'm not sure if you have the time to at least take a look and compare your code with mine to see if I missed something or if I'm just doing something wrong. link to cuda implementation

FSSRepo avatar Jan 30 '24 19:01 FSSRepo

Cool, will take a detailed look tomorrow. On first look I suspect misconfiguration of the matrix layouts (row/col major) as I wrote in the comments there

ggerganov avatar Jan 30 '24 20:01 ggerganov