torch-mlir
torch-mlir copied to clipboard
quantized model support
Hello,
is there any example/plan to support the quantized pytorch model? I tried two ways to quantize the pytorch model, pytorch-quantization from TensorRT (which insert fake-quant) and pytorch integrated quantization, however both of them hit error when compile it with torch_mlir.compile() function to convert it from pytorch to linalg_on_tensor/tosa/torch dialect.
anyone knows how to import these quantized model in torch-mlir?
thanks jinsong
There are many ways to support quantized models in PyTorch. We are actively looking for what is the demand for each different method. For example, we are exploring QPyTorch here: https://github.com/llvm/torch-mlir/pull/909
Can you paste an example script for how your user's quantized program is written so that we can see how to support it best?
sure @silvasean ,
I think they are much different, pytorch integrated quantization use QuantizedConv to replace Conv, and Tensorrt is to instert fake-quant(Q+DQ) to the model, you can find lots of example for them, this is my mnist example using pytorch to quantize, I didn't clean up the code, but you can run it directly.
mnist_torch.zip
thanks
looks torch.qtorch_ops.block_quantize_nearest described in https://github.com/llvm/torch-mlir/pull/909 is not the standard way to quantize the pytorch model, i think it's better to support the quantization method provided by pytorch. likte QuantizedConvRelu2d()
I have not received very good reports on the usability of the quantization method provided natively by PyTorch, so I stopped looking adding support for it. We can continue working on it though. Are you interested in helping with that?
@silvasean sure, that sounds great, it's our basic requirement to translate quantized model from framework to mlir ecosystem, but I am new to torch-mlir, could u please give some proposal/work flow, so that I can continue working on it. you can find my ID (Jinsong) on Discord also. thanks.
Awesome! I looked into your script and I think we are pretty close to being able to make it work.
The current error I see is:
error: unable to import Torch CustomClass type '0x25ecdc0' to MLIR type
Traceback (most recent call last):
File "/tmp/mnist_torch.py", line 329, in <module>
main(args)
File "/tmp/mnist_torch.py", line 317, in main
module = torch_mlir.compile(qnet, torch.ones(1, 1, 28, 28), output_type=torch_mlir.OutputType.TORCH)
File "/usr/local/google/home/silvasean/pg/torch-mlir/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 144, in compile
mb.import_module(scripted._c, class_annotator)
RuntimeError: see diagnostics
I suspect that 0x25ecdc0 is the type ID for ConvPackedParams. You should be able to fix that by adding support for ConvPackedParams here: https://github.com/llvm/torch-mlir/blob/f774e63abdcbe786e884d0cc14ba198192387332/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp#L104 and you can add a test in https://github.com/llvm/torch-mlir/blob/main/test/python/importer/jit_ir/ivalue_import/quantization.py (see development.md for more info on how to work on Torch-MLIR)
After that, a rough sketch of the work that is needed is:
- adding support for importing Conv packed params (new types and ops). This should be quite similar to what was done for LinearPackedParams in https://github.com/llvm/torch-mlir/commit/d66e8fe1f8e0a4a6a2bbdbb532feed0ba33d34ce
- After that is done, we will need to see what would be the ideal lowering for the backends would be, and implement that. I don't know right now what the proper answer will be.
@sjarus -- what would be the ideal IR that you would like to see before TorchToTosa for quantized layers? For a simple linear layer, we currently will provide IR like this: https://gist.github.com/silvasean/0ac7e4da604ebfeb8e06b056781103f7 -- how would you like that to be massaged, if at all, before it gets to you?
Also, @JasonMaojinsong -- are you interested in LINALG_ON_TENSORS, TOSA, or TORCH OutputType for your compiler?
@silvasean thanks for you guide, it's great, let me try and move forward to Conv packed params,
regarding output type of LINALG_ON_TENSORS, TOSA, or TORCH. maybe LINALG_ON_TENSORS can provide the better IR level, TOSA currently use TOSA.rescale to represent the scale factor for Activation and weights, use "quantization_info" attribute for conv/matmul to represent the zero point. I am not sure how much work needs to do to support the above output types.
@silvasean and @JasonMaojinsong : The quantization_info meta information is constructed at compile time going from framework to TOSA. We would construct that in the TorchToTosa pass in this case. I have this partly implemented locally for quantized.linear . The same pathway already exists along the TFLite -> TOSA -> Linalg-on-tensors path so the Torch path would simply leverage existing plumbing beneath TOSA once the quantized.linear legalization is in place.
@sjarus sounds great, when will you check in this code? let me try it after you finish some of it.