AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Implement flash decoding

Open pfultz2 opened this issue 3 months ago • 2 comments

Implement flash decoding as described here: https://pytorch.org/blog/flash-decoding/

We have attention operators grouped like this:

Q -> [B, M, k]
K -> [B, k, N]
V -> [B, N, D]

S = dot(Q, K)
P = softmax(S)
O = dot(P, V) # [B, M, D]

To do flash decoding we will need to add another batch dimension for each group we want to split, and then do:

Q -> [B, G, M, k] # G is a broadcasted dimension
K -> [B, G, k, N/G]
V -> [B, G, N/G, D]

# first kernel
S = dot(Q, K)
P = softmax(S, axis=-1)
L = LSE(S) # [B, G, M, 1]
O' = dot(P, V) # [B, G, M, D]

# second kernel
scale = softmax(L, axis=1) # [B, G, M, 1]
R = mul(O', broadcast(scale)) # [B, G, M, D]
O = sum(R, axis=1) # [B, 1, M, D]

We will probably do this directly in the fuse_attention pass after we have done the initial attention grouping.

pfultz2 avatar Sep 29 '25 17:09 pfultz2

Also, another reference we used when we were discussing the flash decoding: https://arxiv.org/pdf/2402.05099

pfultz2 avatar Sep 29 '25 17:09 pfultz2

rocMLIR experiments for good G sizes: rocMLIR#1895

bdevorem avatar Oct 09 '25 16:10 bdevorem