llama.cpp
llama.cpp copied to clipboard
llama : refactor llama_kv_cache, llama_context and llm_build_context
Overview
This PR is an intermediate step towards a more generic implementation that will support different underlying implementations of llama_kv_cache
, llama_context
and the graph building logic (a.k.a. llm_build_context
). The llama_kv_cache
is also introduced in the public API as an object, but it's actual functionality is yet to be defined in follow-up PRs.
Currently, no functional changes have been introduced. Mainly the code has been reorganized in a way to allow implementing new abstractions. The main changes in the implementation are:
- Avoid all explicit references to
llama_kv_cache
inllm_build_context
. The goal is to be able to construct the computation graphs only through the abstractllama_context
interface, which will hide the actual KV cache implementation and thus allow to be overloaded based on the parameters of the specific use case. More generally, thellama_context
hides not only the KV cache implementation, but all the internal state (such as, applied adapters, masks, etc. if any) with the exception of the model weights - these are still available to thellm_build_context
in order to be able to construct the backbone graph of the various architectures. - Avoid all explicit references to
llama_kv_cache
inllama_decode
/llama_encode
. These are abstracted through a new objectllama_batch_manager
which is produced by the currentllama_context
. Again the goal is to not make explicit assumptions about the underlying KV cache implementation while processing the batches and be able to delegate this logic to thellama_context
. Thellama_batch_manager
is produced by thellama_context
and will handle logic such as, restoring the KV cache state to consistent state upon errors, batching the input batch into micro batches according to the internal processing logic, etc. - Add initial serialization primitives to
llama_kv_cache
. In the future, these will be overloaded for the specific KV cache implementations through a common abstract interface.
The modifications so far are quite substantial and touch too many lines. Even though the code is in a very intermediate state, with many members still publicly exposed and without proper object-oriented implementation in place, it should still be mergeable.
The general class hierarchy that I have in mind is like this:
graph TD;
llama_kv_cache_unified --> llama_kv_cache;
llama_kv_cache_standard --> llama_kv_cache;
llama_kv_cache_mamba --> llama_kv_cache;
... --> llama_kv_cache;
Here, llama_kv_cache_unified
is basically the llama_kv_cache
implementation that we currently have. In the future, we will add more implementations that would be appropriate for multi-user scenarios (e.g. llama_kv_cache_standard
) or for Mamba architectures (llama_kv_cache_mamba
).
graph TD;
llama_context --> llama_model;
llama_context --> llama_cparams;
llama_context --> llama_adapter;
llama_context --> etc..;
llama_context[<b>llama_context</b>];
llama_context_no_kv[<b>llama_context_no_kv</b><br><br>];
llama_context_unified[<b>llama_context_unified</b><br><br>llama_kv_cache_unified];
llama_context_standard[<b>llama_context_standard</b><br><br>llama_kv_cache_standard];
llama_context_mamba[<b>llama_context_mamba</b><br><br>llama_kv_cache_mamba];
llama_context_enc_dec[<b>llama_context_enc_dec</b><br><br>llama_kv_cache_standard];
llama_context_no_kv -.-> llama_context;
llama_context_unified -.-> llama_context;
llama_context_standard -.-> llama_context;
llama_context_mamba -.-> llama_context;
llama_context_enc_dec -.-> llama_context;
... -.-> llama_context;
The base llama_context
class will implement common functionality such as low-level ggml
buffer and backend management + adapters, without the notion of a KV cache. The derived classes will specialize the llama_context
for different use-cases.
The llm_build_context
would operate only through the llama_build_i
interface and the batch processing will respectively only interact with the llama_batch_manager_i
interface. The type of llama_context
to construct in functions such as llama_init_from_model()
would be determined based on the model and the specified context parameters. For example, the user would be able to create both llama_context_unified
and llama_context_standard
for a LLM_ARCH_QWEN2
model. Or a llama_context_no_kv
for an encoding-only LLM_ARCH_BERT
model. And so on.
API changes
The current changes are only necessary to make the API more consistent in following the naming convention. To migrate, simply replace the old API calls with the new ones.
- Deprecate
llama_kv_cache_...
API - Add
llama_kv_self_...
API
In the future, the llama_kv_cache_...
API will be changed to work with struct llama_kv_cache
instead of struct llama_context
and the functionality will be extended to support things like saving, copying, loading, etc.
Notes
- [x] Fix
build_qwen2vl
, inp_pos,lctx.n_pos_per_token
hack - [x] Worst case for
n_outputs
andn_outputs_enc
inllm_build_context
seem incorrect - [x] Remove
inp_s_seq
- not used - [x] fix
onst bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); truct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
- [ ] Fix T5
- [x] Fix RWKV
- [ ] Fix
batch.pos == NULL
-llama_context::pos_max()
is used incorrectly - [x] Dedup the reserve code
- [x] Errors on unimplemented interface
- [ ] Build multiple graphs per model (e.g. enc, dec, no-logits, etc.)
- [x] Implement causal input for cache-less
llama_context
- [ ] Simplify
encode()/decode()
- [x] Remove
worst_case
from thellama_graph_i
API? - [x] Wrap input tensors in structs
- [x] Add trace logs
PRs to resolve
- [x] https://github.com/ggerganov/llama.cpp/pull/11381 (e665b57)
- [ ] https://github.com/ggerganov/llama.cpp/pull/11446
- [x] https://github.com/ggerganov/llama.cpp/pull/11445
- [x] https://github.com/ggerganov/llama.cpp/pull/10573
New features
- [ ] https://github.com/ggerganov/llama.cpp/pull/11571