exllama icon indicating copy to clipboard operation
exllama copied to clipboard

KV caching?

Open bryanhpchiang opened this issue 1 year ago • 2 comments

Where is it being done in the code?

bryanhpchiang avatar Aug 09 '23 01:08 bryanhpchiang

Hey, I can try to answer the question, seems it's here : image

sleepwalker2017 avatar Aug 09 '23 02:08 sleepwalker2017

Sorry, I apparently missed this one. The cache is contained in the ExLlamaCache class, which is just a wrapper for two lists of preallocated tensors, one pair for each layer of the model.

Caching is performed in the attention function, here:

        # Add keys and values to cache

        new_keys = cache.key_states[self.index].narrow(2, past_len, q_len).narrow(0, 0, bsz)
        new_values = cache.value_states[self.index].narrow(2, past_len, q_len).narrow(0, 0, bsz)
        new_keys.copy_(key_states)
        new_values.copy_(value_states)

Which creates a narrow view on the K/V cache for the given layer, then copies the keys and values computed for the current hidden state into it. Then it takes another view on the cache tensors to feed into the attention step:


        # Key/value tensors with past

        key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
        value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)

This is the regular HF-like version of the function. The faster C++ implementation is called in ExLlamaAttention.fused(), where the cuda_ext.exllama_ext.q4_attn() function does the same copy operation, but with a custom kernel defined in q4_attn.cu.

turboderp avatar Sep 04 '23 08:09 turboderp