flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Softmax (particularly exp operations) becomes a major bottleneck in full FP16 pipeline

Open phantaurus opened this issue 1 year ago • 6 comments

Hello!

I have recently switched to a full FP16 pipeline using MMA F16F16F16F16. In such cases, softmax becomes the bottleneck of TensorCore Active %, particularly when running in embedded GPUs such as Orin.

Latency bottleneck lies particularly in the following code block in function scale_apply_exp2:

    for (int mi = 0; mi < size<0>(tensor); ++mi) {
        // If max is -inf, then all elements must have been -inf (possibly due to masking).
        // We don't want (-inf - (-inf)) since that would give NaN.
        // If we don't have float around M_LOG2E the multiplication is done in fp64.
        const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
        #pragma unroll
        for (int ni = 0; ni < size<1>(tensor); ++ni)  {
            // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
            // max * log_2(e)) This allows the compiler to use the ffma
            // instruction instead of fadd and fmul separately.
            // The following macro will disable the use of fma.
            // See: https://github.com/pytorch/pytorch/issues/121558 for more details
            // This macro is set in PyTorch and not FlashAttention
            #ifdef UNFUSE_FMA
                tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
            #else
                tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
            #endif
        }
    }

Simply deleting this function would increase TensorCore Active % from ~65% to ~90%. (Further removing warp-wise reduction operations will make Tensor Core Active % reach ~95%. Therefore the Allreduce function is not the major latency bottleneck in this scenario)

I have modified the pipeline to use FP16 exp operation hexp2 (which brings TensorCore Active % from ~50% to ~65%) I also tried changing the order of computation in this loop to make sure values belonging to one AtomMMA get computed first, before moving on to the other values. That doesn't help at all. Nsight compute result also shows that the exp pipeline is the bottleneck, such that no matter how we rearrange the order of exp and gemm operations, the workload is simply constrained by the capacity of the exp pipeline. I'm particularly shocked that in Orin, hexp2 is this slow, since with the same implementation in 40series GPUs, TensorCore Active % can easily hit > 85%.

So I'm basically stuck here and there doesn't seem to be much I can do now. I'm wondering if you have any suggestions or insights to share?

Thank you so much!

phantaurus avatar Sep 13 '24 22:09 phantaurus

Yup exp is one of the bottlenecks. We talked about that a bit in the FA3 paper.

tridao avatar Sep 13 '24 22:09 tridao

exp uses the MUFU (multi-function unit), which has quite low throughput (e.g. 16 ops per clock cycle, which is 4-8x lower than add / mul floating point operations). https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions

tridao avatar Sep 13 '24 22:09 tridao

Thank you so much for your response! I recently read the FA3 paper. I guess GPUs like Orin do not have async MMA operations. Do you think the ping-pong structure described in the paper could be adapted for synchronized MMA operations? It seems that even with synchronized MMAs, TensorCore and MUFU can still overlap once instructions are launched, as long as there's no data dependency. It seems possible to write kernels that explicitly instantiate two sets of registers and mimic the behavior shown in this figure.

Screenshot 2024-09-13 at 3 55 19 PM

However, you did mention in the paper that the actual case is not as clean as depicted in this figure.

phantaurus avatar Sep 13 '24 22:09 phantaurus

Yeah overlapping works reasonably well for Hopper, for older architectures it might be harder to do.

tridao avatar Sep 13 '24 23:09 tridao

So after FA3, is softmax still the bottleneck? I mean, by percentage, how much overlap can FA3 achieve for softmax/TC?

Thanks!

ziyuhuang123 avatar Sep 30 '24 08:09 ziyuhuang123

I don't think that's easy to measure. For small hdim (e.g. 64) and for fp8, softmax is still a bottleneck.

tridao avatar Sep 30 '24 09:09 tridao