coremltools
coremltools copied to clipboard
[ExecuTorch] Cannot Use Bool Index
This toy model fails in ExecuTorch
class IndexModel(torch.nn.Module):
def __init__(self, axis):
super().__init__()
self.axis = axis
def forward(self, x, y):
index = y > 0.5
if self.axis == 0:
return x[index]
elif self.axis == 1:
return x[:, index]
elif self.axis == 2:
return x[:, :, index]
else:
return x[:, :, :, index]
due to
torch._export.verifier.SpecViolationError: Operator torch._ops.aten._assert_async.msg is not Aten Canonical.
From the aten-dialect exported program, this assertion looks like tensor bound checking for index = y > 0.5
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"):
# File: /Volumes/data/Software/Mine/coremltools-github_non-param-non-buffer-const/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py:8066 in forward, code: index = y > 0.5
gt: "b8[10]" = torch.ops.aten.gt.Scalar(arg1_1, 0.5); arg1_1 = None
# File: /Volumes/data/Software/Mine/coremltools-github_non-param-non-buffer-const/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py:8068 in forward, code: return x[index]
index: "f32[u4]" = torch.ops.aten.index.Tensor(arg0_1, [gt]); arg0_1 = gt = None
# File: /Users/yifanshensz/Software/Mine/coremltools-github_non-param-non-buffer-const/envs/coremltools-github_non-param-non-buffer-const-py3.10/lib/python3.10/site-packages/torch/_export/pass_base.py:54 in _create_dummy_node_metadata, code: return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
sym_size: "Sym(u4)" = torch.ops.aten.sym_size.int(index, 0)
ge: "Sym(u4 >= 0)" = sym_size >= 0
scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge); ge = None
_assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'index.shape[0] is outside of inline constraint [0, 10].'); scalar_tensor = None
le: "Sym(u4 <= 10)" = sym_size <= 10; sym_size = None
scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le); le = None
_assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, 'index.shape[0] is outside of inline constraint [0, 10].'); scalar_tensor_1 = None
return (index,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='index'), target=None)])
Range constraints: {u0: ValueRanges(lower=0, upper=10, is_bool=False), u1: ValueRanges(lower=0, upper=10, is_bool=False), u2: ValueRanges(lower=0, upper=10, is_bool=False), u3: ValueRanges(lower=0, upper=10, is_bool=False), u4: ValueRanges(lower=0, upper=10, is_bool=False)}
One case of torch.where is also affected by this issue
class WhereModelSingleParam(nn.Module):
def forward(self, x):
return torch.where(x)