ggml: avoid rebuild of GGML graph for each token (#7456)
Introduces caching of GGML graph to avoid unnecessary full rebuild between each token. KV cache parameters, which change with each token, are updated directly in cached GGML graph. Can be disabled with GGML_DISABLE_GRAPH_CACHING environment variable.
- [x] I have read the contributing guidelines
- Self-reported review complexity:
- [ ] Low
- [x] Medium
- [ ] High
Here are the BS 1 inference performance improvements I have measured for this optimization (all on Linux):
| Hardware | Llama 7B Q4 | Llama 13B Q4 |
|---|---|---|
| H100-SXM + AMD EPYC 7413 | 6% | 5% |
| H100-PCIe + AMD EPYC 7313 | 4% | 4% |
| A100-PCIe + Intel 4210R | 7% | 5% |
| L40S + AMD EPYC 7763 | 3% | 2% |
| RTX-4090 + AMD Ryzen 3700X | 4% | 3% |
| A40 + Intel 4210R | 5% | 3% |
@slaren @JohannesGaessler @ggerganov this is currently working OK for my local tests but I'd appreciate any further testing from your side to help determine if it needs to be made more robust.
Did you check whether with this optimization the performance for large batch sizes becomes better with CUDA graphs enabled?
Did you check whether with this optimization the performance for large batch sizes becomes better with CUDA graphs enabled?
No, at the moment both this optimization and CUDA graphs are only activated with batch size 1 (investigating BS > 1 is future work).
This is currently causing a segfault for llama-bench - working on it.
This is currently causing a segfault for llama-bench - working on it.
Now fixed.
@agray3 : Any follow-up on this PR? I had to revert it on my own LlamaCPP to follow up the recent commits, with regret due to the performance boost it gave me back then.
@agray3 : Any follow-up on this PR? I had to revert it on my own LlamaCPP to follow up the recent commits, with regret due to the performance boost it gave me back then.
See the comment by Georgi above - https://github.com/ggerganov/llama.cpp/pull/8366#discussion_r1670011134. It's not obvious to me that there's any way to adapt this patch so make it acceptable for merging, so it remains a POC for now.
@agray3 : Well, your POC is working. Would you consider to maintain it, so it can be merged without conflicts with current master for those who want to use it? The perf boost is a no brainer for me, and I suppose, for some other folks.
@agray3 : Well, your POC is working. Would you consider to maintain it, so it can be merged without conflicts with current master for those who want to use it? The perf boost is a no brainer for me, and I suppose, for some other folks.
Sure, I have now rebased it on the latest master branch. Let me know if/when it has conflicts again. Glad it is useful.
Fantastic, @agray3. TYVM.
Hey @agray3. I think the PR needs to be adjusted again to the LCPP refactors subsequent to your last adjustment!
@gerganov @slaren @agray3 I'm interested in reducing the CPU overhead associated with building the GGML graph and would like to follow up on this PR. In particular, I'd like to brainstorm ideas to make this PR more generic and broadly applicable.
One of the main concerns appears to be the special handling of the KV cache when reusing a cached graph. A potential solution is to introduce a specialized copy operator for the KV cache that fuses the functionality of ggml_view_1d and ggml_cpy into a single operation. This new operator would compute the appropriate offset within the KV cache for inserting the new token, then copy the new KV token to that offset.
The offset in the KV cache can be calculated using the inp_pos tensor, which is already an input to the graph and is used by the ROPE operator.
This approach would also eliminate the need for pointer indirection introduced in PR #9017 to manage the KV cache when CUDA graph mode is enabled. With this change, neither the GGML nor the CUDA graph would need to be modified during the decode phase.
Do you think this idea is worth pursuing? Any other ideas on how I can make this PR acceptable?
The way that I think we should implement this is as follows:
- Make sure that the topology of the graph (i.e. tensors, their shapes and the view offsets) is fully determined by the shapes of the input tensors
- Given that, before constructing a full graph, we would first build the input tensors and check if their shapes are the same as the ones we had for the previous graph. If they are the same, then don't build a new graph and reuse the previous graph
To achieve that, the only thing that is missing for most models is to make the KV cache storing and loading to be function of the input tensors. Specifically, this is currently what is preventing it:
Storing:
https://github.com/ggml-org/llama.cpp/blob/e57bb87cede38341963a7a884630dbfb09c7dc00/src/llama-graph.cpp#L1276-L1280
Loading:
https://github.com/ggml-org/llama.cpp/blob/e57bb87cede38341963a7a884630dbfb09c7dc00/src/llama-graph.cpp#L1285-L1286
The second part (the loading) is easy to fix. We just have to create the K and V view tensors at the beginning and mark them as inputs (see llm_graph_input_i). Since their shape depends on n_kv and it is padded typically to 256, then the shape of these views will change only once every 256 ubatches for text-generation cases.
The first part however is more difficult to fix. This is because the K and V view tensors in which we store the new KV cache are currently a function of the KV head. For example, here is the K view:
https://github.com/ggml-org/llama.cpp/blob/e57bb87cede38341963a7a884630dbfb09c7dc00/src/llama-kv-cache-unified.cpp#L645-L647
The offset argument of the view is ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur). Even though the shape of the view is constant for ubatches with the same number of tokens, the offset changes for every ubatch because of the different head_cur.
One way to fix that is to extend the ggml_cpy operation to take an optional offset from a single-element I64 tensor and move the head_cur offset to that tensor. We would then mark this I64 tensor as graph input. But this might be very hacky, so I am wondering if there is something better we can do.
In any case, the KV head changes are related only for models that use a KV cache. For embedding models such as BERT, we can already start implementing the logic for reusing the previous graph in llama_context::encode().
Thanks @ggerganov for the quick response.
This is exactly what I was proposing above:
"A potential solution is to introduce a specialized copy operator for the KV cache that fuses the functionality of ggml_view_1d and ggml_cpy into a single operation. This new operator would compute the appropriate offset within the KV cache for inserting the new token, then copy the new KV token to that offset."
One way to fix that is to extend the ggml_cpy operation to take an optional offset from a single-element I64 tensor and move the head_cur offset to that tensor. We would then mark this I64 tensor as graph input. But this might be very hacky, so I am wondering if there is something better we can do.
My proposal above was to use inp_pos tensor to calculate the offset, which is already an input to the graph and is used by the ROPE operator. Do you think using inp_pos to calculate offset makes sense?
Do you think using inp_pos to calculate offset makes sense?
No, the inp_pos contains the positions of the tokens in the sequence, while the head is the offset in the memory buffer. With the unified KV cache, these 2 are completely independent and cannot use one to compute the other.
Do you think using
inp_posto calculate offset makes sense?
Not all models use inp_pos (e.g. recurrent models don't). Also, the head of the self-attention unified KV cache doesn't necessarily depend on token positions (this is only a coincidence in single-user/single-sequence inference). In multi-user/multi-sequence inference, the inp_pos is unrelated to the KV cache offset.
The dynamic copy operator you're describing sounds a lot like ggml_get_rows, but for copying the values into an existing tensor, kind of like a ggml_set_rows, where the I32 tensor would be destination row indices instead of source (although I don't really know if non-consecutive indices would be used in practice (except maybe to support non-contiguous KV cache slots?[^1]), so a simpler ggml-cpy-like operator with a destination offset coming from an integer tensor would be sufficient (not sure about the API it should have exactly)).
[^1]: This could also simplify the recurrent state cache, which currently has to keep track of the defrag done with ggml_get_rows and inp_s_copy at each ubatch, although arguably it might be better to keep the written slots contiguous.
kind of like a ggml_set_rows
Very good idea - this might be exactly what we need.
Thanks, this makes sense.
Do we need a specialized function to handle transposed v-cache or ggml_set_rows will be enough? Check this part of the code:
https://github.com/ggml-org/llama.cpp/blob/7675c555a13c9f473249e59a54db35032ce8e0fc/src/llama-kv-cache-unified.cpp#L668-L673
Update: Never mind, I realized this can be easily handled by creating a view that sets appropriate strides and uses V-cache as column major.
I have tried to implement ggml_set_rows in PR #14274
Currently it is only for CPU but I can add other backends if this is what we need here.
I have tried to implement
ggml_set_rowsin PR #14274Currently it is only for CPU but I can add other backends if this is what we need here.
Thanks! We should prototype one of the models using this and see if the approach works (make sure we are not overlooking some detail). The idea is to try to use this op in the unified kv cache to replace the head offset from the ggml_cpys (see get_k() and get_v() methods) and will also need to add some basic optional mechanism to first build all input tensors separately before building the full graph so we can compare them with the input tensors that were used to build the previous graph. I can try to push an initial version of this soon.
This has been resolved in #14482