oneDNN icon indicating copy to clipboard operation
oneDNN copied to clipboard

GEMM API for efficient LLM inference with W8A16

Open oleotiger opened this issue 1 year ago • 3 comments

I want to perform inference on quantized LLAMA (W8A16) on ARM-v9 (with SVE) using oneDNN. The LLAMA weights are per-group quantized.

Based on my understanding, I need to prepack the weights to reduce the cost of repeated packing. However, packing will disrupt the arrangement of per-group quantization scales and shifts. I understand that dequantization needs to be fused with the kernel. If fused with packing, it's equivalent to storing another copy of the weights in FP16, essentially undoing the quantization.

I haven't figured out how to combine prepacking and per-group dequantization.

Which interface should I use for prepacking? SVE instructions can be 256-bit or 512-bit wide; how does oneDNN intelligently handle packing? After prepacking and saving the weights again, how do I fuse dequantization with the kernel during computation?

oleotiger avatar Jan 20 '24 09:01 oleotiger

@oleotiger, we are working on enabling per-group quantization in oneDNN. You can find description of proposed design for fused weight decompression here. Implementation is not yet available for any platforms though. The only option for now is to decompress weights separately, as you indicated.

vpirogov avatar Jan 23 '24 00:01 vpirogov

+@igorsafo

vpirogov avatar Jan 23 '24 02:01 vpirogov

API and validation changes necessary to support W8A16 quantization landed to main and rls-v3.4 branches. Specifics is covered in GPT Quantization RFC.

+@jondea, @milpuz01 for additional comments on Arm specifics.

vpirogov avatar Feb 01 '24 18:02 vpirogov