coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

index_put failing when indices are bool type

Open sreeneel opened this issue 1 year ago • 3 comments

🐞Describing the bug

index_put seems to fail when I have boolean indices . Please see following steps to reproduce

To Reproduce

  • Please add a minimal code example that can reproduce the error when running it.
import torch
from torch import nn
import coremltools as ct


class SimpleNet(nn.Module):
      def __init__(self):
          super().__init__()
          self.not_a_point_embed = torch.ones(1, 2)
          self.point_embed = torch.ones(1, 2)*2
  
      @torch.no_grad()
      def forward(self, point_embedding, labels):
          point_embedding[labels == 0] += self.not_a_point_embed
          point_embedding[labels == 1] += self.point_embed
          return point_embedding
    
model = SimpleNet().eval()
point_embedding = torch.zeros(1, 3, 2)
labels = torch.tensor([[1, 0, 1]])

traced_model = torch.jit.trace(model, (point_embedding, labels))

mlmodel=ct.convert(traced_model,
              convert_to="mlprogram",
              inputs=[
                  ct.TensorType(name="point_embedding", shape=point_embedding.shape),
                  ct.TensorType(name="labels", shape=labels.shape)])

System environment (please complete the following information):

  • coremltools version: coremltools installed from main branch of the repo
  • OS (e.g. MacOS version or Linux type): MacOS
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): Pytorch version Torch version 2.2.0

Additional context

Its failing with assert in index_put torch_op:

@register_torch_op
def index_put(context, node):
    inputs = _get_inputs(context, node, expected=4)
    x = inputs[0]
    indices = inputs[1]
    values = inputs[2]
    accumulate = inputs[3].val
    rank = x.rank
    mode = "add" if accumulate else "update"

    indices_type = indices[0].sym_type.get_primitive()

    if types.is_bool(indices_type):
        assert len(indices) == 1, "Unsupported index_put_ usage."
        indices = indices[0]
        assert (
            indices.shape == x.shape
        ), "indices shape must equal to input shape for index put operation."
        indices = mb.cast(x=indices, dtype="int32")
        indices = mb.non_zero(x=indices)

It looks like assert is failing in the shape-check here:

assert (
            indices.shape == x.shape
        ), "indices shape must equal to input shape for index put operation."

Is this check correct? I tried changing it to indices.shape == x.shape[-1], but scatter_nd fails subsequently in further checks.

Any ideas on what's going on here? Appreciate any help in further troubleshooting this.

sreeneel avatar Feb 11 '24 16:02 sreeneel

I investigated into this a bit more

When I comment out the indices shape-check assert in index_put function, I run into a values shape-check mismatch in scatter_nd::type_inference()

Specifically is_compatible_symbolic_vector below:

def type_inference(self):
        assert self.indices.shape[-1] <= self.data.rank
        expected_updates_shape = (
            self.indices.shape[:-1] + self.data.shape[self.indices.shape[-1] :]
        )
        assert is_compatible_symbolic_vector(
            self.updates.shape, tuple(expected_updates_shape)
        )
        return self.data.sym_type

Interestingly self.updates in above code, has different shape depending on how many hits are there in the label == <val> query in the following code:

point_embedding[labels == 0] += self.not_a_point_embed
point_embedding[labels == 1] += self.point_embed

For inputs in the above sample, these checks pass fine for labels==1

Following are the inputs and corresponding shapes in scatter_nd::type_inference function:

#inputs 
point_embedding = torch.zeros(1, 3, 2)
labels = torch.tensor([[1, 0, 1]]) #shape = (1,3)

labels==0 : shapes in type_inference of scatter_nd class self.indices: %non_zero_1: (is1, 2, int32)(Tensor) self.indices.shape : (is1, 2) self.data : %point_embedding: (1, 3, 2, fp32)(Tensor) self.updates : %12: (2,fp32)(Tensor) self.updates.shape : (2,) expected_updates_shape: (is1, 2) This fails as (2,) != (is1, 2)

labels==1 : shapes in type_inference of scatter_nd class self.indices: %non_zero_1: (is1, 2, int32)(Tensor) self.indices.shape : (is1, 2) self.data : %point_embedding: (1, 3, 2, fp32)(Tensor) self.updates : %9: (is0, 2, fp32)(Tensor) self.updates.shape : (is0, 2,) expected_updates_shape: (is1, 2) This succeeds as (is0, 2) is compatible with (is1, 2)

self.updates is being set from values in index_put_

Any ideas why this difference in shape occurs for values / updates ?

sreeneel avatar Feb 12 '24 17:02 sreeneel

Observed similar issue here. May I ask if this has been fixed?

gudgud96 avatar Jun 14 '24 11:06 gudgud96