Error when running `torch-mlir-opt` on a network with `BatchNorm1d`
Hi, I am fairly new to torch-mlir and I run into an issue when trying to use torch-mlir-opt. I have a script that I use to export my network (based on Transformers) by running python export_model_mlir.py > exported_model.mlir:
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
from model.net import ConstituentNet
script_module = torch.jit.script(ConstituentNet())
ca = ClassAnnotator()
ca.exportNone(script_module._c._type())
ca.exportPath(script_module._c._type(), ["forward"])
ca.annotateArgs(
script_module._c._type(),
["forward"],
[None, ([1, 30, 16], torch.float32, True)],
)
mb = ModuleBuilder()
mb.import_module(script_module._c, ca)
mb.module.operation.print()
I then tried to run:
torch-mlir-opt -torchscript-module-to-torch-backend-pipeline="optimize=true" -torch-backend-to-tosa-backend-pipeline="optimize=true" < exported_model.mlir > model.mlir
but I got the following error:
<stdin>:169:14: error: 'torch.copy.to_vtensor' op failed to verify that operand is corresponding !torch.tensor
%186 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int
^
<stdin>:493:12: note: called from
%201 = torch.prim.CallMethod %197["forward"] (%193) : !torch.nn.Module<"__torch__.model.layer.Transformer">, (!torch.tensor) -> !torch.tensor
^
<stdin>:169:14: note: see current operation: %486 = "torch.copy.to_vtensor"(%112) : (!torch.tensor<[?,?,2],unk>) -> !torch.vtensor
%186 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int
Here is the fragment of MLIR code that seems to cause an issue, and if I understand correctly, it is simply the implementation of BatchNorm1d:
159 func.func private @__torch__.torch.nn.modules.batchnorm.___torch_mangle_3.BatchNorm1d._check_input_dim(%arg0: !torch.nn.Module<"__torch__.torch.nn.modules.batchnorm.___torch_mangle_3.BatchNorm1d">, %arg1: !torch.tensor) -> !torch.none {
160 %none_0 = torch.constant.none
161 %str = torch.constant.str "builtins.ValueError"
162 %str_1 = torch.constant.str "expected 2D or 3D input (got {}D input)"
163 %false_2 = torch.constant.bool false
164 %int2_3 = torch.constant.int 2
165 %int3 = torch.constant.int 3
166 %183 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int
167 %184 = torch.aten.ne.int %183, %int2_3 : !torch.int, !torch.int -> !torch.bool
168 %185 = torch.prim.If %184 -> (!torch.bool) {
169 %186 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int
170 %187 = torch.aten.ne.int %186, %int3 : !torch.int, !torch.int -> !torch.bool
171 torch.prim.If.yield %187 : !torch.bool
172 } else {
173 torch.prim.If.yield %false_2 : !torch.bool
174 }
175 torch.prim.If %185 -> () {
176 %186 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int
177 %187 = torch.aten.format(%str_1, %186) : !torch.str, !torch.int -> !torch.str
178 torch.prim.RaiseException %187, %str : !torch.str, !torch.str
179 torch.prim.If.yield
180 } else {
181 torch.prim.If.yield
182 }
183 return %none_0 : !torch.none
184 }
I use BatchNorm1d in a few places in my network and I don't know how to check which one is causing the problem, but maybe from the MLIR code you can give me an idea of what might be going wrong.
I am not sure if that is relevant, but I am using a slightly older hash ea371a9 as I am following the steps listed in another project.
It's hard to tell without being able to reproduce the issue. This does seem like something going wrong inside Torch-MLIR (verifier failure), but I can't help beyond that.
Is there a way for you to reproduce this at HEAD and provide a complete repro script?
Closing this as not reproducible.