Error lowering `index_put_`: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
Here are assorted experiments that I'm trying to rework into concrete test cases suitable for use here in torch-mlir (they use FxProgramsBuilder from iree-turbine to get MLIR from Python at the moment) : https://colab.research.google.com/gist/ScottTodd/f5e657c773e79be7a95aafb774cb3fd3/index_put-pytorch-torch-mlir-iree-turbine-iree.ipynb#scrollTo=UHFkgOtMz0k5
https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html
This puts three values (0.3, 1.4, and 2.5) into place at indices [0, 3], [1, 4], and [2, 5]:
import torch
a = torch.zeros(3, 6)
a.index_put_(indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])], values=torch.tensor([0.3, 1.4, 2.5]))
print(a)
tensor([[0.0000, 0.0000, 0.0000, 0.3000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 1.4000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.5000]])
that imports to this IR:
module @module {
func.func @simple_index_put(%arg0: !torch.tensor<[3,6],f32>) -> !torch.vtensor<[3,6],f32> {
%0 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
%1 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64_1> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
%2 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>) : !torch.vtensor<[3],f32>
%3 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[3,6],f32>
%none = torch.constant.none
%4 = torch.aten.clone %0, %none : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
%none_0 = torch.constant.none
%5 = torch.aten.clone %1, %none_0 : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
%none_1 = torch.constant.none
%6 = torch.aten.clone %2, %none_1 : !torch.vtensor<[3],f32>, !torch.none -> !torch.vtensor<[3],f32>
%7 = torch.prim.ListConstruct %4, %5 : (!torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
%false = torch.constant.bool false
%8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[3],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
torch.overwrite.tensor.contents %8 overwrites %arg0 : !torch.vtensor<[3,6],f32>, !torch.tensor<[3,6],f32>
return %8 : !torch.vtensor<[3,6],f32>
}
}
{-#
dialect_resources: {
builtin: {
torch_tensor_3_torch.int64: "0x08000000000000000000000001000000000000000200000000000000",
torch_tensor_3_torch.int64_1: "0x08000000030000000000000004000000000000000500000000000000",
torch_tensor_3_torch.float32: "0x040000009A99993E3333B33F00002040"
}
}
#-}
which compiles successfully through IREE and also through torch-mlir-opt --pass-pipeline=builtin.module(func.func(torch-decompose-complex-ops,convert-torch-to-tmtensor))
The index_put_ op also appears to support broadcasting the "values" from a single element to all indices:
import torch
a = torch.zeros(3, 6)
a.index_put_(indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])], values=torch.tensor([0.3]))
print(a)
tensor([[0.0000, 0.0000, 0.0000, 0.3000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.3000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3000]])
that, however, imports to IR that fails to compile:
module @module {
func.func @simple_index_put(%arg0: !torch.tensor<[3,6],f32>) -> !torch.vtensor<[3,6],f32> {
%0 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
%1 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64_1> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
%2 = torch.vtensor.literal(dense<5.000000e-01> : tensor<1xf32>) : !torch.vtensor<[1],f32>
%3 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[3,6],f32>
%none = torch.constant.none
%4 = torch.aten.clone %0, %none : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
%none_0 = torch.constant.none
%5 = torch.aten.clone %1, %none_0 : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
%none_1 = torch.constant.none
%6 = torch.aten.clone %2, %none_1 : !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[1],f32>
%7 = torch.prim.ListConstruct %4, %5 : (!torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
%false = torch.constant.bool false
%8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[1],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
torch.overwrite.tensor.contents %8 overwrites %arg0 : !torch.vtensor<[3,6],f32>, !torch.tensor<[3,6],f32>
return %8 : !torch.vtensor<[3,6],f32>
}
}
{-#
dialect_resources: {
builtin: {
torch_tensor_3_torch.int64: "0x08000000000000000000000001000000000000000200000000000000",
torch_tensor_3_torch.int64_1: "0x08000000030000000000000004000000000000000500000000000000"
}
}
#-}
/tmp/index_put_broadcast.mlir:15:10: error: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
%8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[1],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
^
/tmp/index_put_broadcast.mlir:15:10: note: see current operation:
%38 = "tm_tensor.scatter"(%36, %37, %35) <{dimension_map = array<i64: 0, 1>, operandSegmentSizes = array<i32: 2, 1>, unique_indices = false}> ({
^bb0(%arg1: f32, %arg2: f32):
"tm_tensor.yield"(%arg1) : (f32) -> ()
}) : (tensor<1x1x1xf32>, tensor<3x2xi32>, tensor<3x6xf32>) -> tensor<3x6xf32>
There are other broadcasting semantics with "indices", some of which might be handled here in torch-mlir correctly already, but I'm not sure. I'd like to write a suite of e2e tests to verify all the edge cases, possibly drawing on https://github.com/pytorch/pytorch/blob/main/test/test_indexing.py
Wrote some tests cases (TBD how these can land in an existing or new test suite): https://gist.github.com/ScottTodd/1e95795e79d17964078217ca98a3a398
iree runtime + compiler at 20240410.859:
test_single_value | PASS
test_multiple_values | PASS
test_broadcast_value_along_axis | FAIL
test_broadcast_value_along_indices | FAIL
test_broadcast_values_along_axis | PASS
iree runtime + compiler at 20240606.916:
test_single_value | PASS
test_multiple_values | PASS (then crash)
test_broadcast_value_along_axis | FAIL
test_broadcast_value_along_indices | FAIL
test_broadcast_values_along_axis | PASS (then crash)
The new "pass then crash" cases are suspicious - need to debug the source of that... could look at IR dumps or bisect through nightly releases to find a culprit commit range.
The "broadcast_value_along" cases look like they are just unimplemented. That's not too surprising, since https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html contains basically no information about how the op should behave...
The "pass (then crash)" issues were unrelated to the op lowerings and can be worked around in IREE by using copy_buffer instead of wrap_buffer.
Still would like to see the broadcasting cases of torch.aten.index_put / tm_tensor.scatter implemented here in torch-mlir