executorch
executorch copied to clipboard
kv cache manipulation?
Is it possible to manipulate the kv cache for llama models?
A common use case during inference is to strike/remove values from the kv cache when regenerating or editing generated outputs, so the llm does not need to decode from the beginning.
Are there any APIs available to do this right now? If not, can you give me a general pointer on what needs to be done? I'm happy to implement myself.
@l3utterfly to clarify, if we can do a "stack" style on the kv cache. For example, if there's original prompt and outputs. Later we can feed the model with the same prompt, but different output. To save calculation, we don't have to repopulate the original kv cache for the prompt (or any inputs that are the same), and only evict where the input is different?
We are happy if you could put a PR for us to review. @JacobSzwejbka or @larryliu0820 , do you know where's the good API/entry point for @l3utterfly to add this feature, or if we want to build this API first?
Yes. For reference, llama.cpp does these kinds of kv cache manipulations. Additionally, a great feature would be the ability to save and load kv caches.
I did the PR for the rollback/regenerate feature in llama.cpp, happy to implement something similar here if you can give me a quick pointer on where to add these APIs
@l3utterfly thanks for offering help! We have been talking about implementing different kv cache manipulation techniques but haven't got a chance to that part. For now you can look at how it is currently implemented:
https://github.com/pytorch/executorch/blob/main/examples/models/llama2/llama_transformer.py#L183
Feel free to experiment with it and send a PR
save and load kv caches.
Havent thought about the ability to mutate state from outside model execution. It should be possible. Let me think about how the apis would look, as the concept of what tensor is persistent state is not really available in the runtime today, nor do we serialize info that would be necessary to figure that out today.
One thing you could do would be put mutable buffers in the graph AoT onto their own mem-id with a custom memory plan, and then just copy into and out of that buffer in the runtime. So sidestep any tensor concepts and just mutate the arenas directly
How do you want to manipulate the cache/ what granularity do you look at? Are you going by like layer5.k_cache?
@JacobSzwejbka Currently I'm mainly focused on implementing the kv cache optimisation for transformer models (e.g. llama2 and 3). So I guess only the kv values of the attention heads would work. Are you thinking of a general API in executorch to get/set kv values for all layers?
Also I'm focusing on the c++ code because my goal is to run this on Android.
One thing you could do would be put mutable buffers in the graph AoT onto their own mem-id with a custom memory plan, and then just copy into and out of that buffer in the runtime. So sidestep any tensor concepts and just mutate the arenas directly
This sounds like a good workaround for implementing this feature. Can you recommend me a good place to put this code in executorch?
You first need to write a custom memory plan
https://github.com/pytorch/executorch/blob/main/exir/capture/_config.py#L48C38-L48C56
In that plan you need to identify mutable buffers as you iterate over the graph. This can be a little complex in the general case but for simple kv cache stuff I think this will work: https://github.com/pytorch/executorch/blob/main/exir/memory_planning.py#L308
Then finally on the runtime side you just need to mess around with the buffers you pass to https://github.com/pytorch/executorch/blob/main/runtime/core/hierarchical_allocator.h#L35
In the future we want to make all of these steps easier.
- Easy flag you can set to just automatically lift buffers to their own mem_ids
- A way to associate a string with a mem_id. For buffers like kvcache that string could be the fqn.
- A more direct api to update buffers at runtime and even swap the buffers to a different one after initialization. This would let you "load" a kv cache without a copy even after init.
This is a bit of a high level overview, but hopefully its enough to let you get started. If you have more questions feel free to post here and tag me and Ill try and help.
cc @mikekgfb @iseeyuan
Thank you so much for the pointers! I will try this
From: Jacob Szwejbka @.> Sent: Friday, May 10, 2024 3:03:06 AM To: pytorch/executorch @.> Cc: l3utterfly @.>; Mention @.> Subject: Re: [pytorch/executorch] kv cache manipulation? (Issue #3518)
You first need to write a custom memory plan
https://github.com/pytorch/executorch/blob/main/exir/capture/_config.py#L48C38-L48C56
In that plan you need to identify mutable buffers as you iterate over the graph. This can be a little complex in the general case but for simple kv cache stuff I think this will work: https://github.com/pytorch/executorch/blob/main/exir/memory_planning.py#L308
Then finally on the runtime side you just need to mess around with the buffers you pass to https://github.com/pytorch/executorch/blob/main/runtime/core/hierarchical_allocator.h#L35
In the future we want to make all of these steps easier.
- Easy flag you can set to just automatically lift buffers to their own mem_ids
- A way to associate a string with a mem_id. For buffers like kvcache that string could be the fqn.
- A more direct api to update buffers at runtime and even swap the buffers to a different one after initialization. This would let you "load" a kv cache without a copy even after init.
cc @mikekgfbhttps://github.com/mikekgfb @iseeyuanhttps://github.com/iseeyuan
— Reply to this email directly, view it on GitHubhttps://github.com/pytorch/executorch/issues/3518#issuecomment-2103162996, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABVFX24D76YGX4WTWEMILX3ZBO25VAVCNFSM6AAAAABHJDMJLWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBTGE3DEOJZGY. You are receiving this because you were mentioned.Message ID: @.***>
Thanks for putting this together @JacobSzwejbka ! And thanks so much for your contribution to executorch @l3utterfly ! So excited too see your work with Layla and GGML and super excited to have you become a contributor to executorch!!!!
As always, thanks for your support and leadership @iseeyuan !
@JacobSzwejbka Following your pointers and reading through Executorch's documentation several times, I've managed to implement the custom memory planner and obtained the names of the kv_cache arenas via:
class KVMemIdMemoryPlanningPass(MemoryPlanningPass):
def run(self, graph_module: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature]) -> PassResult:
for subgm in graph_module.modules():
if not isinstance(subgm, torch.fx.GraphModule):
continue
for node in subgm.graph.nodes:
if _is_mutable_buffer(node, graph_signature):
print(f"Mutable buffer found: {node}")
return super().run(graph_module, graph_signature)
I'm a little stuck on how should I manipulate this on the c++ side in hierarchical_allocator.h
. I see a method to get a memory buffer by passing the memory_id
. This seems to be the buffer index, and I need the offset
and size
of the buffer.
- How do I determine the memory id, I see you can set the "mem_id" by doing:
node.meta["spec"].mem_id = 1
. So does this mean this node will be inbuffer[1]
? (https://pytorch.org/executorch/stable/compiler-memory-planning.html) - How do I get the size of the buffer, is that fixed by calculating the
sizeof(float) * tensor dimensions
? - After obtaining the buffer via
hierarchical_allocator->get_offset_address
, is modifying the contents via the pointer sufficient? Do I need to handle syncing the buffer with the underlying hardware (GPU, DSP, etc.)?
1 and 2. I think you are thinking about it a bit backwards, it might be because you are using the module class api? But you should have the buffers before creating the hierarchical_allocator. The order is 1. Get buffer -> 2. Create Allocator -> 3. Create hierarchical_allocator from allocators. Here is a link to get the expected size and number of buffers expected in the heirarchical_allocator. https://github.com/pytorch/executorch/blob/main/runtime/executor/program.h#L152
- You should just be able to mutate the contents through pointer. Im not aware of any delegates today that consume the buffer, but if/when its added this approach will probably start failing because we dont have any good apis today to allow mutation of delegates from outside.
Is this the place I determine which buffers are for the kv cache?
Error Module::load_method(const std::string& method_name) {
if (!is_method_loaded(method_name)) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
MethodHolder method_holder;
const auto method_metadata =
ET_UNWRAP(program_->method_meta(method_name.c_str()));
const auto planned_buffersCount =
method_metadata.num_memory_planned_buffers();
method_holder.planned_buffers.reserve(planned_buffersCount);
method_holder.planned_spans.reserve(planned_buffersCount);
for (auto index = 0; index < planned_buffersCount; ++index) {
const auto buffer_size =
method_metadata.memory_planned_buffer_size(index).get();
method_holder.planned_buffers.emplace_back(buffer_size);
method_holder.planned_spans.emplace_back(
method_holder.planned_buffers.back().data(), buffer_size);
}
method_holder.planned_memory = std::make_unique<HierarchicalAllocator>(Span(
method_holder.planned_spans.data(),
method_holder.planned_spans.size()));
method_holder.memory_manager = std::make_unique<MemoryManager>(
memory_allocator_.get(), method_holder.planned_memory.get());
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
method_name.c_str(),
method_holder.memory_manager.get(),
event_tracer_.get()));
methods_.emplace(method_name, std::move(method_holder));
}
return Error::Ok;
}
@JacobSzwejbka I have managed to get a naive implementation of kv cache save + load working.
I have a question:
My kv cache buffer is in index 0 in the planned_buffers
in the method_holder
.
This code obtains the piece of memory correctly:
// copy the contents of kv cache at each step
auto kv_cache_buffer = module_->methods_["forward"].planned_buffers[0];
kv_cache_buffers[pos].assign(kv_cache_buffer.begin(), kv_cache_buffer.end()) ;
Which I'm simply copying to a cache variable at each step for now.
I'm trying to use the provided methods in the memory manager et al., but they seem to return different results:
// copy the contents of kv cache at each step
auto kv_cache_buffer = module_->methods_["forward"].memory_manager->planned_memory()->get_offset_address(0, 0, buffer_size).get();
I looked through the load_method
code above, memory_manager
, planned_memory
(HierarchicalAllocator) seems to be just holding the pointers to the planned_buffer
, so I'm unsure why would they return different results. It seems I'm misunderstanding something in the get_offset_address
function
Sorry for the delay I've been out of town and off the grid.
auto kv_cache_buffer = module_->methods_["forward"].memory_manager->planned_memory()->get_offset_address(0, 0, buffer_size).get();
We have some legacy code where mem_id 0 is reserved for reasons that dont make sense anymore. @dbort put up some code a while back to hide this from users as best we could. You might be running into that when you go directly through the allocator apis to get the buffer offset (maybe mem id 1 would work for you). If the first code you linked works I would just continue using that.
@l3utterfly thanks for offering help! We have been talking about implementing different kv cache manipulation techniques but haven't got a chance to that part.
You mentioned implementing different kv cache manipulation techniques, does that mean there already exists APIs to use the kv cache for LLM inference in executorch ? Or are the new APIs being worked on in this issue?
You mentioned implementing different kv cache manipulation techniques, does that mean there already exists APIs to use the kv cache for LLM inference in executorch ? Or are the new APIs being worked on in this issue?
ExecuTorch already supports kv-cache enabled Llama if thats what you are asking? You can see this under examples/ in the ET repo or in the torchchat repo which just launched. @prashantaithal
Thanks for the response, you must mean llama_transformer.py.
I had a basic question, is kv cache part of the pytorch llama model(*.pt/pth) or is it implemented separately outside of the model? I ask this after looking at the python script (linked above), I see the model's Attention class using the KVCache.
Apologies for hijacking this thread, but I could not find any "discussions" tab in the repositories.
Yes its embedded within the .pte as hidden model state based on the implementation in llama_transformer.py. We had an earlier version running that lifted the cache to model IO as well, so if you want easier manipulation of the cache outside the model that approach might work better for you.
The current implementation in llama_transformer just better aligns with how we had seen people authoring models (stateful) so we wanted to show support for that.
Thanks for the response. Where can I find the earlier version?