flash-attention
flash-attention copied to clipboard
Softmax (particularly exp operations) becomes a major bottleneck in full FP16 pipeline
Hello!
I have recently switched to a full FP16 pipeline using MMA F16F16F16F16. In such cases, softmax becomes the bottleneck of TensorCore Active %, particularly when running in embedded GPUs such as Orin.
Latency bottleneck lies particularly in the following code block in function scale_apply_exp2:
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
}
}
Simply deleting this function would increase TensorCore Active % from ~65% to ~90%. (Further removing warp-wise reduction operations will make Tensor Core Active % reach ~95%. Therefore the Allreduce function is not the major latency bottleneck in this scenario)
I have modified the pipeline to use FP16 exp operation hexp2 (which brings TensorCore Active % from ~50% to ~65%) I also tried changing the order of computation in this loop to make sure values belonging to one AtomMMA get computed first, before moving on to the other values. That doesn't help at all. Nsight compute result also shows that the exp pipeline is the bottleneck, such that no matter how we rearrange the order of exp and gemm operations, the workload is simply constrained by the capacity of the exp pipeline. I'm particularly shocked that in Orin, hexp2 is this slow, since with the same implementation in 40series GPUs, TensorCore Active % can easily hit > 85%.
So I'm basically stuck here and there doesn't seem to be much I can do now. I'm wondering if you have any suggestions or insights to share?
Thank you so much!
Yup exp is one of the bottlenecks. We talked about that a bit in the FA3 paper.
exp uses the MUFU (multi-function unit), which has quite low throughput (e.g. 16 ops per clock cycle, which is 4-8x lower than add / mul floating point operations). https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
Thank you so much for your response! I recently read the FA3 paper. I guess GPUs like Orin do not have async MMA operations. Do you think the ping-pong structure described in the paper could be adapted for synchronized MMA operations? It seems that even with synchronized MMAs, TensorCore and MUFU can still overlap once instructions are launched, as long as there's no data dependency. It seems possible to write kernels that explicitly instantiate two sets of registers and mimic the behavior shown in this figure.
However, you did mention in the paper that the actual case is not as clean as depicted in this figure.
Yeah overlapping works reasonably well for Hopper, for older architectures it might be harder to do.
So after FA3, is softmax still the bottleneck? I mean, by percentage, how much overlap can FA3 achieve for softmax/TC?
Thanks!
I don't think that's easy to measure. For small hdim (e.g. 64) and for fp8, softmax is still a bottleneck.