coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

[ExecuTorch] Cannot Use Bool Index

Open YifanShenSZ opened this issue 1 year ago • 1 comments

This toy model fails in ExecuTorch

        class IndexModel(torch.nn.Module):
            def __init__(self, axis):
                super().__init__()
                self.axis = axis

            def forward(self, x, y):
                index = y > 0.5
                if self.axis == 0:
                    return x[index]
                elif self.axis == 1:
                    return x[:, index]
                elif self.axis == 2:
                    return x[:, :, index]
                else:
                    return x[:, :, :, index]

due to

torch._export.verifier.SpecViolationError: Operator torch._ops.aten._assert_async.msg is not Aten Canonical.

From the aten-dialect exported program, this assertion looks like tensor bound checking for index = y > 0.5

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"):
            # File: /Volumes/data/Software/Mine/coremltools-github_non-param-non-buffer-const/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py:8066 in forward, code: index = y > 0.5
            gt: "b8[10]" = torch.ops.aten.gt.Scalar(arg1_1, 0.5);  arg1_1 = None
            
            # File: /Volumes/data/Software/Mine/coremltools-github_non-param-non-buffer-const/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py:8068 in forward, code: return x[index]
            index: "f32[u4]" = torch.ops.aten.index.Tensor(arg0_1, [gt]);  arg0_1 = gt = None
            
            # File: /Users/yifanshensz/Software/Mine/coremltools-github_non-param-non-buffer-const/envs/coremltools-github_non-param-non-buffer-const-py3.10/lib/python3.10/site-packages/torch/_export/pass_base.py:54 in _create_dummy_node_metadata, code: return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
            sym_size: "Sym(u4)" = torch.ops.aten.sym_size.int(index, 0)
            ge: "Sym(u4 >= 0)" = sym_size >= 0
            scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, 'index.shape[0] is outside of inline constraint [0, 10].');  scalar_tensor = None
            le: "Sym(u4 <= 10)" = sym_size <= 10;  sym_size = None
            scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, 'index.shape[0] is outside of inline constraint [0, 10].');  scalar_tensor_1 = None
            return (index,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='index'), target=None)])
Range constraints: {u0: ValueRanges(lower=0, upper=10, is_bool=False), u1: ValueRanges(lower=0, upper=10, is_bool=False), u2: ValueRanges(lower=0, upper=10, is_bool=False), u3: ValueRanges(lower=0, upper=10, is_bool=False), u4: ValueRanges(lower=0, upper=10, is_bool=False)}

YifanShenSZ avatar Apr 04 '24 00:04 YifanShenSZ

One case of torch.where is also affected by this issue

        class WhereModelSingleParam(nn.Module):
            def forward(self, x):
                return torch.where(x)

YifanShenSZ avatar Apr 16 '24 20:04 YifanShenSZ