tvm
tvm copied to clipboard
[Bug] ConstantOfShape as shape input failed when using "freeze_params=False" option
I simply used relay.frontend.from_onnx to import an ONNX model from PyTorch. The code works fine when freeze_params is set to True (default).
When I was trying to separate weights from other constants, the same issue as #11783 occurred.
Here is the traceback:
mod, params = relay.frontend.from_onnx(graph_def, freeze_params=False)
File "python/tvm/relay/frontend/onnx.py", line 5816, in from_onnx
mod, params = g.from_onnx(graph, opset)
File "python/tvm/relay/frontend/onnx.py", line 5486, in from_onnx
self._construct_nodes(graph)
File "python/tvm/relay/frontend/onnx.py", line 5598, in _construct_nodes
op = self._convert_operator(op_name, inputs, attr, self.opset)
File "python/tvm/relay/frontend/onnx.py", line 5709, in _convert_operator
sym = convert_map[op_name](inputs, attrs, self._params)
File "python/tvm/relay/frontend/onnx.py", line 2732, in _impl_v8
shape = fold_constant(expand_shape(in_shape, shape))
File "python/tvm/relay/frontend/onnx.py", line 2701, in expand_shape
if in_dims < new_dims:
File "python/tvm/tir/expr.py", line 185, in __bool__
return self.__nonzero__()
File "python/tvm/tir/expr.py", line 180, in __nonzero__
"Cannot use and / or / not operator to Expr, hint: "
ValueError: Cannot use and / or / not operator to Expr, hint: use tvm.tir.all / tvm.tir.any instea
onnx.py:2701:
def expand_shape(in_shape, shape):
"""A function expands the shape when the rank is lower than that of the given
intput. Also it replaces the extent of the shape with the corresponding extent
of the intput when it is 1.
"""
in_dims = infer_shape(in_shape)[0]
new_dims = infer_shape(shape)[0]
The value "shape" in infer_shape is
free_var %onnx::ConstantOfShape_139: Tensor[(1), int64];
dyn.full(1, %onnx::ConstantOfShape_139, shape=None, dtype="int64")
And I got in_dims = 3 and new_dims = ?
It seems that the "freeze_params=False" prevented the input of "ConstantOfShape" op to be used in inferring shape.
Well, it is actually reasonable considering that when I directly used ONNX infer_shapes, I also got float32[20,12,unk__0]
Environment
Both on TVM 0.9.0 branch and main branch. PyTorch 1.12.0. ONNX 1.12.0. A "ConstantOfShape" op following by an Expand op is needed to reproduce the error.