triton
triton copied to clipboard
INT8 / UINT8 for Quantization
Hello,
Is there a proper way to handle INT8/UINT8 for quantization? I am attempting to reproduce the functions below in order to quantize flash-attention with Triton.
def quantize_to_int8(tensor, clip_max, quant_range = 127):
scale = quant_range / clip_max
min_bound = - quant_range
max_bound = quant_range
outputs = np.clip((tensor.astype(np.float32) * scale).round(), min_bound, max_bound)
quant_tensor = outputs.astype(np.int8)
return quant_tensor
def quantize_to_uint8(tensor, clip_max, quant_range = 255): #
scale = quant_range / clip_max
max_bound = quant_range
outputs = np.clip((tensor.astype(np.float32) * scale).round(), 0, max_bound)
quant_tensor = outputs.astype(np.uint8)
return quant_tensor
Any advice would be greatly appreciated.
Thank you,
Enrico
@yuguo68
We used to have a fast code-path for it that disappeared when we simplified the IR. We have plan to extend the ExternElementwiseOp so it can accomodate these cases better
We had a PR on the legacy backend for int 8/4/2 dequantization. Please take a look at https://github.com/openai/triton/pull/759 and the examples in test_dequantize.py. We are migrating to triton-MLIR. After the migration is complete, we will start working on quantization/dequantization ops.
@ptillet @yuguo68 Thank you for the additional information. I will review #759 and the examples in test_dequantize.py. Support for Triton int8 and uint8 dtype conversions would be greatly beneficial.
Hi, I'm interested in implementing this. Would you be able to provide any guidance?
Any updates on this issue? It seems all those low-bitwidth kernels cannot work in the current flow.
Hi @yuguo68 is it still the plan to do now on the new MLIR that we are on? Or is there an alternative already to how one can efficiently do things like weight-only (de)quantization with Triton kernels? Thanks for your help here!
Following-up on this issue as well. Thanks!