torch-mlir
torch-mlir copied to clipboard
QPyTorch Support (BFP quantization)
Hi folks! My team and I are looking into having compiler support for block floating point (BFP) in Torch-MLIR. Wondering what you think about extending the Torch-MLIR support for these cases. Below is a dummy test network I used as an experiment to compile a PyTorch model with BFP additions from qtorch
via torch_mlir
:
Input Description
The model is a basic MatMul followed by a BFP cast and a ReLU activation.
class SimpleModel(nn.Module):
def __init__(self, input_dim, output_size):
super(SimpleModel, self).__init__()
self.matmul = nn.Linear(input_dim, output_size)
self.relu = nn.ReLU()
def forward(self, x):
matmul_out = self.matmul(x.flatten(1))
quantized_matmul_out = block_quantize(matmul_out, wl=8, dim=0, rounding="nearest")
relu_out = self.relu(quantized_matmul_out)
return relu_out
Observed Behaviour
I used the torch_mlir.compile()
API to compile the module into TOSA IR. While the module seems to run forward-propagate fine, the compilation seems to hit an assert in the qtorch.quant_function.block_quantize()
for not having a valid "rounding mode". Also, removing the BFP quantization in the forward-propagate of the module yields a successful compilation.
Lastly, if I quantize the inputs tensors of the module and then call torch_mlir.compile()
on them, there doesn't seem to be any issue - are the casts optimized out in this case?
Script to Reproduce
For convenience, I made a draft PR with a minimal script to reproduce the issue I'm hitting here: https://github.com/llvm/torch-mlir/pull/909
FYI @silvasean
Here is the stack trace of the error I'm hitting:
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::__contains__.int_list(int[] l, int item) -> (bool):
Expected a value of type 'List[int]' for argument 'l' but instead found type 'List[str]'.
aten::__contains__.str_list(str[] l, str item) -> (bool):
Expected a value of type 'str' for argument 'item' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'item' to be of type 'Tensor' because it was not annotated with an explicit type.
aten::__contains__.str(Dict(str, t) dict, str key) -> (bool):
Could not match type List[str] to Dict[str, t] in argument 'dict': Cannot match a dict to List[str].
aten::__contains__.int(Dict(int, t) dict, int key) -> (bool):
Could not match type List[str] to Dict[int, t] in argument 'dict': Cannot match a dict to List[str].
aten::__contains__.bool(Dict(bool, t) dict, bool key) -> (bool):
Could not match type List[str] to Dict[bool, t] in argument 'dict': Cannot match a dict to List[str].
aten::__contains__.float(Dict(float, t) dict, float key) -> (bool):
Could not match type List[str] to Dict[float, t] in argument 'dict': Cannot match a dict to List[str].
aten::__contains__.complex(Dict(complex, t) dict, complex key) -> (bool):
Could not match type List[str] to Dict[complex, t] in argument 'dict': Cannot match a dict to List[str].
aten::__contains__.Tensor(Dict(Tensor, t) dict, Tensor key) -> (bool):
Could not match type List[str] to Dict[Tensor, t] in argument 'dict': Cannot match a dict to List[str].
aten::__contains__.float_list(float[] l, float item) -> (bool):
Expected a value of type 'List[float]' for argument 'l' but instead found type 'List[str]'.
__contains__(str self, str key) -> (bool):
Expected a value of type 'str' for argument 'self' but instead found type 'List[str]'.
The original call is:
File ".../python3.8/site-packages/qtorch/quant/quant_function.py", line 257
"""
assert isinstance(x, torch.Tensor), "x is not a single precision Floating Point Tensor"
assert rounding in ["stochastic", "nearest"], "invalid rounding mode, {}".format(rounding)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
quant_module = get_module(x)
if rounding == "nearest":
'block_quantize' is being compiled since it was called from 'SimpleModel.forward'
File "block_quantize_experiment.py", line 17
def forward(self, x):
matmul_out = self.matmul(x.flatten(1))
quantized_matmul_out = block_quantize(matmul_out, wl=8, dim=0)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
relu_out = self.relu(quantized_matmul_out)
return relu_out
Hi, this issue is not specific to Torch-MLIR. This is a general issue with TorchScript'ing the model, and comes from not having a type annotation on the rounding
argument on def block_quantize(x, wl, dim=-1, rounding="stochastic"):
. Replacing it with def block_quantize(x, wl, dim=-1, rounding: str="stochastic"):
fixes that issues, but reveals further issues. Instead, I passed use_tracing=True
to torch_mlir.compile which avoids the issue.
(note, that the issue is not the assertion triggering, but a failure of the TorchScript compiler to compile the line of code which contains the assertion).
I then get this IR, which looks like a normal unquantized module... is there something I need to do to enable quantization on the module? (I applied your PR verbatim, except for adding use_tracing=True
).
#loc0 = loc(unknown)
module attributes {torch.debug_module_name = "SimpleModel"} {
func.func @forward(%arg0: !torch.vtensor<[5,64],f32> loc(unknown)) -> !torch.vtensor<[5,4],f32> {
%0 = torch.vtensor.literal(opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4x64xf32>) : !torch.vtensor<[4,64],f32> loc(#loc0)
%1 = torch.vtensor.literal(dense<[5.70356846E-4, 0.0960100442, 0.11995019, -0.00602804124]> : tensor<4xf32>) : !torch.vtensor<[4],f32> loc(#loc0)
%none = torch.constant.none loc(#loc0)
%int0 = torch.constant.int 0 loc(#loc1)
%int-1 = torch.constant.int -1 loc(#loc2)
%int1 = torch.constant.int 1 loc(#loc2)
%2 = torch.aten.flatten.using_ints %arg0, %int1, %int-1 : !torch.vtensor<[5,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,64],f32> loc(#loc2)
%3 = torch.aten.linear %2, %0, %1 : !torch.vtensor<[5,64],f32>, !torch.vtensor<[4,64],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[5,4],f32> loc(#loc3)
%4 = torch.aten.contiguous %3, %int0 : !torch.vtensor<[5,4],f32>, !torch.int -> !torch.vtensor<[5,4],f32> loc(#loc1)
%int0_0 = torch.constant.int 0 loc(#loc1)
%5 = torch.aten.size.int %4, %int0_0 : !torch.vtensor<[5,4],f32>, !torch.int -> !torch.int loc(#loc1)
%int1_1 = torch.constant.int 1 loc(#loc1)
%6 = torch.aten.size.int %4, %int1_1 : !torch.vtensor<[5,4],f32>, !torch.int -> !torch.int loc(#loc1)
%7 = torch.prim.ListConstruct %5, %6 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc1)
%8 = torch.aten.empty.memory_format %7, %none, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5,4],f32> loc(#loc1)
%int0_2 = torch.constant.int 0 loc(#loc1)
%9 = torch.valsem.aten.fill.Scalar %8, %int0_2 : !torch.vtensor<[5,4],f32>, !torch.int -> !torch.vtensor<[5,4],f32> loc(#loc1)
%10 = torch.aten.relu %9 : !torch.vtensor<[5,4],f32> -> !torch.vtensor<[5,4],f32> loc(#loc4)
return %10 : !torch.vtensor<[5,4],f32> loc(#loc0)
} loc(#loc0)
} loc(#loc0)
#loc1 = loc("/usr/local/google/home/silvasean/.local/lib/python3.9/site-packages/qtorch/quant/quant_function.py":260:0)
#loc2 = loc("/usr/local/google/home/silvasean/pg/torch-mlir/torch-mlir/examples/block_quantize_experiment.py":16:0)
#loc3 = loc("/usr/local/google/home/silvasean/.local/lib/python3.9/site-packages/torch/nn/modules/linear.py":114:0)
#loc4 = loc("/usr/local/google/home/silvasean/.local/lib/python3.9/site-packages/torch/nn/functional.py":1415:0)
Hmm... it seems like an issue with torch.jit.trace
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
forward propagate results on inputs is:
tensor([[0.0000, 0.0000, 0.0000, 1.4062],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.1250, 0.1719, 0.0000, 1.2031],
[0.4531, 0.0000, 0.0000, 0.5938]])
/usr/local/google/home/silvasean/.local/lib/python3.9/site-packages/torch/jit/_trace.py:992: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Tensor-likes are not close!
Mismatched elements: 6 / 20 (30.0%)
Greatest absolute difference: 1.40625 at index (0, 3) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 3) (up to 1e-05 allowed)
I will investigate how to debug this.
So, to summarize, the issues so far are not specific to Torch-MLIR -- we are having trouble even torch.jit.script/torch.jit.trace'ing it, which is a prerequisite for importing into Torch-MLIR.
Hi, so it looks like the source of the issue is that quant_module.block_quantize_nearest
is not registered as a PyTorch op. It is just a random function call (similar to calling into numpy or making a network request), so we cannot see it from the compiler.
As a next step, I would recommend that you define your ops with TORCH_LIBRARY
(see here) instead of as a PYBIND11_MODULE
(which is the current code). Once that is done, we can see those ops in the compiler and take the next step.
Assigning back to @Svoch as the next step is on the QPyTorch side.
Thank you @silvasean! This is very helpful, and it makes sense why the op is not being picked up. Let me modify the registration method somehow and I'll update this issue with my findings.
I was able to modify QPyTorch operation bindings such that block_quantize_nearest
and block_quantize_stochastic
are registered via TORCH_LIBRARY
and not PYBIND11
anymore. Seems like Torch-MLIR is able to pick up these operations as custom Torch Dialect Op (e.g. torch.operator "my_qtorch_ops.block_quantize_nearest"
). I wasn't able to lower it below the Torch Dialect because shape inference seems to have failed for this Op. Here is the IR I got:
#loc0 = loc(unknown)
module attributes {torch.debug_module_name = "SimpleModel"} {
func.func @forward(%arg0: !torch.vtensor<[5,64],f32> loc(unknown)) -> !torch.vtensor {
%0 = torch.vtensor.literal(dense<[0.0832831561, 0.0882747471, -0.0846773684, 0.0318481326]> : tensor<4xf32>) : !torch.vtensor<[4],f32> loc(#loc0)
%1 = torch.vtensor.literal(opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4x64xf32>) : !torch.vtensor<[4,64],f32> loc(#loc0)
%int-1 = torch.constant.int -1 loc(#loc0)
%int1 = torch.constant.int 1 loc(#loc1)
%int8 = torch.constant.int 8 loc(#loc2)
%int0 = torch.constant.int 0 loc(#loc3)
%2 = torch.aten.flatten.using_ints %arg0, %int1, %int-1 : !torch.vtensor<[5,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,64],f32> loc(#loc4)
%3 = torch.aten.linear %2, %1, %0 : !torch.vtensor<[5,64],f32>, !torch.vtensor<[4,64],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[5,4],f32> loc(#loc5)
%4 = torch.tensor_static_info_cast %3 : !torch.vtensor<[5,4],f32> to !torch.vtensor<*,f32> loc(#loc5)
%5 = torch.copy.to_tensor %4 : !torch.tensor<*,f32> loc(#loc5)
%6 = torch.operator "my_qtorch_ops.block_quantize_nearest"(%5, %int8, %int0) : (!torch.tensor<*,f32>, !torch.int, !torch.int) -> !torch.tensor loc(#loc6)
%7 = torch.copy.to_vtensor %6 : !torch.vtensor loc(#loc7)
%8 = torch.aten.relu %7 : !torch.vtensor -> !torch.vtensor loc(#loc7)
return %8 : !torch.vtensor loc(#loc0)
} loc(#loc0)
} loc(#loc0)
#loc1 = loc("example.py":20:43)
#loc2 = loc("example.py":21:90)
#loc3 = loc("example.py":21:93)
#loc4 = loc("example.py":20:33)
#loc5 = loc(callsite(".../QPyTorch/qtorch_venv/lib/python3.8/site-packages/torch/nn/modules/linear.py":103:15 at "example.py":20:21))
#loc6 = loc("example.py":21:31)
#loc7 = loc(callsite(callsite(".../QPyTorch/qtorch_venv/lib/python3.8/site-packages/torch/nn/functional.py":1442:17 at ".../QPyTorch/qtorch_venv/lib/python3.8/site-packages/torch/nn/modules/activation.py":98:15) at "example.py":22:19))
More details on steps to reproduce this error:
1. QPyTorch Modifications
The modifications for QPyTorch operator bindings is present on the master
branch in this QPyTorch fork. The steps to build the custom CPU quantization operators for TorchScript is as below:
- Build the custom operators for
qtorch_ops
by creating some directory namedbuild
and running the CMake commandcmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" $PATH_TO_QUANT-CPU_DIR
in it. WherePATH_TO_QUANT-CPU_DIR
is the path toQPyTorch/qtroch/quant/quant_cpu
. - Run
make
to create the shared librarylibqtorch_ops.so
. - Build the forked
qtorch
package usingpip install .
command from the top-level directory in QPyTorch. - Run the Torch-MLIR experiment script below from the top-level directory in QPyTorch:
import torch
import torch.nn as nn
import qtorch
# from qtorch.quant import block_quantize -> instead we will use the operators in the qtorch_ops
torch.ops.load_library("qtorch/quant/quant_cpu/build/libqtorch_ops.so")
print(torch.ops.qtorch_ops.block_quantize_nearest)
import torch_mlir
from torch_mlir_e2e_test.tosa_backends import linalg_on_tensors
class SimpleModel(nn.Module):
def __init__(self, input_dim, output_size):
super(SimpleModel, self).__init__()
self.matmul = nn.Linear(input_dim, output_size)
self.relu = nn.ReLU()
def forward(self, x):
matmul_out = self.matmul(x.flatten(1))
quantized_matmul_out = torch.ops.qtorch_ops.block_quantize_nearest(matmul_out, 8, 0)
relu_out = self.relu(quantized_matmul_out)
return relu_out
batches = 5
input_dim = 64
output_size = 4
inputs = torch.randn(batches, input_dim)
model = SimpleModel(input_dim, output_size)
print("forward propagate results on inputs is:\n", model.forward(inputs))
# quantized_inputs = block_quantize(inputs, wl=8, dim=0, rounding="nearest")
# print("forward propagate of quantized inputs result is ", model.forward(quantized_inputs))
module = torch_mlir.compile(model, inputs, output_type=torch_mlir.OutputType.TOSA, use_tracing=False)
print("Module compiled to TOSA is:\n", module)
2. The Results
Below is the logs from running the Torch-MLIR experiment script. The custom operators seem to be successfully built and linked, but as seen in the IR in the previous comment the shape information for Ops like torch.tensor_static_info_cast
, torch.copy.to_tensor
and the custom block quantization Op is missing.
<built-in method block_quantize_nearest of PyCapsule object at 0xffffabcd36c0>
forward propagate results on inputs is:
tensor([[0.0000, 0.1875, 0.1719, 0.0000],
[0.0000, 0.0000, 0.1250, 0.0000],
[0.0000, 0.1797, 0.0000, 0.3164],
[0.2109, 0.0000, 0.8203, 0.0000],
[0.7812, 1.7500, 0.9844, 0.9062]])
Traceback (most recent call last):
File "examples/torch-mlir_experiment.py", line 34, in <module>
module = torch_mlir.compile(model, inputs, output_type=torch_mlir.OutputType.TOSA, use_tracing=False)
File ".../torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 157, in compile
run_pipeline_with_repro_report(
File ".../torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 49, in run_pipeline_with_repro_report
raise Exception(f"""
Exception:
Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: unsupported by backend lowering: tensor with unknown rank or dtype
note: see current operation: %8 = "torch.tensor_static_info_cast"(%7) : (!torch.vtensor<[5,4],f32>) -> !torch.vtensor<*,f32>
note: this is likely due to a missing shape transfer function in shape_lib_gen.py
Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torch-backend-to-tosa-backend-pipeline' /tmp/SimpleModel.mlir
Add '-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
Do I need to add additional info to enable shape inference for the custom Ops in Torch-MLIR?
You will want to do something like what Bob does in https://github.com/llvm/torch-mlir/pull/895 to add the shape and dtype inference. That will need to be done in a fork of Torch-MLIR for now, but I would consider including QPyTorch support as first class if we can get to a good solution here for our customers (there seem to be a LOT of hardware vendors that would love to have this be really well supported, and upstream PyTorch doesn't seem to be providing a good solution, so I'm interested in incubating something in Torch-MLIR).
We discussed in the Torch-MLIR developer hour that one of the nod.ai folks was going to be building a PoC of QPyTorch lowering into TOSA. Was that you @nithinsubbiah that was going to work on it? I'm happy to provide architectural guidance here to get a really great PoC and deliver really good first-class support here.
cc @powderluv
Hi @silvasean, yes I'll work on this integration. Adding shape inference for this QPyTorch op and check if that can lower to TOSA would be the first step I think (please correct me if I'm wrong).
Sounds great @silvasean! This is very exciting update on the custom Op support front, and looks like a very good timing.
I'll follow the example for the ExampleIdentity custom Op in #895 and share my findings here. Just to confirm though, it seems like the additional steps for supporting a custom Op is very similar to extending Torch Dialect with a new Op, including extending the Dialect and shape inference methods. Is extending torch_ods_gen.py
and shape_lib_gen.py
enough for this?
cc @rdadolf
You also typically need to update RefineTypes.cpp too -- you can see all the steps here: https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation You may want to skip lowering to linalg.
Just to reiterate what Sean said: yes, the process for extensions is intended to follow the same 5-step process that Sean linked. There's a bit more information on the differences in this readme co-located with the example code.
You should not need to write anything that looks like what's in /python/torch_mlir/_torch_mlir_custom_op_example
, since your library already exists and will take the place of that. And you shouldn't really need to muck with much of anything cmake- or build-related in the torch-mlir repo unless you decide to add extra files to modularize things.
Thank you @silvasean and @rdadolf for the details! Quick update: I kept hitting some KeyErrors in registry.py
running update_torch_ods.sh
. I think it may be due to a torch versioning mismatch on my end, since I got the KeyErrors for non-relevant Ops like:
KeyError: 'aten::zero.functional : (Tensor) -> (Tensor)'
Will update this issue with the findings, @nithinsubbiah is helping with this step.
That sounds accurate. That was changed recently in #915. I've been working some of the bumps with Nithin on Discord, including his registry.py
bugs. The demo code he's worked up is heading in the right direction.
Quick update, I registered one qtorch op with dtype and shape inference following instructions in #895. I got the following Torch IR:
module attributes {torch.debug_module_name = "SimpleModel"} {
func.func @forward(%arg0: !torch.vtensor<[5,64],f32>) -> !torch.vtensor<[5,4],f32> {
%0 = torch.vtensor.literal(dense<[0.0162453204, 0.0135056227, -0.0940410942, 0.0825424045]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%1 = torch.vtensor.literal(dense<"0xLONGHEX"> : tensor<4x64xf32>) : !torch.vtensor<[4,64],f32>
%int-1 = torch.constant.int -1
%int1 = torch.constant.int 1
%int8 = torch.constant.int 8
%int0 = torch.constant.int 0
%2 = torch.aten.flatten.using_ints %arg0, %int1, %int-1 : !torch.vtensor<[5,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,64],f32>
%3 = torch.aten.linear %2, %1, %0 : !torch.vtensor<[5,64],f32>, !torch.vtensor<[4,64],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[5,4],f32>
%4 = torch.qtorch_ops.block_quantize_nearest %3, %int8, %int0 : !torch.vtensor<[5,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,4],f32>
%5 = torch.aten.relu %4 : !torch.vtensor<[5,4],f32> -> !torch.vtensor<[5,4],f32>
return %5 : !torch.vtensor<[5,4],f32>
}
}
But lowering to TOSA fails with an exception:
Exception:
Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.constant.int'
note: see current operation: %3 = "torch.constant.int"() {value = 8 : i64} : () -> !torch.int
Reviving discussion on this - Was able to use custom op extension to register a qtorch op and did a rewrite from Torch -> TOSA as tosa.custom
op with quantization parameters as operands. Attaching the TOSA IR below:
module attributes {torch.debug_module_name = "SimpleModel"} {
func.func @forward(%arg0: tensor<5x64xf32>) -> tensor<5x4xf32> {
%0 = "tosa.const"() {value = dense<[[-0.0788066238, 0.0372745693, 0.102689952, 0.0827734619]]> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
%1 = "tosa.const"() {value = dense<"0xLONGHEX"> : tensor<1x64x4xf32>} : () -> tensor<1x64x4xf32>
%2 = "tosa.const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32>
%3 = "tosa.const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%4 = "tosa.reshape"(%arg0) {new_shape = [1, 5, 64]} : (tensor<5x64xf32>) -> tensor<1x5x64xf32>
%5 = "tosa.matmul"(%4, %1) : (tensor<1x5x64xf32>, tensor<1x64x4xf32>) -> tensor<1x5x4xf32>
%6 = "tosa.reshape"(%5) {new_shape = [5, 4]} : (tensor<1x5x4xf32>) -> tensor<5x4xf32>
%7 = "tosa.add"(%6, %0) : (tensor<5x4xf32>, tensor<1x4xf32>) -> tensor<5x4xf32>
%8 = "tosa.custom"(%7, %2, %3) {identifier = "qtorch_blockquantizenearest"} : (tensor<5x4xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<5x4xf32>
%9 = "tosa.clamp"(%8) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<5x4xf32>) -> tensor<5x4xf32>
return %9 : tensor<5x4xf32>
}
}
cc: @silvasean @Svoch @powderluv
What is the action item here @nithinsubbiah ?
@Svoch @nithinsubbiah can we close this issue?
I think we can close this issue. To get the custom Ops picked up by Torch-MLIR, we needed to modify QPyTorch (like this).
@silvasean - I wonder how does the Custom Ops support RFC affect this path however, we can discuss it there for better visibility.
cc @nithinsubbiah @powderluv