LWM icon indicating copy to clipboard operation
LWM copied to clipboard

A question on your implementation of decoder phase of llama

Open wangtianxia-sjtu opened this issue 1 year ago • 2 comments

Recently I have been studying your code. However, It seems to me that your implemention will not expand the kv cache during the decoding phase. The follow code is excerpted from the function def _concatenate_to_cache in llama.py.

if query.shape[1] == 1:
    mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
    def fn(cached_key, cached_value, key, value, cur_index):
        assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
        sp_size = max_length // mesh.shape['sp']
        axis_index = jax.lax.axis_index('sp')
        cur_index = cur_index - axis_index * sp_size
        key, value = jax.lax.cond(
            jnp.logical_and(cur_index >= 0, cur_index < sp_size),
            lambda: (
                cached_key.at[:, cur_index].set(key[:, -1]),
                cached_value.at[:, cur_index].set(value[:, -1]),
            ),
            lambda: (cached_key, cached_value),
        )
        return key, value

In this function, we will only update cached_key and cached_value with the newly-generated key/value in the decoding phase, instead of pushing back them into the cached_key and cached_value. However, it seems to me that a correct implementation of kvcache should make the size of kvcache grow and become longer.

Maybe I do not fully understand your code, but I am looking forward to your reply.

wangtianxia-sjtu avatar Jul 01 '24 09:07 wangtianxia-sjtu

Hi. I have another question about the decoder phase. I print the hidden_states.shape in https://github.com/LargeWorldModel/LWM/blob/f45d2b70bda27abfa9cf32e228916b2883801366/lwm/llama.py#L977 And I find sometime the result is

(512, 8192, 4096)

(2, 385, 4096)

(2, 128, 4096)

(2, 1, 4096)

In my opinion, if decoder is AR structure, the output of decoder should be one by one just like hidden_states.shape will be always like (2, 1, 4096) in only one forward.

Do you know "Does the result prove the generation of 1 frame need only one forward?".

Thank your reply

Haodong-Lei-Ray avatar Apr 26 '25 12:04 Haodong-Lei-Ray

Hi. I am not the author of this paper. But it seems that you do not know the size of KVCahe could be stable. You just need to change the mask. jusk like that; T1: KVcache: [kv1, kv2, 0, 0, 0, 0] Mask: [1, 1, 0, 0, 0, 0] T2: KVcache: [kv1, kv2, kv3, 0, 0, 0] Mask: [1, 1, 1, 0, 0, 0]

Haodong-Lei-Ray avatar May 05 '25 11:05 Haodong-Lei-Ray