flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Pipelining GmemCopy on kHeadDim

Open phantaurus opened this issue 1 year ago • 6 comments

Hello!

I'm currently checking flash attention v2 and noticed that when copying from global memory to shared memory, the entire HeadDim (the K dimension in MNK tiling) needs to be copied to shared memory in full, then synchronize, followed by shared memory to register copy and MMA operation.

The pseudo code for QxK looks like this:

 // Copy from gmem to smem for Q
 // Copy from gmem to smem for K

for(n_blocks) {
   ...
   // Sync threads

    // Smem to reg copy for Q(_,_,0)
    // Smem to reg copy for K(_,_,0)
   for(int i in size<2>(Q)) {
        // Smem to reg copy for Q(_,_,i+1)
        // Smem to reg copy for K(_,_,i+1) 
        gemm(Q(_,_,i), K(_,_,i));
   }   
   // Sync threads

   // Update gmem address for K
   // Copy the new block for K
}

Is it possible to partition HeadDim into smaller chunks and pipeline the global memory copy process? For example, there can be two sets of shared memory and registers for the K tensor, each handling HeadDim/2. While the first set processes the MMA operation, the second set could handle the global memory copy concurrently, without waiting for the MMA to finish.

Screenshot 2024-09-05 at 1 28 52 PM

Do you think this approach would be beneficial, or are there other factors in the system that might undermine its effectiveness? I appreciate your insight!

phantaurus avatar Sep 05 '24 20:09 phantaurus

for gmem copy you just want to issue the instructions and then do other work. Currently I don't think Gmem copy is slowing things down? It's possible to pipeline but i'm not sure how much benefit you can get. You should look at nsight compute to see what the bottlenecks are

tridao avatar Sep 05 '24 20:09 tridao

Thank you so much for your reply!

I think you're describing an ideal asynchronous scenario: If the system has sufficient parallelism, one pipeline—whether compute or memory—becomes saturated, while the other remains underutilized. In this case, pipelining wouldn't help with the saturated pipeline, since its capacity is the limiting factor. Pipelining wouldn't benefit the underutilized pipeline either, as its operations can just be scheduled flexibly without impacting the overall latency in any way.

However, in scenarios with insufficient parallelism, such as when GPU occupancy is low, and when synchronization occur relatively frequently in the workload, then it might not be an ideal asynchronous scenario. The workload might become latency-bound. In these cases, pipelining can help by introducing more parallelism.

I'm still getting familiar with Nsight Compute. I will search for the metrics indicate whether the workload is latency-bound.

( I apologize if my earlier description was misleading. We can actually pipeline the entire workflow If we choose to pipeline on kBlockN instead of kHeadDim, including shared memory to register copy and MMA operations, not only the GmemCopy. )

phantaurus avatar Sep 06 '24 18:09 phantaurus

Yeah, thinking more about it, on 4090 we should be able to get 70%+ tensor core util. The current version (e.g. FA2) might not get there maybe because our existing pipelining isn't good enough. It was optimized for A100 at the time, there's probalby quite a bit of head room for 4090.

tridao avatar Sep 06 '24 20:09 tridao

I have confirmed that memory copies, whether from global memory (GmemCopy) or shared memory (SmemCpy), do not significantly impact the 50% TensorCore Active %. I removed all data copying operations, reducing the main loop to perform only the two GEMM operations directly on registers. Despite this, the Tensor Core Active % still maxes out at 50%.

Even leaving just one single GEMM operation, Nsight Compute still reports 50% Tensor Core %.

I then considered the possibility of a data dependency on acc register caused by iterating over the K dimension.

for (int i = 0; i < size<2>(tCrA); ++i) {
    cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
  }

To test this, I created two sets of acc registers, making them run in parallel, but this still generates 50% TensorCore Active %.

Profiling the workload with Nsight Compute shows that the primary stall reason is "Math Pipe Throttle." Based on my understanding, this suggests that the Tensor Core is oversubscribed...

At this point, I’m quite confused and unsure of what might be causing this 50% limit.

phantaurus avatar Sep 09 '24 01:09 phantaurus

Do you know what TFLOPS you get with FA2 on 4090? The 4090 theoretical max is 165 TFLOPS if using fp32 accumulator and 330 TFLOPS is using fp16 acumulator. We use fp32 accumulator. If you measure FA2 and get around 160 TFLOPS, then it's probably because ncu is reporting %TC based on the 330 TFLOPS max. If you measure FA2 and get around 80 TFLOPS, then the issue is elsewhere.

tridao avatar Sep 09 '24 01:09 tridao

Ah, I see. I am measuring based on Max FP16 TFLOPS. The numbers make a lot more sense now. I guess we have to use FP32 for softmax, so achieving around 50% of the Max FP16 TFLOPS is the best throughput we can expect. In that case, flash attention v2 actually performs very well in our models! On the 30x, 40x, and Orin GPUs, it typically reaches over 45% TensorCore Active %.

phantaurus avatar Sep 09 '24 02:09 phantaurus