nx
nx copied to clipboard
Quantization via MLIR
There are quantized tensor types in StableHLO spec, but looks like there are still discussions around it, for example https://github.com/openxla/stablehlo/issues/1491. Also, that's only the specification, based on https://github.com/openxla/xla/issues/9291 I infer it's not implemented in XLA yet. So on this level the best bet is to wait.
There is also google/aqt, which is actually the only quantization reference/project I found for Jax/Flax. It lifts quantization one level up, by introducing a quantized tensor/config and reimplementing some operations with quantization in mind.
Finally, it's worth noting that quantization is more broad, there are many approaches and algorithms, and both the use case and hardware plays a role. hf/transformers has a matrix of different techniques.
Closing in favor of #1528. FP8 is already in.