torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Unable to compile PyTorch quantized models with torch-mlir due to CustomClass

Open jysh1214 opened this issue 4 months ago • 4 comments

What happened?

I'm trying to convert a quantized ResNet-50 model (using PyTorch's quantization-aware training) to MLIR format, but I'm encountering an error when torch-mlir attempts to import the model.

Environment

  • python: 3.11.3
  • pytorch: 2.7.1+cpu
  • torch-mlir: 1983b6db9cdba2dfef4939a4fb521e9ac25dc08e

Steps to reproduce the issue

import torch
import torch.nn as nn
import torch.quantization
import torchvision
import torchvision.transforms as transforms
import torch_mlir.torchscript
import os

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_dataset = torchvision.datasets.FakeData(transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

model = torchvision.models.quantization.resnet50(pretrained=True, quantize=False)

model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
torch.quantization.prepare_qat(model, inplace=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

 # QAT
for epoch in range(1): 
    for images, targets in train_loader:
        images, targets = images.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    print(f"[Epoch {epoch+1}] QAT training complete")

model.eval()
quantized_model = torch.quantization.convert(model.eval(), inplace=False)

print(quantized_model)

scripted_model = torch.jit.script(quantized_model)
example_input = torch.randn(1, 3, 224, 224)

mlir_module = torch_mlir.torchscript.compile(
    scripted_model, 
    [example_input], 
    output_type="torch"
)

with open('resnet50_int8.mlir', 'w') as f:
    f.write(str(mlir_module))

log:

QuantizableResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): QuantizableBottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip_add_relu): FloatFunctional(
        (activation_post_process): Identity()
      )
      (relu1): ReLU()
      (relu2): ReLU()
    )
...

Traceback (most recent call last):
  File "/workspace/qat_resnet50.py", line 57, in <module>
    mlir_module = torch_mlir.torchscript.compile(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/torch-mlir/build/python_packages/torch_mlir/torch_mlir/torchscript.py", line 282, in compile
    raise Exception(
Exception: 
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
### Importer C++ Exception:
see diagnostics
### Importer Diagnostics:
error: unable to import Torch CustomClass type '0x5baca5da1580' to MLIR type

Do we have any approach to convert a quantized model to MLIR currently?

jysh1214 avatar Aug 04 '25 09:08 jysh1214

torchscript.compile has been deprecated for a while, have you tried the newer fx.export_and_import API https://github.com/llvm/torch-mlir/blob/60ffb919b465a9cce77e5c4454ae5958e9350fd8/projects/pt1/examples/fximporter_resnet18.py#L32 ?

sahas3 avatar Aug 07 '25 12:08 sahas3

I encountered another error:

mlir_module = export_and_import(
    quantized_model,
    example_input,
    output_type="torch"
)
Traceback (most recent call last):
  File "/workspace/resnet50.py", line 49, in <module>
    mlir_module = export_and_import(
                  ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch_mlir/fx.py", line 98, in export_and_import
    prog = torch.export.export(
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/export/__init__.py", line 319, in export
    raise e
  File "/usr/local/lib/python3.11/site-packages/torch/export/__init__.py", line 286, in export
    return _export(
           ^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1164, in wrapper
    raise e
  File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1130, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 2176, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1164, in wrapper
    raise e
  File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1130, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 2037, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/export/_trace.py", line 1968, in _non_strict_export
    with (
  File "/usr/local/lib/python3.11/contextlib.py", line 137, in __enter__
    return next(self.gen)
           ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/_export/non_strict_utils.py", line 915, in _fakify_script_objects
    fake_script_obj = _maybe_fakify_obj(obj)
                      ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/_export/non_strict_utils.py", line 899, in _maybe_fakify_obj
    fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torch/_library/fake_class_registry.py", line 142, in maybe_to_fake_obj
    flat_x = x.__obj_flatten__()  # type: ignore[attr-defined]
             ^^^^^^^^^^^^^^^^^
AttributeError: __torch__.torch.classes.quantized.Conv2dPackedParamsBase (of Python compilation unit at: 0) does not have a field with name '__obj_flatten__'

jysh1214 avatar Aug 11 '25 15:08 jysh1214

I have exactly the same problem when trying to use quantization and then export with the fx.export_and_import

Error during torch-mlir export: torch.torch.classes.quantized.Conv2dPackedParamsBase (of Python compilation unit at: 0) does not have a field with name 'obj_flatten'

When applying torch.ao.quantization.quantize_fx.prepare_fx and then convert_fx before calling the export_and_import

copparihollmann avatar Oct 20 '25 02:10 copparihollmann

The current supported path for importing PyTorch models into MLIR via fx.export_and_import requires the PyTorch model to be exportable to torch.export.ExportedProgram format via the torch.export.export API. Models quantized following the new PT2 export workflows (for example, https://docs.pytorch.org/ao/stable/tutorials_source/pt2e_quant_qat.html) can be converted to MLIR (see https://github.com/llvm/torch-mlir/issues/4356).

The quantization approach used in the repro doesn't appear to be using the new PT2 export path, so it is unlikely that torch-mlir can support importing this model into MLIR. Also note that the error you have run into seems to be coming from the torch.export.export call itself which is outside of the scope of this repo -- you can file an issue in the PyTorch github. I suggest trying out the new PT2 export path for quantizing the model.

sahas3 avatar Nov 05 '25 13:11 sahas3