outlines icon indicating copy to clipboard operation
outlines copied to clipboard

Expose kv_cache in generator API

Open gautierdag opened this issue 1 year ago • 1 comments

Presentation of the new feature

I'm dealing with chat/agents long contexts, where the context grows with each interaction. An easy optimisation is to keep the KV_cache in memory. This can be done in the naive transformers library by passing past_key_values and is in fact done under the hood in outlines.

The problem is that outlines does not expose this functionality and therefore the model has to recompute the kv_cache after every interaction (user chat message).

Where does it fit in Outlines?

I think the easiest way to fix this would be to expose the kv_cache variable in sequence_generator and set it as a function parameter instead.

def sequence_generator(
    model: Callable,
    sampler: Callable,
    fsms: List["Guide"],
    token_ids: torch.Tensor,
    sequence_weights: torch.Tensor,
    attention_masks: torch.Tensor,
    fsm_states: List[int],
    rng: torch.Generator = torch.Generator(),
    kv_cache = None # add this
) -> Iterator[GenerationState]:

Then do the same for the SequenceGenerator class.

Downside, is that this doesn't really apply to other models and there are already a lot of arguments being passed around.

Are you willing to open a PR?

Yes, just curious to hear if that something that would be accepted before I draft this.

Distantly related to #667 #452

gautierdag avatar Mar 21 '24 18:03 gautierdag

I'm currently working on https://github.com/outlines-dev/outlines/issues/667, this is something I've done.

miftahmoha avatar Mar 23 '24 22:03 miftahmoha