flash-attention
flash-attention copied to clipboard
[QST] WIP int8 quantized impl but can't get transposed LDSM to work with current layout
I wrote a minimal int8 quantized Flash Attention implementation that sees a 31% speedup with reasonable accuracy. One big issue hindering its performance is that it can't utilize LDSM for the transposed V matrix and falls back on DefaultCopy. While the Q and K swizzled layout works rather naturally with a CopyAtom of SM75_U32x2_LDSM_N, I'm not sure how to construct a layout for V to match the CopyAtom SM75_U16x4_LDSM_T. The atom will require two elements per row and then going down through the columns (see below), and not sure if we can construct a swizzled layout out of it either. Much appreciated if you have any insights!
Right LDSM won't work for V if the data is 8bit. We might have some way to address this soon.
Could you provide a bit more detail on how we can approach this? Any idea is welcome, can be half-baked and I can explore more on my own. Happy to make a PR once my stuff gets more polished as well.
soon
The update was indeed quite soon lol - sadly im not using Hopper GPUs and can't use TMA instructions among other things, but thanks for the update anyway!
you can use LDSM.T and byte-permute, then LDSM, as a way to transpose V we'll release that code soon idk if it works well without warp specialization
Looking forward to reading that code! I'm not sure exactly what u mean by doing 2 LDSMs? I thought we can only do one LDSM to load from smem to registers, and theres a barrier between the two.
Sorry i mean LDSM.T, byte permute, then store using STSM. That way you can transpose V.
Got it - sadly STSM is a Hopper only instruction as well, but now I think about it, LDSM, byte-permute, and then STSM.T make more sense (i haven't really looked into STSM tho)? Using LDSM.T would circle back to my original question, which is that the current swizzled layout doesn't allow LDSM.T in the first place because the smallest LDSM unit is 16bit and the instruction needs two elements per row. Register layout is much easier to modify (byte-permute) in comparison. I guess I can try changing the smem layout to fit what LDSM.T wants, but it doesn't seem clean and would cause more bank conflicts than swizzled layout
I see, I forgot that STSM is Hopper only. The other option is to transpose V in a separate kernel, or fused it with a preceding kernel (e.g. gemm).
Yeah, makes sense - i dont think theres a preceding kernel using the same V and am worried a separate kernel would cost too much performance. I'll explore the different options and lyk if i have more updates
tentatively I found a solution to be using normal LSTM to load from smem to registers and then use movmatrix to transpose on registers
tentatively I found a solution to be using normal LSTM to load from smem to registers and then use
movmatrixto transpose on registers
FYI. FA3 has a solution using ldmatrix and stmatrix instructions to implement in-kernel transpose for fp8 matrix. https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L25
thx for the note - as Tri and I discussed above, stmatrix is only supported on Hopper gpus while I don't work with them