flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

[QST] WIP int8 quantized impl but can't get transposed LDSM to work with current layout

Open carlguo866 opened this issue 1 year ago • 10 comments
trafficstars

I wrote a minimal int8 quantized Flash Attention implementation that sees a 31% speedup with reasonable accuracy. One big issue hindering its performance is that it can't utilize LDSM for the transposed V matrix and falls back on DefaultCopy. While the Q and K swizzled layout works rather naturally with a CopyAtom of SM75_U32x2_LDSM_N, I'm not sure how to construct a layout for V to match the CopyAtom SM75_U16x4_LDSM_T. The atom will require two elements per row and then going down through the columns (see below), and not sure if we can construct a swizzled layout out of it either. Much appreciated if you have any insights! Screenshot 2024-07-10 at 11 57 57 AM

carlguo866 avatar Jul 10 '24 22:07 carlguo866

Right LDSM won't work for V if the data is 8bit. We might have some way to address this soon.

tridao avatar Jul 10 '24 23:07 tridao

Could you provide a bit more detail on how we can approach this? Any idea is welcome, can be half-baked and I can explore more on my own. Happy to make a PR once my stuff gets more polished as well.

carlguo866 avatar Jul 10 '24 23:07 carlguo866

soon

tridao avatar Jul 10 '24 23:07 tridao

The update was indeed quite soon lol - sadly im not using Hopper GPUs and can't use TMA instructions among other things, but thanks for the update anyway!

carlguo866 avatar Jul 11 '24 21:07 carlguo866

you can use LDSM.T and byte-permute, then LDSM, as a way to transpose V we'll release that code soon idk if it works well without warp specialization

tridao avatar Jul 12 '24 00:07 tridao

Looking forward to reading that code! I'm not sure exactly what u mean by doing 2 LDSMs? I thought we can only do one LDSM to load from smem to registers, and theres a barrier between the two.

carlguo866 avatar Jul 12 '24 18:07 carlguo866

Sorry i mean LDSM.T, byte permute, then store using STSM. That way you can transpose V.

tridao avatar Jul 12 '24 20:07 tridao

Got it - sadly STSM is a Hopper only instruction as well, but now I think about it, LDSM, byte-permute, and then STSM.T make more sense (i haven't really looked into STSM tho)? Using LDSM.T would circle back to my original question, which is that the current swizzled layout doesn't allow LDSM.T in the first place because the smallest LDSM unit is 16bit and the instruction needs two elements per row. Register layout is much easier to modify (byte-permute) in comparison. I guess I can try changing the smem layout to fit what LDSM.T wants, but it doesn't seem clean and would cause more bank conflicts than swizzled layout

carlguo866 avatar Jul 12 '24 20:07 carlguo866

I see, I forgot that STSM is Hopper only. The other option is to transpose V in a separate kernel, or fused it with a preceding kernel (e.g. gemm).

tridao avatar Jul 12 '24 22:07 tridao

Yeah, makes sense - i dont think theres a preceding kernel using the same V and am worried a separate kernel would cost too much performance. I'll explore the different options and lyk if i have more updates

carlguo866 avatar Jul 12 '24 23:07 carlguo866

tentatively I found a solution to be using normal LSTM to load from smem to registers and then use movmatrix to transpose on registers

carlguo866 avatar Aug 14 '24 17:08 carlguo866

tentatively I found a solution to be using normal LSTM to load from smem to registers and then use movmatrix to transpose on registers

FYI. FA3 has a solution using ldmatrix and stmatrix instructions to implement in-kernel transpose for fp8 matrix. https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L25

itsliupeng avatar Aug 15 '24 03:08 itsliupeng

thx for the note - as Tri and I discussed above, stmatrix is only supported on Hopper gpus while I don't work with them

carlguo866 avatar Aug 15 '24 04:08 carlguo866