llama.cpp
llama.cpp copied to clipboard
ggml : add Flash Attention
ref #3365
Setting up what's needed for Flash Attention support in ggml
and llama.cpp
The proposed operator performs:
// new
res = ggml_flash_attn(ctx, q, k, v, kq_mask, kq_scale);
// fused scale + mask + soft_max (old)
kq = ggml_mul_mat (ctx, k, q);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
kqv = ggml_mul_mat (ctx, v, kq);
kqv = ggml_permute (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d (ctx, kqv, n_embd_head_k*n_head, n_tokens);
// unfused (old)
kq = ggml_mul_mat (ctx, k, q);
kq = ggml_scale (ctx, kq, kq_scale);
kq = ggml_add (ctx, kq, kq_mask);
kq = ggml_soft_max(ctx, kq);
kqv = ggml_mul_mat (ctx, v, kq);
kqv = ggml_permute (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d (ctx, kqv, n_embd_head_k*n_head, n_tokens);
Suggestions and comments for the API are welcome. Looking for help in implementing efficient GPU kernels - please open PR to this branch if you have proposals
- [x]
ggml
API:ggml_flash_attn_ext()
- [x]
llama.cpp
use inllm_build_kqv()
- [x] add
test-backend-ops
test - [x] CPU implementation (slow, just for testing)
- [x] CUDA implementation (https://github.com/ggerganov/llama.cpp/pull/6374)
- [x] Metal implementation
- [x]
GGML_PREC_F32
support (CUDA) (https://github.com/ggerganov/llama.cpp/pull/6646) - [x]
GGML_PREC_F32
support (Metal)
Changes to ggml
/llama
-
Add new op
GGML_OP_FLASH_ATTN_EXT
andggml_flash_attn_ext()
call (before merging we can consider reusing the oldGGML_OP_FLASH_ATTN
and removing the legacy code) -
Change
mask
type to F16 forggml_soft_max_ext()
and require that it is padded toGGML_KQ_MASK_PAD 32
-
The
n_kv
denoting the number of computed tokens from the KV cache is now padded to 128 (from 32) to support larger FA blocks without making out-of-bounds access -
The minimum
llama_context_params.n_batch
that can be used isGGML_KQ_MASK_PAD 32
to avoid out-of-bounds access in the FA kernels for small batch size -
The
V
tensor is no longer transposed when storing it in the KV cache - The input buffer is cleared with zeros to avoid NaNs in the padded tensors
Things to consider
- Pass KQ list with/instead of KQ mask
- Pass block-wise KQ mask
- Support Alibi
- Finally transform Alibi as
ggml_add()
? (low-prio) - No longer store transposed V-cache (gg/flash-attn-online)
Testing
./tests/test-backend-ops -o FLASH_ATTN_EXT
-
main
,server
: add-fa
-
llama-bench
: add-fa 1
Benchmark
Baseline:
# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf
# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b Metal perf
FA kernel:
# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf
# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b Metal perf
Text-generation after long prompt:
# without flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 0 1 99 8192 256 1
# with flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 1 1 99 8192 256 1
References
- https://arxiv.org/pdf/1805.02867.pdf Online softmax
- https://arxiv.org/pdf/2112.05682.pdf O(n) memory self-attention
- https://arxiv.org/pdf/2307.08691.pdf Flash-attention 2
Since we are doing this from scratch, wouldn't it be better to remove the custom attention mask entirely and pass a list of KV cells used in each sequence? Considering our implementation of batching, I think we should be looking at implementing something closer to paged attention rather than flash attention. I suppose it is possible to convert the mask to a list of sequences in the kernels, but it would be less efficient.
Yes, we can pass list instead of mask. I am not sure of the format though - if each list has different length I feel it will hinder the GPU performance.
Edit: I just got an idea - we can pass both the kq_mask
as it is, plus a second boolean tensor that tells each token to which KV blocks it should attend. For example, we split the KV cache in blocks of 128 (or some other round number) and a token (i.e. row in q
) attends to a block if atleast one of the cells in it belongs to the token's sequence. This way, we can skip entire blocks of the KV cache that do not belong to the current sequence and keep the problem parallel-friendly. Thoughts?
We could use a vector with dimension [num_seqs]
that contains the length of the sequences, and a 2D tensor with dimensions [max_seq_len, num_seqs]
that contains the KV cells in each sequence, padded to the length of the longest sequence.
It seems that vLLM has added a new version of paged attention since it looked into the implementation (https://github.com/vllm-project/vllm/pull/1348). I am not sure what are the changes, but I think it is worth looking into what they are doing. The kernel is in https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu
Alibi could also be done in this kernel.
Regarding the Alibi, I feel reinterpreting it as a KQ_mask
via ggml_add()
is a more general solution - we will avoid having a ggml_alibi()
operator and explicit support in the kernels that we write (like in vLLM).
It remains to be seen though if the KQ_mask
will be a bottleneck - my feeling is that just avoiding the extra read/write of KQ will bring us close to the optimal performance, even with the existing "cross-KV compute" drawback.
Will take a look at the vLLM code and I've updated the description with some of the things from this discussion
@ggerganov @slaren Together with @JohannesGaessler and @FSSRepo we are working on the same thing over at https://github.com/Pints-App/llama.cpp/pull/1 which we intend to do a pull to llamacpp once work is done.
However, I think we will converge into this one. Given the amount of work here, @ggerganov @slaren how do you want to organise this? The 3 of us are in a temporary discord group actually to work this out, perhaps we can use that?
What are your thoughts?
Discord is not an option for me - I prefer to communicate over Github issues / discussions / e-mail.
Happy to see you have started work on the CUDA implementation. Please take into account the proposed API here - note that it is still a WIP and can change. I can review the implementation that you have when you think it is in a good state. Would prefer PR's that are compatible with this branch so we can verify correctness using test-backend-ops
and support for all backends.
@ggerganov Got it. Let us work on a plan to converge with this PR.
~~test-backend-ops -o FLASH_ATTN_EXT
fails for Metal on my M2 Pro, is this known?~~
edit: I see, not implemented yet.
Any performance numbers?
There is now an initial version for Metal - see the kernel_flash_attn_ext_f16
kernel. It's already slightly faster for TG after a long prompt, but the PP speed is ~15% lower compared to master
. Still looking into it - any ideas are appreciated.
I think the CPU is already better, but I haven't done too much tests in this regard, as I'm more interested in the GPU performance
This is based on reading the 3 papers in the description and implementing my understanding of what Flash-attention / Flash-decoding is supposed to do. I guess there is some chance that I'm still misunderstanding something and that's why I can't improve the performance universally.
I've changed the API so that the V
tensor is no longer stored transposed in the KV cache. Since we no longer rely on the ggml_mul_mat
in the attention, this is no longer required and it would allow for easier quantization of the V cache in the future since the head dimensions are quite often a factor of 32. Still contemplating over this, but I think it will be a good change overall:
https://github.com/ggerganov/llama.cpp/blob/17720fad669eed6171ddf17184da5bab50adeb72/ggml.h#L1622-L1635
The ggml_flash_attn_ext
also avoids the final permute + cont that we normally have - this is built-in the kernels
I'm currently looking into the performance for a full KQ_mask
so the format of the masked indices is not relevant. The plan is to get to something that is faster without masking and only after that will start looking into improving this
If anyone is interested in porting this implementation to CUDA - feel free to give it a try. For the moment I'm focusing on Metal as it is more convenient for me to develop
Great job!! I'm struggling to achieve performance improvement in CUDA because I'm having issues where 90% of the kernel execution time is spent on memory I/O, and the remaining 10% on computation.
It's already slightly faster for TG after a long prompt, but the PP speed is ~15% lower compared to master. Still looking into it - any ideas are appreciated.
I'm struggling to achieve performance improvement in CUDA because I'm having issues where 90% of the kernel execution time is spent on memory I/O, and the remaining 10% on computation.
My experience is that writing GPU matrix multiplication code for cases in which at least one of the matrices is thin in at least one of its dimensions is comparatively easy. No matter what you do, the implementation is going to be I/O bound anyways and a simple implementation based on dot products is going to be close to optimal. If you have large matrices the operation becomes compute rather than I/O bound however. Now the challenge is to utilize the compute pipelines by loading the data in such a way that arithmetic intensity becomes maximal.
The speedup reported by the FlashAttention paper was I think ~15%. This is a significant speedup but it also means that to get any speedup at all you would need to base your FlashAttention implementation off of a matrix multiplication kernel that achieves performance comparable to cuBLAS GEMM. This is no easy task.
I am currently working on matrix multiplication kernels for quantized data based on int8 arithmetic: https://github.com/ggerganov/llama.cpp/pull/4801 . They are already faster than cuBLAS FP16 GEMM and can be used in the future as the basis for custom matrix multiplication kernels. Conceivably one of those custom matrix multiplication kernels could be FlashAttention.
If anyone is interested in porting this implementation to CUDA - feel free to give it a try. For the moment I'm focusing on Metal as it is more convenient for me to develop
I will not get to it in the foreseeable future. My priority is to get the int8 matrix multiplication in order, after that the quantization of the KV cache, and after that I want to look into training in llama.cpp. If it turns out that FlashAttention would be useful there I will maybe look into it but definitely no promises in terms of timeline or features.
Having a fused kernel allows you to skip 2 times reading and writing KQ to global memory and avoid computing half the operations compared to matrix-multiplication based implementation via masking and you can also apply online softmax. So even if the fused kernel is not cuBLAS-level optimized I was hoping to outperform the existing implementation
I mean, I don't have the hardware to profile the Metal code but you should check how long the kernel takes compared to the equivalent GEMM kernels. In my experience getting within 50% of the performance of highly optimized GEMM code is already not easy. And if your kernel is slower than 50% GEMM then the total runtime will increase even if you do only half as many calculations. In my int8 PR the actual matrix multiplication kernel is currently only ~10% faster than the cuBLAS FP16 GEMM kernel even though int8 tensor cores are in theory twice as fast as FP16 tensor cores.
For large batch sizes the amount of data read from and written to global memory should not make much of a difference in terms of inference performance since the matrices should have size $O(N^2)$ but the computation needs $O(N^3)$ operations.
After implementing the kernel with simdgroup matrix ops, it is now universally better than the master
version. Tested with head size 128 - might needs some extra work for heads that are not divisible by 32, but it also might just work.
Next step will be to port this implementation in CUDA - will get into this in a few days after catching up a bit with other issues and if it hasn't been implemented yet by other people
After implementing the kernel with simdgroup matrix ops, it is now universally better than the
master
version. Tested with head size 128 - might needs some extra work for heads that are not divisible by 32, but it also might just work.Next step will be to port this implementation in CUDA - will get into this in a few days after catching up a bit with other issues and if it hasn't been implemented yet by other people
Can you give some example benchmark numbers on master vs this PR with Metal ?
It's early for numbers - with the current block size of 8x32 (8 queries x 32 cache items) it's marginally better. I'm just glad that the performance finally makes sense and this looks like the correct way to implement this kernel
@ggerganov I'm trying to port the kernel you already have in Metal to CUDA, but I'm not completely clear on how it works yet. So, I would appreciate it if you could help me clarify my understanding.
-
I understand that for each thread group, 8 queries are processed (batch size). For every thread group, there are 2 warps(
simdgroups
) Then the batch size is greater than 4 (number of queries), each warp handles one query. Is that right? -
Could you explain the data layout in shared memory?
threadgroup half * pq = (threadgroup half *) (shared + 0*D); // offset: 0, size: head_dim
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D);
threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D); // offset: head_dim + (warp_id * (head_dim + cache_per_warp)), size: head_dim + cache_per_warp
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D);
// In the code, it implies that it has a size of (head_dim + warps*(head_dim + cache_per_warp)) * queries per threadgroup.
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); // offset: head_dim*2 + (warp_id * (head_dim + cache_per_warp)), size: ??
- In CUDA, there is no vector type
half4
, only half andhalf2
. Therefore, I will use the latter when copying data from VRAM to SRAM.
const int64_t D4 = D/4; // changing this to D/2 to allow half2 when copying data?
const int64_t N4 = N_SIMDWIDTH;
const int64_t L4 = (D4 + N4 - 1)/N4;
const int64_t D8 = D/8;
const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half
const int64_t T4 = T/4; // shared memory size per query in half4 <-- this too. T/2
- Tensor cores in CUDA require a multiplication of 16x16x16, whereas the multiplication you are performing here is in 8x8x8. So, I assume you will need to make the following changes.
// change queries per threadgroup (Q) to 16 to have a multiplication 16x16x16,
// head_dim and batch_size must be multiple of 16 to use tensor cores, or add padding for alignment to avoid overflow when load the data into the tensor core fragments
const int64_t L4 = (D4 + N4 - 1)/N4;
const int64_t D8 = D/8; // change this to 16
-
mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
Does this create a QxQ matrix or a vector with Q elements? -
In the online softmax loop, it seems like this loop does nothing since it only multiplies
ms
by an array of zeros.
// online softmax
for (int64_t j = 0; j < Q; ++j) {
...
for (int64_t i = 0; i < L4; ++i) {
ps4[j*T4 + N4*i + tiisg] *= ms; // This does nothing since the ps4 has only been initialized to 0.0 so far.
}
ss[j*T + p] = vs;
}
Any other questions I have, I will let you know. Thank you!
@FSSRepo Thanks for looking into this. Here is some information:
- The threadgroup can have configurable number of warps:
https://github.com/ggerganov/llama.cpp/blob/6fea843b246409a3c4b26156745a89e4ba01029b/ggml-metal.m#L2255-L2258
It is yet to be determined how many warps to set - I think the paper recommends 4 or 8. But in any case, the kernel should support a configurable number of warps.
Each warp in the threadgroup works on the same exact batch of 8 queries - one head at a time. (in the q
query tensor, the batch index is 1 and the head index is 2). Each warp processes 1/nsg
of the KV cache:
https://github.com/ggerganov/llama.cpp/blob/6fea843b246409a3c4b26156745a89e4ba01029b/ggml-metal.metal#L2129-L2130
Here ne11
is the size of the KV cache, nsg
is the number of warps (a.k.a. simdgroups) in the threadgroup and ~C = 8
~ C = 32
is the number of cache items processed on each iteration. The warps work completely independently from each other thanks to the online softmax. They need to be synchronized just at the end when we reduce their results. ~For CUDA likely C
has to be 16
.~
- The total shared memory required by this kernel is:
https://github.com/ggerganov/llama.cpp/blob/6fea843b246409a3c4b26156745a89e4ba01029b/ggml-metal.m#L2263
-
nqptg
- number of queries per threadgroup (i.e. currently 8, but for CUDA likely to be 16) -
ne00
- the head size (e.g. 128 for LLaMA, Mistral, etc.) -
nsg
- number of warps (configurable) -
ncpsg
- number of cache values per simdgroup (i.e. 32, has to be multiple of the warp size)
On Metal, we have 32KB of shared memory per threadgroup, so here are some possible configurations:
nqptg | ne00 | nsg | ncpsg | smem (bytes) |
---|---|---|---|---|
8 | 128 | 4 | 32 | 12288 |
8 | 128 | 8 | 32 | 22528 |
8 | 128 | 12 | 32 | 32768 |
8 | 128 | 4 | 64 | 14336 |
8 | 128 | 8 | 64 | 26624 |
8 | 128 | 4 | 96 | 16384 |
8 | 128 | 8 | 96 | 30720 |
8 | 128 | 4 | 128 | 18432 |
16 | 128 | 4 | 32 | 24576 |
16 | 128 | 4 | 64 | 28672 |
16 | 128 | 4 | 96 | 32768 |
Not sure about CUDA - hopefully it has more shared memory so we can fit more queries and cache values per threadgroup.
The reason for this shared memory layout is because for each query, we need to load the head (this is ne00
elements) and then for each warp, we need a scratch buffer to be storing the resulting head of QKV
. We also need a small scratch buffer of size ncspg
to write the intermediate attention values from the Q*K^T
result.
In theory the nsg*(ne00 + 1*ncpsg)
part of the buffer could reside in the warp registers. But on Metal, there is no way to utilize the simd matrix operations in that case - they require to load and store data to shared or device memory. I don't know what is the situation with CUDA, but one should look into moving this in the registers if possible
-
Yes, use the biggest vector type available -
half2
in this case -
Yes, given that CUDA operates with 16x16 matrices, I suspect you should be using 16 queries per threadgroup. Or a multiple of 16 queries if the SRAM allows it. I plan to extend the Metal kernel to support configurable number of queries (multiple of 8).
Head dimensions are always multiple of 16 so this should be OK.
Number of queries should be padded with zeros if not a multiple:
https://github.com/ggerganov/llama.cpp/blob/6fea843b246409a3c4b26156745a89e4ba01029b/ggml-metal.metal#L2055-L2063
-
Yes, this creates a 8x8 matrix filled with zeros
-
Notice that the
ps
andps4
pointers point to the same shared memory data.ps
is necessary for loading the data into the 8x8 simd matrices (i.e.simdgroup_load
), while theps4
is needed for faster multiplication. The line that you referenced basically does the scaling of the ~attention~ output during online softmax - this is the first term in this equation from the FA2 paper:
Depending on what matrix operators are available in CUDA, this could be represented in matrix form and even might be possible to avoid writing and reading these intermediate results to the ncpsg
shared memory buffer. But on Metal, there is no way to create a 8x8 diagonal matrix with different diagonal elements ms
, so that is why I've implemented this in scalar form.
I plan to extend the kernel to support larger block sizes (i.e. 8x64, 16x32, 16x64, etc.) which I expect to result in further improvement. But I suppose for initial implementation, one should first try to implement the smallest block size - 16x32 for CUDA and work from there.
Here are some preliminary results measuring the performance just for computing the attention. This speedup is lower-bound since the test data does not use -INF
mask which in real cases helps significantly the new flash attention kernel by skipping the computation of such blocks.
The speed on master
for head sizes not multiple of 64 is pretty bad because in such cases we fallback to the mat-vec kernels, instead of using the efficient mat-mat kernels.
For small contexts (< 1024) the bs=1
case on master
is better due to the efficient mat-vec kernels, but at bigger contexts the PR is better
The speed in this PR for head size of 256 is pretty bad because it seems that at this size the amount of local memory in the simdgroup is quite large and this somehow affects the performance. Still not 100% sure though - there could be some other explanation
M2 Ultra
Head size | Heads | n_kv | n_batch | master us/run | PR us/run | speedup |
---|---|---|---|---|---|---|
64 | 32 | 512 | 1 | 48.35 | 52.66 | 0.918 |
64 | 32 | 512 | 2 | 65.10 | 53.51 | 1.217 |
64 | 32 | 512 | 4 | 64.60 | 54.14 | 1.193 |
64 | 32 | 512 | 8 | 66.28 | 54.98 | 1.206 |
64 | 32 | 512 | 512 | 359.21 | 320.89 | 1.119 |
64 | 32 | 512 | 1024 | 679.95 | 556.66 | 1.221 |
64 | 32 | 512 | 2048 | 1307.07 | 1070.70 | 1.221 |
64 | 32 | 1024 | 1 | 73.74 | 55.80 | 1.322 |
64 | 32 | 1024 | 2 | 84.81 | 56.87 | 1.491 |
64 | 32 | 1024 | 4 | 87.57 | 57.65 | 1.519 |
64 | 32 | 1024 | 8 | 90.37 | 59.91 | 1.508 |
64 | 32 | 1024 | 512 | 615.42 | 515.89 | 1.193 |
64 | 32 | 1024 | 1024 | 1162.55 | 976.28 | 1.191 |
64 | 32 | 1024 | 2048 | 2258.26 | 1897.12 | 1.190 |
64 | 32 | 2048 | 1 | 127.41 | 77.92 | 1.635 |
64 | 32 | 2048 | 2 | 137.83 | 74.32 | 1.855 |
64 | 32 | 2048 | 4 | 141.00 | 73.93 | 1.907 |
64 | 32 | 2048 | 8 | 143.79 | 75.10 | 1.915 |
64 | 32 | 2048 | 512 | 941.67 | 955.09 | 0.986 |
64 | 32 | 2048 | 1024 | 1859.78 | 1829.82 | 1.016 |
64 | 32 | 2048 | 2048 | 3773.73 | 3563.03 | 1.059 |
64 | 32 | 4096 | 1 | 236.36 | 134.64 | 1.755 |
64 | 32 | 4096 | 2 | 242.09 | 113.24 | 2.138 |
64 | 32 | 4096 | 4 | 238.95 | 107.49 | 2.223 |
64 | 32 | 4096 | 8 | 248.57 | 107.94 | 2.303 |
64 | 32 | 4096 | 512 | 2128.95 | 1909.24 | 1.115 |
64 | 32 | 4096 | 1024 | 4191.58 | 3833.60 | 1.093 |
64 | 32 | 4096 | 2048 | 8662.38 | 7592.72 | 1.141 |
80 | 32 | 512 | 1 | 51.22 | 45.78 | 1.119 |
80 | 32 | 512 | 2 | 78.24 | 46.65 | 1.677 |
80 | 32 | 512 | 4 | 85.26 | 48.55 | 1.756 |
80 | 32 | 512 | 8 | 113.76 | 48.40 | 2.350 |
80 | 32 | 512 | 512 | 4179.59 | 318.88 | 13.107 |
80 | 32 | 512 | 1024 | 8208.02 | 591.91 | 13.867 |
80 | 32 | 512 | 2048 | 16379.66 | 1144.30 | 14.314 |
80 | 32 | 1024 | 1 | 78.68 | 61.85 | 1.272 |
80 | 32 | 1024 | 2 | 120.74 | 62.71 | 1.925 |
80 | 32 | 1024 | 4 | 135.29 | 64.29 | 2.104 |
80 | 32 | 1024 | 8 | 194.74 | 65.25 | 2.985 |
80 | 32 | 1024 | 512 | 8164.54 | 554.10 | 14.735 |
80 | 32 | 1024 | 1024 | 16046.73 | 1042.44 | 15.393 |
80 | 32 | 1024 | 2048 | 32204.06 | 2019.51 | 15.946 |
80 | 32 | 2048 | 1 | 120.86 | 86.05 | 1.405 |
80 | 32 | 2048 | 2 | 207.56 | 84.41 | 2.459 |
80 | 32 | 2048 | 4 | 236.96 | 86.10 | 2.752 |
80 | 32 | 2048 | 8 | 343.46 | 85.79 | 4.003 |
80 | 32 | 2048 | 512 | 15971.57 | 1030.95 | 15.492 |
80 | 32 | 2048 | 1024 | 31757.59 | 1956.62 | 16.231 |
80 | 32 | 2048 | 2048 | 63448.85 | 3805.53 | 16.673 |
80 | 32 | 4096 | 1 | 220.70 | 133.37 | 1.655 |
80 | 32 | 4096 | 2 | 385.37 | 121.89 | 3.162 |
80 | 32 | 4096 | 4 | 433.54 | 121.74 | 3.561 |
80 | 32 | 4096 | 8 | 656.55 | 122.32 | 5.367 |
80 | 32 | 4096 | 512 | 32207.92 | 2042.83 | 15.766 |
80 | 32 | 4096 | 1024 | 63862.88 | 3952.05 | 16.159 |
80 | 32 | 4096 | 2048 | 127819.94 | 7618.14 | 16.778 |
96 | 32 | 512 | 1 | 52.16 | 48.19 | 1.082 |
96 | 32 | 512 | 2 | 64.60 | 49.42 | 1.307 |
96 | 32 | 512 | 4 | 65.62 | 50.44 | 1.301 |
96 | 32 | 512 | 8 | 68.31 | 51.56 | 1.325 |
96 | 32 | 512 | 512 | 479.10 | 367.21 | 1.305 |
96 | 32 | 512 | 1024 | 918.58 | 684.40 | 1.342 |
96 | 32 | 512 | 2048 | 1798.89 | 1323.87 | 1.359 |
96 | 32 | 1024 | 1 | 76.41 | 62.88 | 1.215 |
96 | 32 | 1024 | 2 | 88.54 | 64.60 | 1.371 |
96 | 32 | 1024 | 4 | 89.44 | 65.23 | 1.371 |
96 | 32 | 1024 | 8 | 92.75 | 66.22 | 1.401 |
96 | 32 | 1024 | 512 | 806.81 | 643.86 | 1.253 |
96 | 32 | 1024 | 1024 | 1534.11 | 1220.26 | 1.257 |
96 | 32 | 1024 | 2048 | 3015.06 | 2366.33 | 1.274 |
96 | 32 | 2048 | 1 | 122.82 | 88.51 | 1.388 |
96 | 32 | 2048 | 2 | 142.06 | 87.75 | 1.619 |
96 | 32 | 2048 | 4 | 145.86 | 89.62 | 1.628 |
96 | 32 | 2048 | 8 | 149.40 | 88.97 | 1.679 |
96 | 32 | 2048 | 512 | 1263.90 | 1230.62 | 1.027 |
96 | 32 | 2048 | 1024 | 2496.61 | 2319.37 | 1.076 |
96 | 32 | 2048 | 2048 | 5082.33 | 4508.18 | 1.127 |
96 | 32 | 4096 | 1 | 226.70 | 147.65 | 1.535 |
96 | 32 | 4096 | 2 | 249.52 | 134.95 | 1.849 |
96 | 32 | 4096 | 4 | 259.10 | 136.24 | 1.902 |
96 | 32 | 4096 | 8 | 264.56 | 135.94 | 1.946 |
96 | 32 | 4096 | 512 | 2726.24 | 2541.81 | 1.073 |
96 | 32 | 4096 | 1024 | 5391.37 | 4940.26 | 1.091 |
96 | 32 | 4096 | 2048 | 11118.87 | 9774.00 | 1.138 |
112 | 32 | 512 | 1 | 53.03 | 50.35 | 1.053 |
112 | 32 | 512 | 2 | 78.74 | 52.39 | 1.503 |
112 | 32 | 512 | 4 | 89.65 | 53.01 | 1.691 |
112 | 32 | 512 | 8 | 116.58 | 54.33 | 2.146 |
112 | 32 | 512 | 512 | 4398.15 | 388.05 | 11.334 |
112 | 32 | 512 | 1024 | 8582.10 | 731.05 | 11.739 |
112 | 32 | 512 | 2048 | 17236.83 | 1416.57 | 12.168 |
112 | 32 | 1024 | 1 | 74.96 | 67.07 | 1.118 |
112 | 32 | 1024 | 2 | 123.45 | 68.21 | 1.810 |
112 | 32 | 1024 | 4 | 144.76 | 68.74 | 2.106 |
112 | 32 | 1024 | 8 | 196.37 | 70.79 | 2.774 |
112 | 32 | 1024 | 512 | 8554.53 | 688.48 | 12.425 |
112 | 32 | 1024 | 1024 | 16775.31 | 1307.00 | 12.835 |
112 | 32 | 1024 | 2048 | 33713.82 | 2548.04 | 13.231 |
112 | 32 | 2048 | 1 | 126.37 | 92.52 | 1.366 |
112 | 32 | 2048 | 2 | 213.71 | 92.35 | 2.314 |
112 | 32 | 2048 | 4 | 248.02 | 94.07 | 2.637 |
112 | 32 | 2048 | 8 | 357.08 | 95.50 | 3.739 |
112 | 32 | 2048 | 512 | 16697.99 | 1317.57 | 12.673 |
112 | 32 | 2048 | 1024 | 33034.48 | 2491.99 | 13.256 |
112 | 32 | 2048 | 2048 | 66383.25 | 4868.13 | 13.636 |
112 | 32 | 4096 | 1 | 237.16 | 155.86 | 1.522 |
112 | 32 | 4096 | 2 | 373.50 | 144.62 | 2.583 |
112 | 32 | 4096 | 4 | 435.65 | 146.07 | 2.982 |
112 | 32 | 4096 | 8 | 643.29 | 146.13 | 4.402 |
112 | 32 | 4096 | 512 | 33739.44 | 2635.00 | 12.804 |
112 | 32 | 4096 | 1024 | 66798.13 | 5029.28 | 13.282 |
112 | 32 | 4096 | 2048 | 133687.09 | 9861.75 | 13.556 |
128 | 32 | 512 | 1 | 46.68 | 54.74 | 0.853 |
128 | 32 | 512 | 2 | 66.58 | 57.64 | 1.155 |
128 | 32 | 512 | 4 | 67.92 | 57.95 | 1.172 |
128 | 32 | 512 | 8 | 70.21 | 58.34 | 1.203 |
128 | 32 | 512 | 512 | 517.57 | 436.09 | 1.187 |
128 | 32 | 512 | 1024 | 1000.31 | 826.79 | 1.210 |
128 | 32 | 512 | 2048 | 1979.87 | 1605.15 | 1.233 |
128 | 32 | 1024 | 1 | 69.69 | 69.38 | 1.004 |
128 | 32 | 1024 | 2 | 91.11 | 71.09 | 1.282 |
128 | 32 | 1024 | 4 | 92.39 | 72.72 | 1.270 |
128 | 32 | 1024 | 8 | 95.79 | 72.42 | 1.323 |
128 | 32 | 1024 | 512 | 865.96 | 770.68 | 1.124 |
128 | 32 | 1024 | 1024 | 1660.97 | 1463.05 | 1.135 |
128 | 32 | 1024 | 2048 | 3288.41 | 2845.84 | 1.156 |
128 | 32 | 2048 | 1 | 114.68 | 99.64 | 1.151 |
128 | 32 | 2048 | 2 | 149.57 | 100.09 | 1.494 |
128 | 32 | 2048 | 4 | 151.86 | 100.76 | 1.507 |
128 | 32 | 2048 | 8 | 155.56 | 101.88 | 1.527 |
128 | 32 | 2048 | 512 | 1367.97 | 1539.56 | 0.889 |
128 | 32 | 2048 | 1024 | 2697.36 | 2844.56 | 0.948 |
128 | 32 | 2048 | 2048 | 5535.49 | 5478.64 | 1.010 |
128 | 32 | 4096 | 1 | 210.81 | 175.34 | 1.202 |
128 | 32 | 4096 | 2 | 264.01 | 166.61 | 1.585 |
128 | 32 | 4096 | 4 | 269.43 | 166.14 | 1.622 |
128 | 32 | 4096 | 8 | 276.98 | 166.71 | 1.661 |
128 | 32 | 4096 | 512 | 2926.51 | 3033.96 | 0.965 |
128 | 32 | 4096 | 1024 | 5816.29 | 5828.25 | 0.998 |
128 | 32 | 4096 | 2048 | 12035.27 | 11497.53 | 1.047 |
256 | 32 | 512 | 1 | 52.02 | 326.86 | 0.159 |
256 | 32 | 512 | 2 | 79.17 | 327.95 | 0.241 |
256 | 32 | 512 | 4 | 79.87 | 329.14 | 0.243 |
256 | 32 | 512 | 8 | 83.67 | 329.81 | 0.254 |
256 | 32 | 512 | 512 | 843.42 | 4572.88 | 0.184 |
256 | 32 | 512 | 1024 | 1665.21 | 8587.96 | 0.194 |
256 | 32 | 512 | 2048 | 3390.91 | 16978.11 | 0.200 |
256 | 32 | 1024 | 1 | 75.91 | 628.88 | 0.121 |
256 | 32 | 1024 | 2 | 116.52 | 630.33 | 0.185 |
256 | 32 | 1024 | 4 | 118.84 | 628.09 | 0.189 |
256 | 32 | 1024 | 8 | 125.11 | 630.85 | 0.198 |
256 | 32 | 1024 | 512 | 1362.92 | 8857.97 | 0.154 |
256 | 32 | 1024 | 1024 | 2686.44 | 17106.66 | 0.157 |
256 | 32 | 1024 | 2048 | 5370.67 | 34220.45 | 0.157 |
256 | 32 | 2048 | 1 | 134.83 | 917.62 | 0.147 |
256 | 32 | 2048 | 2 | 195.35 | 916.88 | 0.213 |
256 | 32 | 2048 | 4 | 199.93 | 909.51 | 0.220 |
256 | 32 | 2048 | 8 | 205.00 | 909.46 | 0.225 |
256 | 32 | 2048 | 512 | 2232.24 | 17710.61 | 0.126 |
256 | 32 | 2048 | 1024 | 4480.89 | 34771.62 | 0.129 |
256 | 32 | 2048 | 2048 | 9132.45 | 69601.42 | 0.131 |
256 | 32 | 4096 | 1 | 258.21 | 1530.00 | 0.169 |
256 | 32 | 4096 | 2 | 342.99 | 1529.23 | 0.224 |
256 | 32 | 4096 | 4 | 357.43 | 1526.52 | 0.234 |
256 | 32 | 4096 | 8 | 358.46 | 1523.29 | 0.235 |
256 | 32 | 4096 | 512 | 4506.89 | 36281.69 | 0.124 |
256 | 32 | 4096 | 1024 | 9059.17 | 70992.35 | 0.128 |
256 | 32 | 4096 | 2048 | 18719.47 | 142732.73 | 0.131 |
M1 Pro
Head size | Heads | n_kv | n_batch | master us/run | PR us/run | speedup |
---|---|---|---|---|---|---|
64 | 32 | 512 | 1 | 171.72 | 145.12 | 1.183 |
64 | 32 | 512 | 2 | 158.36 | 122.98 | 1.288 |
64 | 32 | 512 | 4 | 169.14 | 127.02 | 1.332 |
64 | 32 | 512 | 8 | 192.28 | 125.90 | 1.527 |
64 | 32 | 512 | 512 | 1854.92 | 1516.88 | 1.223 |
64 | 32 | 512 | 1024 | 3205.20 | 2682.15 | 1.195 |
64 | 32 | 512 | 2048 | 6451.63 | 5339.80 | 1.208 |
64 | 32 | 1024 | 1 | 173.10 | 111.86 | 1.547 |
64 | 32 | 1024 | 2 | 175.84 | 109.55 | 1.605 |
64 | 32 | 1024 | 4 | 186.56 | 110.96 | 1.681 |
64 | 32 | 1024 | 8 | 193.02 | 112.15 | 1.721 |
64 | 32 | 1024 | 512 | 2820.42 | 2422.31 | 1.164 |
64 | 32 | 1024 | 1024 | 5655.72 | 4825.81 | 1.172 |
64 | 32 | 1024 | 2048 | 11300.35 | 9621.63 | 1.174 |
64 | 32 | 2048 | 1 | 328.60 | 202.96 | 1.619 |
64 | 32 | 2048 | 2 | 304.72 | 181.98 | 1.674 |
64 | 32 | 2048 | 4 | 315.31 | 149.04 | 2.116 |
64 | 32 | 2048 | 8 | 332.71 | 150.34 | 2.213 |
64 | 32 | 2048 | 512 | 5477.61 | 4578.76 | 1.196 |
64 | 32 | 2048 | 1024 | 11708.44 | 9115.65 | 1.284 |
64 | 32 | 2048 | 2048 | 24465.36 | 18190.31 | 1.345 |
64 | 32 | 4096 | 1 | 660.99 | 295.43 | 2.237 |
64 | 32 | 4096 | 2 | 564.48 | 268.22 | 2.105 |
64 | 32 | 4096 | 4 | 583.85 | 263.57 | 2.215 |
64 | 32 | 4096 | 8 | 634.65 | 265.49 | 2.390 |
64 | 32 | 4096 | 512 | 11739.99 | 8937.42 | 1.314 |
64 | 32 | 4096 | 1024 | 24629.73 | 17878.37 | 1.378 |
64 | 32 | 4096 | 2048 | 49174.03 | 35553.78 | 1.383 |
80 | 32 | 512 | 1 | 103.56 | 84.43 | 1.227 |
80 | 32 | 512 | 2 | 161.68 | 86.45 | 1.870 |
80 | 32 | 512 | 4 | 215.49 | 87.03 | 2.476 |
80 | 32 | 512 | 8 | 358.89 | 88.77 | 4.043 |
80 | 32 | 512 | 512 | 18589.89 | 1443.56 | 12.878 |
80 | 32 | 512 | 1024 | 37176.46 | 2862.85 | 12.986 |
80 | 32 | 512 | 2048 | 74406.37 | 5700.48 | 13.053 |
80 | 32 | 1024 | 1 | 180.91 | 126.16 | 1.434 |
80 | 32 | 1024 | 2 | 291.73 | 122.60 | 2.380 |
80 | 32 | 1024 | 4 | 388.21 | 123.62 | 3.140 |
80 | 32 | 1024 | 8 | 664.89 | 126.11 | 5.272 |
80 | 32 | 1024 | 512 | 36531.48 | 2574.14 | 14.192 |
80 | 32 | 1024 | 1024 | 73058.68 | 5121.40 | 14.265 |
80 | 32 | 1024 | 2048 | 146167.13 | 10211.46 | 14.314 |
80 | 32 | 2048 | 1 | 345.14 | 241.11 | 1.431 |
80 | 32 | 2048 | 2 | 543.55 | 170.14 | 3.195 |
80 | 32 | 2048 | 4 | 724.30 | 168.35 | 4.302 |
80 | 32 | 2048 | 8 | 1267.28 | 168.50 | 7.521 |
80 | 32 | 2048 | 512 | 72650.87 | 4885.48 | 14.871 |
80 | 32 | 2048 | 1024 | 146028.15 | 9746.80 | 14.982 |
80 | 32 | 2048 | 2048 | 293172.12 | 19514.49 | 15.023 |
80 | 32 | 4096 | 1 | 705.26 | 320.43 | 2.201 |
80 | 32 | 4096 | 2 | 1051.25 | 320.59 | 3.279 |
80 | 32 | 4096 | 4 | 1408.55 | 321.61 | 4.380 |
80 | 32 | 4096 | 8 | 2515.26 | 323.06 | 7.786 |
80 | 32 | 4096 | 512 | 145826.01 | 9616.87 | 15.164 |
80 | 32 | 4096 | 1024 | 292761.77 | 19220.51 | 15.232 |
80 | 32 | 4096 | 2048 | 585546.11 | 38428.58 | 15.237 |
96 | 32 | 512 | 1 | 105.99 | 89.76 | 1.181 |
96 | 32 | 512 | 2 | 128.04 | 91.61 | 1.398 |
96 | 32 | 512 | 4 | 139.17 | 92.39 | 1.506 |
96 | 32 | 512 | 8 | 152.91 | 93.62 | 1.633 |
96 | 32 | 512 | 512 | 2230.92 | 1658.70 | 1.345 |
96 | 32 | 512 | 1024 | 4455.13 | 3287.62 | 1.355 |
96 | 32 | 512 | 2048 | 8969.06 | 6543.33 | 1.371 |
96 | 32 | 1024 | 1 | 184.12 | 131.07 | 1.405 |
96 | 32 | 1024 | 2 | 214.95 | 123.40 | 1.742 |
96 | 32 | 1024 | 4 | 227.36 | 124.23 | 1.830 |
96 | 32 | 1024 | 8 | 242.90 | 126.34 | 1.923 |
96 | 32 | 1024 | 512 | 3787.77 | 3007.45 | 1.259 |
96 | 32 | 1024 | 1024 | 7609.07 | 5979.71 | 1.272 |
96 | 32 | 1024 | 2048 | 15195.30 | 11897.70 | 1.277 |
96 | 32 | 2048 | 1 | 355.71 | 247.04 | 1.440 |
96 | 32 | 2048 | 2 | 381.19 | 192.88 | 1.976 |
96 | 32 | 2048 | 4 | 393.55 | 188.02 | 2.093 |
96 | 32 | 2048 | 8 | 413.80 | 190.69 | 2.170 |
96 | 32 | 2048 | 512 | 7170.58 | 5742.46 | 1.249 |
96 | 32 | 2048 | 1024 | 15010.50 | 11424.02 | 1.314 |
96 | 32 | 2048 | 2048 | 31179.88 | 22777.09 | 1.369 |
96 | 32 | 4096 | 1 | 732.07 | 365.71 | 2.002 |
96 | 32 | 4096 | 2 | 723.18 | 369.12 | 1.959 |
96 | 32 | 4096 | 4 | 746.12 | 368.88 | 2.023 |
96 | 32 | 4096 | 8 | 794.66 | 371.86 | 2.137 |
96 | 32 | 4096 | 512 | 14802.02 | 11198.49 | 1.322 |
96 | 32 | 4096 | 1024 | 30692.86 | 22647.14 | 1.355 |
96 | 32 | 4096 | 2048 | 61414.68 | 45290.27 | 1.356 |
112 | 32 | 512 | 1 | 110.03 | 94.80 | 1.161 |
112 | 32 | 512 | 2 | 164.12 | 96.85 | 1.695 |
112 | 32 | 512 | 4 | 224.63 | 98.12 | 2.289 |
112 | 32 | 512 | 8 | 379.52 | 99.09 | 3.830 |
112 | 32 | 512 | 512 | 19778.95 | 1784.52 | 11.084 |
112 | 32 | 512 | 1024 | 39568.86 | 3540.50 | 11.176 |
112 | 32 | 512 | 2048 | 79228.59 | 7047.04 | 11.243 |
112 | 32 | 1024 | 1 | 191.73 | 141.67 | 1.353 |
112 | 32 | 1024 | 2 | 298.63 | 133.50 | 2.237 |
112 | 32 | 1024 | 4 | 402.43 | 134.75 | 2.986 |
112 | 32 | 1024 | 8 | 697.98 | 137.51 | 5.076 |
112 | 32 | 1024 | 512 | 38610.12 | 3261.74 | 11.837 |
112 | 32 | 1024 | 1024 | 77230.99 | 6463.59 | 11.949 |
112 | 32 | 1024 | 2048 | 154548.6 | 12847.12 | 12.030 |
112 | 32 | 2048 | 1 | 377.50 | 224.07 | 1.685 |
112 | 32 | 2048 | 2 | 559.62 | 221.53 | 2.526 |
112 | 32 | 2048 | 4 | 756.76 | 223.23 | 3.390 |
112 | 32 | 2048 | 8 | 1332.31 | 224.76 | 5.928 |
112 | 32 | 2048 | 512 | 76513.81 | 6267.96 | 12.207 |
112 | 32 | 2048 | 1024 | 153822.6 | 12480.19 | 12.325 |
112 | 32 | 2048 | 2048 | 308721.8 | 24994.85 | 12.351 |
112 | 32 | 4096 | 1 | 764.97 | 427.78 | 1.788 |
112 | 32 | 4096 | 2 | 1085.56 | 417.19 | 2.602 |
112 | 32 | 4096 | 4 | 1476.68 | 421.29 | 3.505 |
112 | 32 | 4096 | 8 | 2641.42 | 426.24 | 6.197 |
112 | 32 | 4096 | 512 | 153534.6 | 12365.98 | 12.416 |
112 | 32 | 4096 | 1024 | 307760.6 | 24717.05 | 12.451 |
112 | 32 | 4096 | 2048 | 615496.0 | 49432.70 | 12.451 |
128 | 32 | 512 | 1 | 99.07 | 103.97 | 0.953 |
128 | 32 | 512 | 2 | 138.02 | 106.05 | 1.301 |
128 | 32 | 512 | 4 | 146.58 | 107.98 | 1.357 |
128 | 32 | 512 | 8 | 161.30 | 109.31 | 1.476 |
128 | 32 | 512 | 512 | 2439.23 | 1998.11 | 1.221 |
128 | 32 | 512 | 1024 | 4894.82 | 3970.35 | 1.233 |
128 | 32 | 512 | 2048 | 9891.79 | 7891.88 | 1.253 |
128 | 32 | 1024 | 1 | 170.30 | 134.44 | 1.267 |
128 | 32 | 1024 | 2 | 229.90 | 135.96 | 1.691 |
128 | 32 | 1024 | 4 | 235.05 | 137.93 | 1.704 |
128 | 32 | 1024 | 8 | 258.60 | 139.78 | 1.850 |
128 | 32 | 1024 | 512 | 4104.27 | 3614.97 | 1.135 |
128 | 32 | 1024 | 1024 | 8242.82 | 7162.60 | 1.151 |
128 | 32 | 1024 | 2048 | 16558.43 | 14240.18 | 1.163 |
128 | 32 | 2048 | 1 | 349.25 | 262.17 | 1.332 |
128 | 32 | 2048 | 2 | 411.53 | 257.67 | 1.597 |
128 | 32 | 2048 | 4 | 424.24 | 259.95 | 1.632 |
128 | 32 | 2048 | 8 | 446.55 | 261.23 | 1.709 |
128 | 32 | 2048 | 512 | 7650.29 | 6922.82 | 1.105 |
128 | 32 | 2048 | 1024 | 16121.95 | 13767.31 | 1.171 |
128 | 32 | 2048 | 2048 | 33444.10 | 27602.74 | 1.212 |
128 | 32 | 4096 | 1 | 694.53 | 477.46 | 1.455 |
128 | 32 | 4096 | 2 | 781.76 | 468.61 | 1.668 |
128 | 32 | 4096 | 4 | 806.22 | 472.49 | 1.706 |
128 | 32 | 4096 | 8 | 855.88 | 476.62 | 1.796 |
128 | 32 | 4096 | 512 | 15799.73 | 13656.81 | 1.157 |
128 | 32 | 4096 | 1024 | 32739.61 | 27299.12 | 1.199 |
128 | 32 | 4096 | 2048 | 65542.43 | 54540.84 | 1.202 |
The following tests show the advantage of skipping -INF
blocks in the mask. This occurs during batched decoding and non-shared prompts (e.g. server
slots). For large contexts and batch size of 8 the TG speed is x1.5 faster compared to master
since we avoid a large amount of the "cross-sequence" attention compute that we have due to the unified KV cache:
make -j batched-bench && ./bin/batched-bench ../models/mistral-7b/ggml-model-f16.gguf 35000 0 999 0 512,1024,2048,4092 128 1,2,4,8
M2 Ultra
master
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
512 | 128 | 1 | 640 | 0.384 | 1332.29 | 3.388 | 37.78 | 3.772 | 169.68 |
512 | 128 | 2 | 1280 | 0.765 | 1337.80 | 8.241 | 31.07 | 9.006 | 142.13 |
512 | 128 | 4 | 2560 | 1.564 | 1309.74 | 8.535 | 59.99 | 10.099 | 253.50 |
512 | 128 | 8 | 5120 | 3.292 | 1244.24 | 9.118 | 112.31 | 12.410 | 412.58 |
1024 | 128 | 1 | 1152 | 0.764 | 1340.59 | 3.511 | 36.45 | 4.275 | 269.46 |
1024 | 128 | 2 | 2304 | 1.563 | 1310.05 | 8.480 | 30.19 | 10.044 | 229.40 |
1024 | 128 | 4 | 4608 | 3.291 | 1244.62 | 8.996 | 56.91 | 12.287 | 375.03 |
1024 | 128 | 8 | 9216 | 7.329 | 1117.68 | 10.083 | 101.56 | 17.412 | 529.29 |
2048 | 128 | 1 | 2176 | 1.564 | 1309.26 | 3.702 | 34.58 | 5.266 | 413.22 |
2048 | 128 | 2 | 4352 | 3.291 | 1244.69 | 8.928 | 28.67 | 12.219 | 356.17 |
2048 | 128 | 4 | 8704 | 7.328 | 1117.87 | 9.935 | 51.54 | 17.263 | 504.20 |
2048 | 128 | 8 | 17408 | 18.209 | 899.77 | 11.950 | 85.69 | 30.159 | 577.21 |
4092 | 128 | 1 | 4220 | 3.294 | 1242.34 | 4.099 | 31.23 | 7.393 | 570.84 |
4092 | 128 | 2 | 8440 | 7.329 | 1116.71 | 9.853 | 25.98 | 17.181 | 491.24 |
4092 | 128 | 4 | 16880 | 18.179 | 900.38 | 11.831 | 43.28 | 30.010 | 562.48 |
4092 | 128 | 8 | 33760 | 57.250 | 571.80 | 15.917 | 64.33 | 73.168 | 461.41 |
PR
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
512 | 128 | 1 | 640 | 0.379 | 1351.74 | 3.443 | 37.17 | 3.822 | 167.45 |
512 | 128 | 2 | 1280 | 0.731 | 1400.37 | 8.117 | 31.54 | 8.848 | 144.66 |
512 | 128 | 4 | 2560 | 1.486 | 1377.78 | 8.244 | 62.10 | 9.731 | 263.09 |
512 | 128 | 8 | 5120 | 3.107 | 1318.17 | 8.424 | 121.56 | 11.531 | 444.01 |
1024 | 128 | 1 | 1152 | 0.722 | 1418.04 | 3.485 | 36.73 | 4.207 | 273.82 |
1024 | 128 | 2 | 2304 | 1.469 | 1393.88 | 8.214 | 31.17 | 9.683 | 237.93 |
1024 | 128 | 4 | 4608 | 3.106 | 1318.78 | 8.473 | 60.43 | 11.579 | 397.97 |
1024 | 128 | 8 | 9216 | 6.937 | 1180.91 | 8.748 | 117.06 | 15.685 | 587.58 |
2048 | 128 | 1 | 2176 | 1.470 | 1393.22 | 3.599 | 35.57 | 5.069 | 429.31 |
2048 | 128 | 2 | 4352 | 3.106 | 1318.69 | 8.439 | 30.33 | 11.545 | 376.94 |
2048 | 128 | 4 | 8704 | 6.930 | 1182.06 | 8.936 | 57.30 | 15.866 | 548.59 |
2048 | 128 | 8 | 17408 | 16.801 | 975.16 | 9.423 | 108.67 | 26.224 | 663.82 |
4092 | 128 | 1 | 4220 | 3.111 | 1315.26 | 3.816 | 33.54 | 6.927 | 609.20 |
4092 | 128 | 2 | 8440 | 6.934 | 1180.19 | 8.907 | 28.74 | 15.842 | 532.77 |
4092 | 128 | 4 | 16880 | 16.785 | 975.16 | 9.883 | 51.81 | 26.668 | 632.97 |
4092 | 128 | 8 | 33760 | 45.540 | 718.85 | 10.802 | 94.80 | 56.342 | 599.20 |
Additionally, some results for prompt processing at different batch sizes and model sizes:
make -j llama-bench && ./bin/llama-bench -m ../models/llama-7b-v2/ggml-model-f16.gguf -m ../models/llama-13b-v2/ggml-model-f16.gguf -p 1024,2048,4096,8192 -b 512,1024,2048,4096,8192 -ngl 99
M2 Ultra
model | backend | n_batch | test | master t/s | PR t/s | speedup |
---|---|---|---|---|---|---|
llama 7B F16 | Metal | 512 | pp 1024 | 1408.56 ± 1.98 | 1444.83 ± 2.00 | 1.026 |
llama 7B F16 | Metal | 512 | pp 2048 | 1372.74 ± 0.54 | 1402.69 ± 1.29 | 1.022 |
llama 7B F16 | Metal | 512 | pp 4096 | 1303.37 ± 1.22 | 1317.79 ± 0.73 | 1.011 |
llama 7B F16 | Metal | 512 | pp 8192 | 1163.45 ± 0.33 | 1166.52 ± 0.57 | 1.003 |
llama 7B F16 | Metal | 512 | tg 128 | 41.76 ± 0.07 | 41.87 ± 0.08 | 1.003 |
llama 7B F16 | Metal | 1024 | pp 1024 | 1475.37 ± 3.67 | 1524.12 ± 3.29 | 1.033 |
llama 7B F16 | Metal | 1024 | pp 2048 | 1439.38 ± 0.99 | 1478.06 ± 1.58 | 1.027 |
llama 7B F16 | Metal | 1024 | pp 4096 | 1362.46 ± 1.20 | 1386.09 ± 0.86 | 1.017 |
llama 7B F16 | Metal | 1024 | pp 8192 | 1210.67 ± 1.01 | 1228.48 ± 0.81 | 1.015 |
llama 7B F16 | Metal | 1024 | tg 128 | 41.86 ± 0.05 | 41.95 ± 0.01 | 1.002 |
llama 7B F16 | Metal | 2048 | pp 1024 | 1476.71 ± 1.78 | 1527.52 ± 1.96 | 1.034 |
llama 7B F16 | Metal | 2048 | pp 2048 | 1444.91 ± 2.61 | 1486.82 ± 1.98 | 1.029 |
llama 7B F16 | Metal | 2048 | pp 4096 | 1359.67 ± 1.37 | 1392.01 ± 3.13 | 1.024 |
llama 7B F16 | Metal | 2048 | pp 8192 | 1203.30 ± 0.88 | 1233.46 ± 0.39 | 1.025 |
llama 7B F16 | Metal | 2048 | tg 128 | 41.81 ± 0.04 | 41.90 ± 0.03 | 1.002 |
llama 7B F16 | Metal | 4096 | pp 1024 | 1476.58 ± 1.68 | 1526.93 ± 2.56 | 1.034 |
llama 7B F16 | Metal | 4096 | pp 2048 | 1445.21 ± 2.67 | 1487.42 ± 1.19 | 1.029 |
llama 7B F16 | Metal | 4096 | pp 4096 | 1290.82 ± 2.85 | 1338.66 ± 3.77 | 1.037 |
llama 7B F16 | Metal | 4096 | pp 8192 | 1148.23 ± 2.84 | 1194.02 ± 2.38 | 1.040 |
llama 7B F16 | Metal | 4096 | tg 128 | 41.86 ± 0.03 | 41.85 ± 0.02 | 1.000 |
llama 7B F16 | Metal | 8192 | pp 1024 | 1477.47 ± 0.70 | 1527.04 ± 1.24 | 1.034 |
llama 7B F16 | Metal | 8192 | pp 2048 | 1445.74 ± 1.11 | 1487.39 ± 1.97 | 1.029 |
llama 7B F16 | Metal | 8192 | pp 4096 | 1291.49 ± 2.80 | 1339.12 ± 3.03 | 1.037 |
llama 7B F16 | Metal | 8192 | pp 8192 | 1036.40 ± 3.29 | 1101.68 ± 3.15 | 1.063 |
llama 7B F16 | Metal | 8192 | tg 128 | 41.83 ± 0.05 | 41.76 ± 0.06 | 0.998 |
llama 13B F16 | Metal | 512 | pp 1024 | 759.97 ± 0.34 | 775.94 ± 0.46 | 1.021 |
llama 13B F16 | Metal | 512 | pp 2048 | 743.75 ± 0.32 | 755.90 ± 0.11 | 1.016 |
llama 13B F16 | Metal | 512 | pp 4096 | 711.48 ± 0.11 | 716.48 ± 0.17 | 1.007 |
llama 13B F16 | Metal | 512 | pp 8192 | 645.61 ± 0.15 | 646.88 ± 0.06 | 1.002 |
llama 13B F16 | Metal | 512 | tg 128 | 22.37 ± 0.02 | 22.53 ± 0.01 | 1.007 |
llama 13B F16 | Metal | 1024 | pp 1024 | 781.60 ± 0.65 | 805.34 ± 0.22 | 1.030 |
llama 13B F16 | Metal | 1024 | pp 2048 | 765.52 ± 0.62 | 784.03 ± 0.38 | 1.024 |
llama 13B F16 | Metal | 1024 | pp 4096 | 732.14 ± 0.15 | 742.55 ± 0.27 | 1.014 |
llama 13B F16 | Metal | 1024 | pp 8192 | 662.75 ± 0.33 | 670.83 ± 0.11 | 1.012 |
llama 13B F16 | Metal | 1024 | tg 128 | 22.36 ± 0.01 | 22.52 ± 0.01 | 1.007 |
llama 13B F16 | Metal | 2048 | pp 1024 | 781.22 ± 0.97 | 804.82 ± 0.77 | 1.030 |
llama 13B F16 | Metal | 2048 | pp 2048 | 768.84 ± 0.58 | 788.66 ± 0.08 | 1.026 |
llama 13B F16 | Metal | 2048 | pp 4096 | 731.17 ± 0.52 | 745.80 ± 0.28 | 1.020 |
llama 13B F16 | Metal | 2048 | pp 8192 | 655.80 ± 0.39 | 675.39 ± 0.28 | 1.030 |
llama 13B F16 | Metal | 2048 | tg 128 | 22.37 ± 0.02 | 22.52 ± 0.02 | 1.007 |
llama 13B F16 | Metal | 4096 | pp 1024 | 781.29 ± 1.11 | 805.09 ± 0.96 | 1.030 |
llama 13B F16 | Metal | 4096 | pp 2048 | 768.68 ± 0.84 | 788.46 ± 0.50 | 1.026 |
llama 13B F16 | Metal | 4096 | pp 4096 | 696.49 ± 0.86 | 722.03 ± 1.01 | 1.037 |
llama 13B F16 | Metal | 4096 | pp 8192 | 626.14 ± 0.96 | 658.53 ± 0.63 | 1.052 |
llama 13B F16 | Metal | 4096 | tg 128 | 22.39 ± 0.02 | 22.52 ± 0.02 | 1.006 |
llama 13B F16 | Metal | 8192 | pp 1024 | 781.37 ± 0.53 | 804.94 ± 1.53 | 1.030 |
llama 13B F16 | Metal | 8192 | pp 2048 | 768.90 ± 0.45 | 788.89 ± 0.36 | 1.026 |
llama 13B F16 | Metal | 8192 | pp 4096 | 696.26 ± 0.79 | 721.58 ± 0.87 | 1.036 |
llama 13B F16 | Metal | 8192 | pp 8192 | 570.26 ± 0.97 | 615.81 ± 1.01 | 1.080 |
llama 13B F16 | Metal | 8192 | tg 128 | 22.34 ± 0.07 | 22.53 ± 0.02 | 1.009 |
build: b68a1122 (2086)
After playing for some time with this kernel, I'm more convinced we should put efforts in implementing a matrix multiplication kernel that works for src1->ne[1] <= 8 (Metal), 16 (CUDA)
in order to solve the inefficient batched decoding that we currently have at small batches. That kernel would pad the missing src1
rows with zeros and use the built-in matrix types (8x8 for Metal and 16x16 for CUDA) - this should give approximately constant speed for the different batch sizes. Ideally, it will be templated with a dequantization function so it works for all data types and it would be so performant that we can drop the mat-vec kernels all together.
@ggerganov in the function simdgroup_load
, what is the last parameter passed? I assume it's the stride of the data or am I wrong?
// load the queries from shared memory into local memory
simdgroup_half8x8 mq[Q8][D8];
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); // what is the third paramter
}
}
in the function
simdgroup_load
, what is the last parameter passed? I assume it's the stride of the data or am I wrong?
The Metal shading language spec calls it elements_per_row, which defaults to the number of columns in the destination matrix (8 in this case).
The elements_per_row parameter indicates the number of elements in the source memory layout.
Yes, it's the stride of the row in the source buffer (i.e. the shared memory buffer sq
that holds the queries). It is specified as number of elements (i.e. number of half
s). Same for simdgroup_store
but the stride is in the destination array.
@ggerganov Is this behaivor expected? Did you think that if it was like this, all the elements of the array would be negative infinity in Metal?
It's a bug - thanks for spotting it. Should be fixed in d073e4f
@ggerganov I have been examining the kernel I created in CUDA, but it produces incorrect values despite all the operations being exactly the same. I really want to ask for your help, but I'm not sure if you have the time to at least take a look and compare your code with mine to see if I missed something or if I'm just doing something wrong. link to cuda implementation
Cool, will take a detailed look tomorrow. On first look I suspect misconfiguration of the matrix layouts (row/col major) as I wrote in the comments there