torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Error lowering `index_put_`: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0

Open ScottTodd opened this issue 1 year ago • 2 comments

Here are assorted experiments that I'm trying to rework into concrete test cases suitable for use here in torch-mlir (they use FxProgramsBuilder from iree-turbine to get MLIR from Python at the moment) : https://colab.research.google.com/gist/ScottTodd/f5e657c773e79be7a95aafb774cb3fd3/index_put-pytorch-torch-mlir-iree-turbine-iree.ipynb#scrollTo=UHFkgOtMz0k5

https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html

This puts three values (0.3, 1.4, and 2.5) into place at indices [0, 3], [1, 4], and [2, 5]:

import torch
a = torch.zeros(3, 6)
a.index_put_(indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])], values=torch.tensor([0.3, 1.4, 2.5]))
print(a)

tensor([[0.0000, 0.0000, 0.0000, 0.3000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.4000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.5000]])

that imports to this IR:

module @module {
  func.func @simple_index_put(%arg0: !torch.tensor<[3,6],f32>) -> !torch.vtensor<[3,6],f32> {
    %0 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
    %1 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64_1> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
    %2 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>) : !torch.vtensor<[3],f32>
    %3 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[3,6],f32>
    %none = torch.constant.none
    %4 = torch.aten.clone %0, %none : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
    %none_0 = torch.constant.none
    %5 = torch.aten.clone %1, %none_0 : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
    %none_1 = torch.constant.none
    %6 = torch.aten.clone %2, %none_1 : !torch.vtensor<[3],f32>, !torch.none -> !torch.vtensor<[3],f32>
    %7 = torch.prim.ListConstruct %4, %5 : (!torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
    %false = torch.constant.bool false
    %8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[3],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
    torch.overwrite.tensor.contents %8 overwrites %arg0 : !torch.vtensor<[3,6],f32>, !torch.tensor<[3,6],f32>
    return %8 : !torch.vtensor<[3,6],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      torch_tensor_3_torch.int64: "0x08000000000000000000000001000000000000000200000000000000",
      torch_tensor_3_torch.int64_1: "0x08000000030000000000000004000000000000000500000000000000",
      torch_tensor_3_torch.float32: "0x040000009A99993E3333B33F00002040"
    }
  }
#-}

which compiles successfully through IREE and also through torch-mlir-opt --pass-pipeline=builtin.module(func.func(torch-decompose-complex-ops,convert-torch-to-tmtensor))

The index_put_ op also appears to support broadcasting the "values" from a single element to all indices:

import torch
a = torch.zeros(3, 6)
a.index_put_(indices=[torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])], values=torch.tensor([0.3]))
print(a)

tensor([[0.0000, 0.0000, 0.0000, 0.3000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3000]])

that, however, imports to IR that fails to compile:

module @module {
  func.func @simple_index_put(%arg0: !torch.tensor<[3,6],f32>) -> !torch.vtensor<[3,6],f32> {
    %0 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
    %1 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.int64_1> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
    %2 = torch.vtensor.literal(dense<5.000000e-01> : tensor<1xf32>) : !torch.vtensor<[1],f32>
    %3 = torch.copy.to_vtensor %arg0 : !torch.vtensor<[3,6],f32>
    %none = torch.constant.none
    %4 = torch.aten.clone %0, %none : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
    %none_0 = torch.constant.none
    %5 = torch.aten.clone %1, %none_0 : !torch.vtensor<[3],si64>, !torch.none -> !torch.vtensor<[3],si64>
    %none_1 = torch.constant.none
    %6 = torch.aten.clone %2, %none_1 : !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[1],f32>
    %7 = torch.prim.ListConstruct %4, %5 : (!torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
    %false = torch.constant.bool false
    %8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[1],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
    torch.overwrite.tensor.contents %8 overwrites %arg0 : !torch.vtensor<[3,6],f32>, !torch.tensor<[3,6],f32>
    return %8 : !torch.vtensor<[3,6],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      torch_tensor_3_torch.int64: "0x08000000000000000000000001000000000000000200000000000000",
      torch_tensor_3_torch.int64_1: "0x08000000030000000000000004000000000000000500000000000000"
    }
  }
#-}
/tmp/index_put_broadcast.mlir:15:10: error: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
    %8 = torch.aten.index_put %3, %7, %6, %false : !torch.vtensor<[3,6],f32>, !torch.list<optional<vtensor>>, !torch.vtensor<[1],f32>, !torch.bool -> !torch.vtensor<[3,6],f32>
         ^
/tmp/index_put_broadcast.mlir:15:10: note: see current operation: 
%38 = "tm_tensor.scatter"(%36, %37, %35) <{dimension_map = array<i64: 0, 1>, operandSegmentSizes = array<i32: 2, 1>, unique_indices = false}> ({
^bb0(%arg1: f32, %arg2: f32):
  "tm_tensor.yield"(%arg1) : (f32) -> ()
}) : (tensor<1x1x1xf32>, tensor<3x2xi32>, tensor<3x6xf32>) -> tensor<3x6xf32>

There are other broadcasting semantics with "indices", some of which might be handled here in torch-mlir correctly already, but I'm not sure. I'd like to write a suite of e2e tests to verify all the edge cases, possibly drawing on https://github.com/pytorch/pytorch/blob/main/test/test_indexing.py

ScottTodd avatar Jun 07 '24 18:06 ScottTodd

Wrote some tests cases (TBD how these can land in an existing or new test suite): https://gist.github.com/ScottTodd/1e95795e79d17964078217ca98a3a398

iree runtime + compiler at 20240410.859:

test_single_value                  | PASS
test_multiple_values               | PASS
test_broadcast_value_along_axis    | FAIL
test_broadcast_value_along_indices | FAIL
test_broadcast_values_along_axis   | PASS

iree runtime + compiler at 20240606.916:

test_single_value                  | PASS
test_multiple_values               | PASS (then crash)
test_broadcast_value_along_axis    | FAIL
test_broadcast_value_along_indices | FAIL
test_broadcast_values_along_axis   | PASS (then crash)

The new "pass then crash" cases are suspicious - need to debug the source of that... could look at IR dumps or bisect through nightly releases to find a culprit commit range.

The "broadcast_value_along" cases look like they are just unimplemented. That's not too surprising, since https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html contains basically no information about how the op should behave...

ScottTodd avatar Jun 10 '24 21:06 ScottTodd

The "pass (then crash)" issues were unrelated to the op lowerings and can be worked around in IREE by using copy_buffer instead of wrap_buffer.

Still would like to see the broadcasting cases of torch.aten.index_put / tm_tensor.scatter implemented here in torch-mlir

ScottTodd avatar Jun 13 '24 17:06 ScottTodd