torch-mlir
torch-mlir copied to clipboard
F.interpolate fails due to a missing shape transfer function in shape_lib_gen.py
Reproducible code:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch_mlir # torch-mlir==20220823.574
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = F.interpolate(x, size=torch.Size([8, 8]), mode="nearest")
return x
inputs = torch.randn(1, 64, 4, 4)
model = Model()
results = model(inputs) # torch.Size([1, 64, 4, 4]) -> torch.Size([1, 64, 8, 8])
traced = torch.jit.trace(model, (inputs, ))
linalg_on_tensors_mlir = torch_mlir.compile(traced, [inputs, ], output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
# # run using torch_mlir
# backend = refbackend.RefBackendLinalgOnTensorsBackend()
# compiled = backend.compile(linalg_on_tensors_mlir)
# jit_module = backend.load(compiled)
# results_mlir = jit_module.forward(inputs.numpy())
# from numpy.testing import assert_almost_equal
# assert_almost_equal(results.detach().numpy(), results_mlir, decimal=5)
Error message:
linalg_on_tensors_mlir = torch_mlir.compile(traced, [inputs, ], output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
File "/home/username/venv/mlir_venv/lib/python3.8/site-packages/torch_mlir/__init__.py", line 247, in compile
run_pipeline_with_repro_report(
File "/home/username/venv/mlir_venv/lib/python3.8/site-packages/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: unsupported by backend contract: tensor with unknown rank
note: see current operation: %2 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[1,64,4,4],f32>) -> !torch.vtensor<*,f32>
note: this is likely due to a missing shape transfer function in shape_lib_gen.py
Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints}' /tmp/Model.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
I'm facing this error when running torch-mlir on F.interpolate, can someone help me fix the issue? Thanks!
Here's our guide for adding a new op! https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation
Thanks for the fix! @powderluv (Use torch.ops.aten.upsample_nearest2d)