Unable to compile PyTorch quantized models with torch-mlir due to CustomClass
What happened?
I'm trying to convert a quantized ResNet-50 model (using PyTorch's quantization-aware training) to MLIR format, but I'm encountering an error when torch-mlir attempts to import the model.
Environment
- python:
3.11.3 - pytorch:
2.7.1+cpu - torch-mlir:
1983b6db9cdba2dfef4939a4fb521e9ac25dc08e
Steps to reproduce the issue
import torch
import torch.nn as nn
import torch.quantization
import torchvision
import torchvision.transforms as transforms
import torch_mlir.torchscript
import os
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_dataset = torchvision.datasets.FakeData(transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
model = torchvision.models.quantization.resnet50(pretrained=True, quantize=False)
model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
torch.quantization.prepare_qat(model, inplace=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
# QAT
for epoch in range(1):
for images, targets in train_loader:
images, targets = images.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
print(f"[Epoch {epoch+1}] QAT training complete")
model.eval()
quantized_model = torch.quantization.convert(model.eval(), inplace=False)
print(quantized_model)
scripted_model = torch.jit.script(quantized_model)
example_input = torch.randn(1, 3, 224, 224)
mlir_module = torch_mlir.torchscript.compile(
scripted_model,
[example_input],
output_type="torch"
)
with open('resnet50_int8.mlir', 'w') as f:
f.write(str(mlir_module))
log:
QuantizableResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): QuantizableBottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU()
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(skip_add_relu): FloatFunctional(
(activation_post_process): Identity()
)
(relu1): ReLU()
(relu2): ReLU()
)
...
Traceback (most recent call last):
File "/workspace/qat_resnet50.py", line 57, in <module>
mlir_module = torch_mlir.torchscript.compile(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/torch-mlir/build/python_packages/torch_mlir/torch_mlir/torchscript.py", line 282, in compile
raise Exception(
Exception:
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
### Importer C++ Exception:
see diagnostics
### Importer Diagnostics:
error: unable to import Torch CustomClass type '0x5baca5da1580' to MLIR type
Do we have any approach to convert a quantized model to MLIR currently?
torchscript.compile has been deprecated for a while, have you tried the newer fx.export_and_import API https://github.com/llvm/torch-mlir/blob/60ffb919b465a9cce77e5c4454ae5958e9350fd8/projects/pt1/examples/fximporter_resnet18.py#L32 ?
I encountered another error:
mlir_module = export_and_import(
quantized_model,
example_input,
output_type="torch"
)
Traceback (most recent call last):
File "/workspace/resnet50.py", line 49, in <module>
mlir_module = export_and_import(
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch_mlir/fx.py", line 98, in export_and_import
prog = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/export/__init__.py", line 319, in export
raise e
File "/usr/local/lib/python3.11/site-packages/torch/export/__init__.py", line 286, in export
return _export(
^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1164, in wrapper
raise e
File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1130, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 2176, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1164, in wrapper
raise e
File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1130, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 2037, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1968, in _non_strict_export
with (
File "/usr/local/lib/python3.11/contextlib.py", line 137, in __enter__
return next(self.gen)
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/_export/non_strict_utils.py", line 915, in _fakify_script_objects
fake_script_obj = _maybe_fakify_obj(obj)
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/_export/non_strict_utils.py", line 899, in _maybe_fakify_obj
fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/torch/_library/fake_class_registry.py", line 142, in maybe_to_fake_obj
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^
AttributeError: __torch__.torch.classes.quantized.Conv2dPackedParamsBase (of Python compilation unit at: 0) does not have a field with name '__obj_flatten__'
I have exactly the same problem when trying to use quantization and then export with the fx.export_and_import
Error during torch-mlir export: torch.torch.classes.quantized.Conv2dPackedParamsBase (of Python compilation unit at: 0) does not have a field with name 'obj_flatten'
When applying torch.ao.quantization.quantize_fx.prepare_fx and then convert_fx before calling the export_and_import
The current supported path for importing PyTorch models into MLIR via fx.export_and_import requires the PyTorch model to be exportable to torch.export.ExportedProgram format via the torch.export.export API. Models quantized following the new PT2 export workflows (for example, https://docs.pytorch.org/ao/stable/tutorials_source/pt2e_quant_qat.html) can be converted to MLIR (see https://github.com/llvm/torch-mlir/issues/4356).
The quantization approach used in the repro doesn't appear to be using the new PT2 export path, so it is unlikely that torch-mlir can support importing this model into MLIR. Also note that the error you have run into seems to be coming from the torch.export.export call itself which is outside of the scope of this repo -- you can file an issue in the PyTorch github. I suggest trying out the new PT2 export path for quantizing the model.