llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

llama : refactor llama_kv_cache, llama_context and llm_build_context

Open ggerganov opened this issue 1 month ago • 15 comments

Overview

This PR is an intermediate step towards a more generic implementation that will support different underlying implementations of llama_kv_cache, llama_context and the graph building logic (a.k.a. llm_build_context). The llama_kv_cache is also introduced in the public API as an object, but it's actual functionality is yet to be defined in follow-up PRs.

Currently, no functional changes have been introduced. Mainly the code has been reorganized in a way to allow implementing new abstractions. The main changes in the implementation are:

  • Avoid all explicit references to llama_kv_cache in llm_build_context. The goal is to be able to construct the computation graphs only through the abstract llama_context interface, which will hide the actual KV cache implementation and thus allow to be overloaded based on the parameters of the specific use case. More generally, the llama_context hides not only the KV cache implementation, but all the internal state (such as, applied adapters, masks, etc. if any) with the exception of the model weights - these are still available to the llm_build_context in order to be able to construct the backbone graph of the various architectures.
  • Avoid all explicit references to llama_kv_cache in llama_decode/llama_encode. These are abstracted through a new object llama_batch_manager which is produced by the current llama_context. Again the goal is to not make explicit assumptions about the underlying KV cache implementation while processing the batches and be able to delegate this logic to the llama_context. The llama_batch_manager is produced by the llama_context and will handle logic such as, restoring the KV cache state to consistent state upon errors, batching the input batch into micro batches according to the internal processing logic, etc.
  • Add initial serialization primitives to llama_kv_cache. In the future, these will be overloaded for the specific KV cache implementations through a common abstract interface.

The modifications so far are quite substantial and touch too many lines. Even though the code is in a very intermediate state, with many members still publicly exposed and without proper object-oriented implementation in place, it should still be mergeable.

The general class hierarchy that I have in mind is like this:

graph TD;
llama_kv_cache_unified --> llama_kv_cache;
llama_kv_cache_standard --> llama_kv_cache;
llama_kv_cache_mamba --> llama_kv_cache;
... --> llama_kv_cache;

Here, llama_kv_cache_unified is basically the llama_kv_cache implementation that we currently have. In the future, we will add more implementations that would be appropriate for multi-user scenarios (e.g. llama_kv_cache_standard) or for Mamba architectures (llama_kv_cache_mamba).

graph TD;
llama_context --> llama_model;
llama_context --> llama_cparams;
llama_context --> llama_adapter;
llama_context --> etc..;

llama_context[<b>llama_context</b>];

llama_context_no_kv[<b>llama_context_no_kv</b><br><br>];
llama_context_unified[<b>llama_context_unified</b><br><br>llama_kv_cache_unified];
llama_context_standard[<b>llama_context_standard</b><br><br>llama_kv_cache_standard];
llama_context_mamba[<b>llama_context_mamba</b><br><br>llama_kv_cache_mamba];
llama_context_enc_dec[<b>llama_context_enc_dec</b><br><br>llama_kv_cache_standard];

llama_context_no_kv -.-> llama_context;
llama_context_unified -.-> llama_context;
llama_context_standard -.-> llama_context;
llama_context_mamba -.-> llama_context;
llama_context_enc_dec -.-> llama_context;
... -.-> llama_context;

The base llama_context class will implement common functionality such as low-level ggml buffer and backend management + adapters, without the notion of a KV cache. The derived classes will specialize the llama_context for different use-cases.

The llm_build_context would operate only through the llama_build_i interface and the batch processing will respectively only interact with the llama_batch_manager_i interface. The type of llama_context to construct in functions such as llama_init_from_model() would be determined based on the model and the specified context parameters. For example, the user would be able to create both llama_context_unified and llama_context_standard for a LLM_ARCH_QWEN2 model. Or a llama_context_no_kv for an encoding-only LLM_ARCH_BERT model. And so on.

API changes

The current changes are only necessary to make the API more consistent in following the naming convention. To migrate, simply replace the old API calls with the new ones.

  • Deprecate llama_kv_cache_... API
  • Add llama_kv_self_... API

In the future, the llama_kv_cache_... API will be changed to work with struct llama_kv_cache instead of struct llama_context and the functionality will be extended to support things like saving, copying, loading, etc.

Notes

  • [x] Fix build_qwen2vl, inp_pos, lctx.n_pos_per_token hack
  • [x] Worst case for n_outputs and n_outputs_enc in llm_build_context seem incorrect
  • [x] Remove inp_s_seq - not used
  • [x] fix
    onst bool           is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
    truct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
    
  • [ ] Fix T5
  • [x] Fix RWKV
  • [ ] Fix batch.pos == NULL - llama_context::pos_max() is used incorrectly
  • [x] Dedup the reserve code
  • [x] Errors on unimplemented interface
  • [ ] Build multiple graphs per model (e.g. enc, dec, no-logits, etc.)
  • [x] Implement causal input for cache-less llama_context
  • [ ] Simplify encode()/decode()
  • [x] Remove worst_case from the llama_graph_i API?
  • [x] Wrap input tensors in structs
  • [x] Add trace logs

PRs to resolve

  • [x] https://github.com/ggerganov/llama.cpp/pull/11381 (e665b57)
  • [ ] https://github.com/ggerganov/llama.cpp/pull/11446
  • [x] https://github.com/ggerganov/llama.cpp/pull/11445
  • [x] https://github.com/ggerganov/llama.cpp/pull/10573

New features

  • [ ] https://github.com/ggerganov/llama.cpp/pull/11571

ggerganov avatar Jan 13 '25 12:01 ggerganov