torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Initialize kv cache w/num_kv_heads instead of num_heads

Open rohan-varma opened this issue 2 years ago • 1 comments

This will save memory for GQA / MQA, but will require a bit of refactor to attention forward pass.

rohan-varma avatar Nov 17 '23 11:11 rohan-varma

@rohan-varma I think we can close this, no?

joecummings avatar Apr 16 '24 02:04 joecummings

I don't think so, as we still use the general n_heads and not num_kv_heads (where n_heads > num_kv_heads for GQA / MQA). @kartikayk Any context as to why this issue was closed?

rohan-varma avatar May 03 '24 01:05 rohan-varma

Why do we want to use num_kv_heads?

kartikayk avatar May 03 '24 14:05 kartikayk

@kartikayk This will reduce memory usage during inference. For the cache shape: https://github.com/pytorch/torchtune/blob/main/torchtune/modules/kv_cache.py#L36, if we use n_kv_heads as opposed to num_heads, there's 8x less memory usage by the kv cache.

rohan-varma avatar May 06 '24 07:05 rohan-varma