Initialize kv cache w/num_kv_heads instead of num_heads
This will save memory for GQA / MQA, but will require a bit of refactor to attention forward pass.
@rohan-varma I think we can close this, no?
I don't think so, as we still use the general n_heads and not num_kv_heads (where n_heads > num_kv_heads for GQA / MQA). @kartikayk Any context as to why this issue was closed?
Why do we want to use num_kv_heads?
@kartikayk This will reduce memory usage during inference. For the cache shape: https://github.com/pytorch/torchtune/blob/main/torchtune/modules/kv_cache.py#L36, if we use n_kv_heads as opposed to num_heads, there's 8x less memory usage by the kv cache.