Expose kv_cache in generator API
Presentation of the new feature
I'm dealing with chat/agents long contexts, where the context grows with each interaction. An easy optimisation is to keep the KV_cache in memory. This can be done in the naive transformers library by passing past_key_values and is in fact done under the hood in outlines.
The problem is that outlines does not expose this functionality and therefore the model has to recompute the kv_cache after every interaction (user chat message).
Where does it fit in Outlines?
I think the easiest way to fix this would be to expose the kv_cache variable in sequence_generator and set it as a function parameter instead.
def sequence_generator(
model: Callable,
sampler: Callable,
fsms: List["Guide"],
token_ids: torch.Tensor,
sequence_weights: torch.Tensor,
attention_masks: torch.Tensor,
fsm_states: List[int],
rng: torch.Generator = torch.Generator(),
kv_cache = None # add this
) -> Iterator[GenerationState]:
Then do the same for the SequenceGenerator class.
Downside, is that this doesn't really apply to other models and there are already a lot of arguments being passed around.
Are you willing to open a PR?
Yes, just curious to hear if that something that would be accepted before I draft this.
Distantly related to #667 #452
I'm currently working on https://github.com/outlines-dev/outlines/issues/667, this is something I've done.