[QST] Weight Format & GEMM
@efrantar
Awesome work -- always enjoy your research on and implementation of efficient model inference.
I was hoping that you could shed some light on the logic of the packing step?
-
My understanding is that the individual int4 values need rearranged in order to use the fast unpack / convert functions from FasterTransformer.
-
Is the subsequent interleaving such that
ldmatrixcan be used on these packed values such that each thread holds the necessary values formma.sync? Typicallyldmatrixis used onfp16 / bf16types, but in this case the weights are sub-byte types, hence the additional preprocessing required for efficient shared -> register copy. I know FasterTransformer has its own formatting logic as a workaround for this issue; I have yet to find a general solution to efficiently leveraging tensorcore primitives on sub-byte types without preprocessing weights to a custom format. -
Theoretically, if I were to preprocess the weights of a
non-GPTQint4model using the packing function -- i.e., any groupwise quantization method that yields4bweights along with group scales and zeros -- would I be able to use theMarlinkernel on such model? If not, what changes would need to be made?
Many thanks!
Hi, Marlin only uses ldmatrix for the activations, as the weights are already preshuffled optimally for both dequantization and tensor core fragment layouts. You can find some more detailed description of how this format works here https://github.com/IST-DASLab/marlin/issues/12.
Marlin is completely independent of GPTQ, the model needs to be quantized symmetrically either with groupsize 128 or row-wise (how you produced this model doesn't matter to Marlin); then you can preprocess the weights and use Marlin kernels. Zero-points are currently not supported, the reasons for this are discussed here https://github.com/IST-DASLab/marlin/issues/5#issuecomment-1934082099.
@efrantar
Thank you for taking the time to explain.
Have you looked into Cutlass, specifically the 3.x API that introduced the CuTe abstractions for tensor thread-value manipulation / mapping? Wondering if it could potentially help generalize / extend the handcrafted code in Marlin without sacrificing performance.