marlin icon indicating copy to clipboard operation
marlin copied to clipboard

[QST] Weight Format & GEMM

Open jeromeku opened this issue 1 year ago • 2 comments

@efrantar

Awesome work -- always enjoy your research on and implementation of efficient model inference.

I was hoping that you could shed some light on the logic of the packing step?

  • My understanding is that the individual int4 values need rearranged in order to use the fast unpack / convert functions from FasterTransformer.

  • Is the subsequent interleaving such that ldmatrix can be used on these packed values such that each thread holds the necessary values for mma.sync? Typically ldmatrix is used on fp16 / bf16 types, but in this case the weights are sub-byte types, hence the additional preprocessing required for efficient shared -> register copy. I know FasterTransformer has its own formatting logic as a workaround for this issue; I have yet to find a general solution to efficiently leveraging tensorcore primitives on sub-byte types without preprocessing weights to a custom format.

  • Theoretically, if I were to preprocess the weights of a non-GPTQ int4 model using the packing function -- i.e., any groupwise quantization method that yields 4b weights along with group scales and zeros -- would I be able to use the Marlin kernel on such model? If not, what changes would need to be made?

Many thanks!

jeromeku avatar Apr 01 '24 22:04 jeromeku

Hi, Marlin only uses ldmatrix for the activations, as the weights are already preshuffled optimally for both dequantization and tensor core fragment layouts. You can find some more detailed description of how this format works here https://github.com/IST-DASLab/marlin/issues/12.

Marlin is completely independent of GPTQ, the model needs to be quantized symmetrically either with groupsize 128 or row-wise (how you produced this model doesn't matter to Marlin); then you can preprocess the weights and use Marlin kernels. Zero-points are currently not supported, the reasons for this are discussed here https://github.com/IST-DASLab/marlin/issues/5#issuecomment-1934082099.

efrantar avatar Apr 02 '24 16:04 efrantar

@efrantar

Thank you for taking the time to explain.

Have you looked into Cutlass, specifically the 3.x API that introduced the CuTe abstractions for tensor thread-value manipulation / mapping? Wondering if it could potentially help generalize / extend the handcrafted code in Marlin without sacrificing performance.

jeromeku avatar Apr 03 '24 13:04 jeromeku