iree icon indicating copy to clipboard operation
iree copied to clipboard

Generalize `QuantizedMatmulToMatmul` to handle `linalg.quantized_batch_matmul`.

Open bjacob opened this issue 1 year ago • 0 comments

Context

Consider the following 4 linalg named ops and how they differ from each other, by comparing their definitions in the Linalg OpDSL (follow the links):

  1. matmul
  2. batch_matmul
  3. quantized_matmul
  4. quantized_batch_matmul

In summary:

  • The quantized_ variants add extra "zero_point" inputs, which are subtracted from the matrix operands before being multiplied.
  • The batch_ variants add an extra parallel "batch" dimension, but the arithmetci is still overall the same.

Our compilation strategy for all kinds of matmuls is generally to reduce all these kinds of variants through just matmul (or batch_matmul, but that too gets funneled to the same path as matmul later down in codegen). In particular, we have a pass, QuantizedMatmulToMatmul, that reduces quantized_matmul to just matmul. The underlying mathematics making it possible is not complicated, it's really just distributing the multiplication over the subtractions of the zero_points. As linked in a comment there, this is explained in this old paper, section 2.3. Maybe start by reading that and get familiar with how that pass merely implements that idea, and take a look at its test.

This works fine for quantized_matmul, but doesn't currently handle quantized_batch_matmul. The context where this suddenly appeared on the radar is https://github.com/openxla/iree/issues/15399, where some workload is running very slow as a result.

What needs to be done

This QuantizedMatmulToMatmul pass needs to be generalized to also handle quantized_batch_matmul. Just like it currently rewrites a quantized_matmul to a matmul, it should now also rewrite a quantized_batch_matmul to a batch_matmul.

The existing test should be expanded with quantized_batch_matmul test cases, expecting a batch_matmul op to be created.

As batching is merely an additional parallel dimension, not changing the arithmetic at all, this should only involve handling more general tensor shapes and more general iterators and affine maps, but the arithmetic ops should stay the same.

This should NOT try to rewrite a quantized_batch_matmul to a matmul. Resolving the batching is out of scope of this pass and is taken care of much later down in codegen.

bjacob avatar Feb 28 '24 15:02 bjacob