NATTEN icon indicating copy to clipboard operation
NATTEN copied to clipboard

Grouped Query Attention without repeat()

Open Birch-san opened this issue 5 months ago • 15 comments

Grouped Query Attention improves parameter-efficiency of attention KV projections and reduces IO at inference-time, making inference faster.

It can be implemented naïvely by just repeating K and V along the head dim. I did so here:
https://github.com/Birch-san/natten-fwd-ad/blob/gqa/src/natten_block.py I am not sure whether a repeat() or repeat_interleave() is preferred, but probably "whatever Llama did" can be considered the standard practice.

But perhaps instead of incurring the IO cost of repeat(): we could direct the matmul to visit the K kv_groups times?

For example, if we unflattened Q's head dim into (groups, heads):

Q=(1, q_heads)

we can do the same with KV, then expand() its groups dim by the number of kv_groups:

K=V=(1, kv_heads)
K=V=(kv_groups, kv_heads)

I'm not sure whether that expand() is free, but if it is: then we would need the NATTEN API to accept arguments with these kinds of shapes:

# [batch, groups, heads, hei, wid, channels]
Q=[1, 1, 6, 128, 128, 64]
K=[1, 2, 3, 128, 128, 64]
V=[1, 2, 3, 128, 128, 64]

where so long as [groups, heads] flattens to the same amount (6), the arguments would be allowed.

maybe the user would need to tell the NATTEN API whether to access the groups via a repeat() or repeat_interleaved() access pattern. or maybe only one of those access patterns makes sense to support perf-wise.

does it sound like a speedup is possible here compared to just naïve repeat?

note: I don't think scaled_dot_product_attention supports this kind of thing (I tried it but it rejected the tensor shapes). so this isn't a parity item.

Birch-san avatar Jan 23 '24 00:01 Birch-san