flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

where is flash decoding second stage (reduce) code ?

Open liuqi123123 opened this issue 1 year ago • 1 comments
trafficstars

according to https://pytorch.org/blog/flash-decoding/ , flash decoding is dual stage, the second stage is "reduce && rescale contribution of each split", but I can't find the reduce kernel after kernel "compute_attn_1rowblock_splitkv", where is it ?

liuqi123123 avatar Sep 27 '24 03:09 liuqi123123

https://github.com/Dao-AILab/flash-attention/blob/53a4f341634fcbc96bb999a3c804c192ea14f2ea/csrc/flash_attn/src/flash_fwd_kernel.h#L1108

tridao avatar Sep 27 '24 09:09 tridao

What’s the best way to trigger the flash-decoding path when using flash_fwd_splitkv_kernel(...)? Is it correct to set num_splits = 0 and let the heuristics decide automatically?

For flash-decoding, is the num_splits_heuristics function the recommended way to determine the optimal split for flash-decoding? I tried hardcoding num_splits to 2, 4, and 8, but saw worse results on an A100 (batch size 8, 48 q heads, new seqlen between 1 and 10): even though the heuristics calculate num_splits == 1 as the best choice, it seems combining multiple q heads with new seqlen in one GEMM is better at maximizing the TC utilization in my case? Thanks a lot in advance for your insights!

btw, the flash-decoding release notes mentioned a minimal example, but the link still leads to this "coming soon" page: https://github.com/Dao-AILab/flash-attention/tree/main/examples/inference. note sure if the link is still valid.

SimpleTheoryOfTypes avatar Nov 05 '24 07:11 SimpleTheoryOfTypes

If num_splits = 0 we use a heuristic to decide if we should split (and how many splits). If batch = 8 and 48 q heads, there are 8 x 48 = 384 pieces of parallel work, more than the number of SMs on A100 (108). So there's no reason to split. Usually split is needed if there's not enough parallel work to assign to all the SMs.

tridao avatar Nov 05 '24 07:11 tridao

Is it feasible to vectorize the S=QK^T and SV GEMMs along the batch dimension in flash decoding? For example, during decoding, the query q has a shape of [b, 1, 48, 128], and a KV tile has a shape of [b, 64, 48, 128], where 48 represents the number of attention heads, 128 is the head dimension, and 64 is kBlockN, with only 1 token being decoded at a time.

In the current implementation, a separate flash::gemm is run for each batch, resulting in GEMM shapes like:

q: 1 x 128 K: 64 x 128

Here, flash::gemm computes a q: kBlockM=64 x 128 by K: kBlockN=64 x 128 GEMM, but only the first row of the result is used, while the remaining 63 rows are discarded (partition utilization 1/ 64 = 1.6%).

Would it be feasible to perform the following using a single flash::gemm for the entire batch in flash decoding, i.e.,:

q: (b, 1) x 128 K: (b, 64) x 128

This approach could potentially improve tensor core partition utilization by a factor of b, as it would allow us to keep b rows of the output tensor instead of just 1.

If flash decoding can already do this batch dim vectorization , how to enable it? Thanks!

SimpleTheoryOfTypes avatar Nov 06 '24 00:11 SimpleTheoryOfTypes

No that doesn't increase tensor core util. The operation is mem bound any way (you can measure that its speed is close to memcpy) so it doens't matter that we're doing extra compute

tridao avatar Nov 06 '24 00:11 tridao

Given that, with small batch sizes, the attention kernel during decoding is memory-bound, why would maximizing SM utilization by creating more parallel work along the sequence dimension still lead to improved latency?

flash decoding does help a lot with small batch size (<5) decoding, just wanted to verify my understanding: Flash decoding’s main optimization appears to be optimizing compute unit utilization, which seems at odds with the fact that the attention kernel during decoding is memory-bound. Is it because the scheduling is suboptimal?

SimpleTheoryOfTypes avatar Nov 19 '24 22:11 SimpleTheoryOfTypes

Mem bound here means most of the time is spent waiting for memory to be loaded from global memory. You want more thread blocks issuing load instructions. If batch size = 1, seqlen_q = 1, nheads = 16, then you only have 16 thread blocks issuing loads (out of 108 or 132 SMs). So you're not saturating memory bandwidth. You want to parallelize along the seqlen_k dimension so that more thread blocks are issuing loads and saturate mem bw

tridao avatar Nov 19 '24 22:11 tridao

Thanks a lot for the explanation! that makes sense, flash decoding also optimizes memory bandwidth by creating more parallel LD/ST instructions.

SimpleTheoryOfTypes avatar Nov 19 '24 23:11 SimpleTheoryOfTypes

For those who don't know how to do the reduction: each split i outputs O_i and LSE_i, then you can get the final output by

LSE_final = log(sum(exp(LSE_i)))
O_final = sum(exp(LSE_i - LSE_final) * O_i)

In the CUDA implementation, it do the "logsumexp trick" again -- subtract the max(LSE_i) to avoid overflow.

ZJUGuoShuai avatar Dec 30 '24 15:12 ZJUGuoShuai

Hi @tridao There’s one thing that has been bothering me. Why do we use log-sum-exp trick instead of exp-normalize trick? As exp-normalize-trick said, softmax is our goal, exp-normalize trick is more intuitive? 🤔

Image

xiaofeihan1 avatar Sep 28 '25 13:09 xiaofeihan1

They're fundamentally the same thing. We use them.

tridao avatar Sep 28 '25 14:09 tridao