rten icon indicating copy to clipboard operation
rten copied to clipboard

MatMulNBits support for 4-bit quantization

Open robertknight opened this issue 11 months ago • 1 comments

Sub-8bit quantization support is still work-in-progress in the ONNX standard. See https://github.com/onnx/onnx/issues/6326. The only thing the standard currently supports is (4-bit constant) => DequantizeLinear => MatMul, relying on the runtime to fuse these ops for efficiency.

However many models have started appearing on eg. Hugging Face which rely on the Microsoft-specific MatMulNBits operator to enable matmul operations with 4-bit weights.

Sub-tasks:

  • [x] Support ONNX operators from domains other than "ai.onnx" in ONNX loader (https://github.com/robertknight/rten/pull/1029, https://github.com/robertknight/rten/pull/1031)
  • [x] Implement matmul of f32 LHS x 4-bit quantized RHS in rten-gemm (https://github.com/robertknight/rten/pull/1030(
  • [x] Implement MatMulNBits operator without zero point support (https://github.com/robertknight/rten/pull/1031)
  • [ ] Compare MatMulNBits performance to ORT on Arm / x64, optimize as needed
  • [ ] Add fast path for GEMV with f32 LHS and 4-bit quantized RHS
  • [ ] Implement zero-point support for MatMulNBits
  • [ ] Update documentation to describe support for weight-only quantization
  • [ ] Investigate whether the MatMulNBits implementation can be re-used for a fusion of MatMul + DequantizeLinear
  • [ ] Update the tools/ort-quantize.py script to support Q4 quantization (https://github.com/robertknight/rten/pull/1035)

robertknight avatar Feb 04 '25 09:02 robertknight

Some notes about 4-bit quantization via standard operators only:

ONNX Runtime will fuse DequantizeLinear + MatMul into MatMulNBits. This makes it possible to create and distribute int4-quantized models using standard operators only. However this combination of operators does not express an important attriute of MatMulNBits: accuracy_level. This controls whether the internal compute happens in int8 or f32. By default DequantizeLinear + MatMul gets fused into a MatMulNBits node with accuracy_level=4 which uses int8 compute. See https://github.com/microsoft/onnxruntime/blob/0463aa9fc3ef02d30d7177c0065cd4b7d36a39f7/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h#L380. Int8 computation is a lot faster (~20ms/token vs ~105ms/token on my M3 Pro), but will impact accuracy. By how much I don't know.

Unlike most fusions, this fusion is not cheap in time or memory as it requires transposing the weights to align with the input format expected by MatMulNBits (where weights have shape (N, K/block_size, block_bytes)): https://github.com/microsoft/onnxruntime/blob/0463aa9fc3ef02d30d7177c0065cd4b7d36a39f7/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc#L385

robertknight avatar Oct 27 '25 12:10 robertknight