NATTEN
NATTEN copied to clipboard
Grouped Query Attention without repeat()
Grouped Query Attention improves parameter-efficiency of attention KV projections and reduces IO at inference-time, making inference faster.
It can be implemented naïvely by just repeating K and V along the head dim. I did so here:
https://github.com/Birch-san/natten-fwd-ad/blob/gqa/src/natten_block.py
I am not sure whether a repeat()
or repeat_interleave()
is preferred, but probably "whatever Llama did" can be considered the standard practice.
But perhaps instead of incurring the IO cost of repeat()
: we could direct the matmul to visit the K kv_groups
times?
For example, if we unflattened Q's head dim into (groups, heads):
Q=(1, q_heads)
we can do the same with KV, then expand()
its groups dim by the number of kv_groups
:
K=V=(1, kv_heads)
K=V=(kv_groups, kv_heads)
I'm not sure whether that expand()
is free, but if it is: then we would need the NATTEN API to accept arguments with these kinds of shapes:
# [batch, groups, heads, hei, wid, channels]
Q=[1, 1, 6, 128, 128, 64]
K=[1, 2, 3, 128, 128, 64]
V=[1, 2, 3, 128, 128, 64]
where so long as [groups, heads] flattens to the same amount (6), the arguments would be allowed.
maybe the user would need to tell the NATTEN API whether to access the groups via a repeat()
or repeat_interleaved()
access pattern. or maybe only one of those access patterns makes sense to support perf-wise.
does it sound like a speedup is possible here compared to just naïve repeat?
note: I don't think scaled_dot_product_attention
supports this kind of thing (I tried it but it rejected the tensor shapes). so this isn't a parity item.