flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Question about the paper v2: How to parallelize along the sequence length ?

Open deltaguo opened this issue 2 years ago • 6 comments
trafficstars

I notice that the divsion of grid in v1: csrc/flash_attn/src/fmha_fwd_launch_template.h line 86

dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);

which num_splits must be computed by batch_size and head_num, to find the best efficiency for the number of SMs. The division of grid in v2: csrc/flash_attn/src/fmha_fwd_launch_template.h line 27, 28

const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);

Can I understand that the grid definition of v2 has nothing to do with the number of SMs ? So what is the difference between the seq_len parallelism mentioned in v2 and the third dimension of the grid in v1? I think v1 is also parallelize along seq_len, is that right?

deltaguo avatar Aug 09 '23 05:08 deltaguo

Yes. The original FlashAttention implementation (May 2022) didn't have any seqlen parallelism. Later on (in code v1) we have a kind of parallelism in the forward pass where we decide up front how many splits, then each thread block takes care of (seqlen_q / num_splits) rows along the seqlen_q dimension. In v2, we parallelize: each thread block will take care of 128 rows (or sometimes 64 rows for large head dimension) along the seqlen_q dimension. This is simpler and works better.

tridao avatar Aug 09 '23 05:08 tridao

Thank you for your reply. This work is very attractive to me! I reproduced the results of paper on a10 by v1.0.7 and v2.0.7.

I can understand that during forward propagation, v2 has a great performance improvement over v1 because the loop order is exchanged, which reduces a lot of reading and writing HBM. But I still have a doubt that:

It is mentioned in the paper that there is not much difference between the backpropagation stages of v1 and v2. I observed the implementation of the code v1 and v2, both of them are parallel along seq_len_k, and the inner loop is along seq_len_q. But according to the experimental results, v2 still has a greater improvement than v1. In addition to not using split-k, what is the main reason that makes the performance increase?

deltaguo avatar Aug 15 '23 11:08 deltaguo

The improvement in the backward pass is a combination of factors:

  • Not using split-k, so we reduce amount of shared memory needed and shared memory read/write.
  • Better work partitioning between warps so that we can use more "square" block size (e.g. 128 x 128 or 64 x 128 in v2, instead of 16 x 128 or 16 x 256 in v2).
  • Some lower-level optimizations such as using the cp.async feature in Ampere.

tridao avatar Aug 15 '23 16:08 tridao

@deltaguo can you explain why looping over the queries in the outer loop reduces the number of HBM accesses? I think it's the same, no matter which way you go. Is it because of how the KV vectors are laid out in memory?

j93hahn avatar Jul 26 '24 05:07 j93hahn

@j93hahn The outer loop on the GPU is that each iteration will be executed on a different SM, which is parallel. When calculating the softmax, it is necessary to reduce the rows of the attention matrix. In order to ensure that these reduces can be completed on the same SM, each iteration should process a Q matrix block, the entire K and V, and obtain an O matrix block. So this is the outer loop along Q. If it is on K and V, the row reduction on the attention matrix needs to communicate through global memory, which will be time-consuming.

deltaguo avatar Jul 29 '24 02:07 deltaguo

Thanks for the reply! I read the FlashAttention-2 paper and the memory layout makes a lot more sense to me [particular the work-partitioning between warps]

j93hahn avatar Jul 29 '24 02:07 j93hahn