index_put failing when indices are bool type
🐞Describing the bug
index_put seems to fail when I have boolean indices . Please see following steps to reproduce
To Reproduce
- Please add a minimal code example that can reproduce the error when running it.
import torch
from torch import nn
import coremltools as ct
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.not_a_point_embed = torch.ones(1, 2)
self.point_embed = torch.ones(1, 2)*2
@torch.no_grad()
def forward(self, point_embedding, labels):
point_embedding[labels == 0] += self.not_a_point_embed
point_embedding[labels == 1] += self.point_embed
return point_embedding
model = SimpleNet().eval()
point_embedding = torch.zeros(1, 3, 2)
labels = torch.tensor([[1, 0, 1]])
traced_model = torch.jit.trace(model, (point_embedding, labels))
mlmodel=ct.convert(traced_model,
convert_to="mlprogram",
inputs=[
ct.TensorType(name="point_embedding", shape=point_embedding.shape),
ct.TensorType(name="labels", shape=labels.shape)])
System environment (please complete the following information):
- coremltools version: coremltools installed from
mainbranch of the repo - OS (e.g. MacOS version or Linux type): MacOS
- Any other relevant version information (e.g. PyTorch or TensorFlow version): Pytorch version
Torch version 2.2.0
Additional context
Its failing with assert in index_put torch_op:
@register_torch_op
def index_put(context, node):
inputs = _get_inputs(context, node, expected=4)
x = inputs[0]
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3].val
rank = x.rank
mode = "add" if accumulate else "update"
indices_type = indices[0].sym_type.get_primitive()
if types.is_bool(indices_type):
assert len(indices) == 1, "Unsupported index_put_ usage."
indices = indices[0]
assert (
indices.shape == x.shape
), "indices shape must equal to input shape for index put operation."
indices = mb.cast(x=indices, dtype="int32")
indices = mb.non_zero(x=indices)
It looks like assert is failing in the shape-check here:
assert (
indices.shape == x.shape
), "indices shape must equal to input shape for index put operation."
Is this check correct?
I tried changing it to indices.shape == x.shape[-1], but scatter_nd fails subsequently in further checks.
Any ideas on what's going on here? Appreciate any help in further troubleshooting this.
I investigated into this a bit more
When I comment out the indices shape-check assert in index_put function, I run into a values shape-check mismatch in scatter_nd::type_inference()
Specifically is_compatible_symbolic_vector below:
def type_inference(self):
assert self.indices.shape[-1] <= self.data.rank
expected_updates_shape = (
self.indices.shape[:-1] + self.data.shape[self.indices.shape[-1] :]
)
assert is_compatible_symbolic_vector(
self.updates.shape, tuple(expected_updates_shape)
)
return self.data.sym_type
Interestingly self.updates in above code, has different shape depending on how many hits are there in the label == <val> query in the following code:
point_embedding[labels == 0] += self.not_a_point_embed
point_embedding[labels == 1] += self.point_embed
For inputs in the above sample, these checks pass fine for labels==1
Following are the inputs and corresponding shapes in scatter_nd::type_inference function:
#inputs
point_embedding = torch.zeros(1, 3, 2)
labels = torch.tensor([[1, 0, 1]]) #shape = (1,3)
labels==0 : shapes in type_inference of scatter_nd class
self.indices: %non_zero_1: (is1, 2, int32)(Tensor)
self.indices.shape : (is1, 2)
self.data : %point_embedding: (1, 3, 2, fp32)(Tensor)
self.updates : %12: (2,fp32)(Tensor)
self.updates.shape : (2,)
expected_updates_shape: (is1, 2)
This fails as (2,) != (is1, 2)
labels==1 : shapes in type_inference of scatter_nd class
self.indices: %non_zero_1: (is1, 2, int32)(Tensor)
self.indices.shape : (is1, 2)
self.data : %point_embedding: (1, 3, 2, fp32)(Tensor)
self.updates : %9: (is0, 2, fp32)(Tensor)
self.updates.shape : (is0, 2,)
expected_updates_shape: (is1, 2)
This succeeds as (is0, 2) is compatible with (is1, 2)
self.updates is being set from values in index_put_
Any ideas why this difference in shape occurs for values / updates ?
Observed similar issue here. May I ask if this has been fixed?