A question on your implementation of decoder phase of llama
Recently I have been studying your code. However, It seems to me that your implemention will not expand the kv cache during the decoding phase. The follow code is excerpted from the function def _concatenate_to_cache in llama.py.
if query.shape[1] == 1:
mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
def fn(cached_key, cached_value, key, value, cur_index):
assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
sp_size = max_length // mesh.shape['sp']
axis_index = jax.lax.axis_index('sp')
cur_index = cur_index - axis_index * sp_size
key, value = jax.lax.cond(
jnp.logical_and(cur_index >= 0, cur_index < sp_size),
lambda: (
cached_key.at[:, cur_index].set(key[:, -1]),
cached_value.at[:, cur_index].set(value[:, -1]),
),
lambda: (cached_key, cached_value),
)
return key, value
In this function, we will only update cached_key and cached_value with the newly-generated key/value in the decoding phase, instead of pushing back them into the cached_key and cached_value. However, it seems to me that a correct implementation of kvcache should make the size of kvcache grow and become longer.
Maybe I do not fully understand your code, but I am looking forward to your reply.
Hi. I have another question about the decoder phase. I print the hidden_states.shape in https://github.com/LargeWorldModel/LWM/blob/f45d2b70bda27abfa9cf32e228916b2883801366/lwm/llama.py#L977 And I find sometime the result is
(512, 8192, 4096)
(2, 385, 4096)
(2, 128, 4096)
(2, 1, 4096)
In my opinion, if decoder is AR structure, the output of decoder should be one by one just like hidden_states.shape will be always like (2, 1, 4096) in only one forward.
Do you know "Does the result prove the generation of 1 frame need only one forward?".
Thank your reply
Hi. I am not the author of this paper. But it seems that you do not know the size of KVCahe could be stable. You just need to change the mask. jusk like that; T1: KVcache: [kv1, kv2, 0, 0, 0, 0] Mask: [1, 1, 0, 0, 0, 0] T2: KVcache: [kv1, kv2, kv3, 0, 0, 0] Mask: [1, 1, 1, 0, 0, 0]