Decompose aten.index_put_impl to aten.index_put.hacked_twin
This patch is to remove the none-index in aten.index_put, then lower it to stablehlo and tosa. This patch will use and change the commit in [Stablehlo] Add converter for aten._index_put_impl op #2343. Then add a decompose pass, which will lead to rewritting of tosa index_put and some torchdynamo/linalg bug. The decompose pass is mostly inspired from https://github.com/llvm/torch-mlir/pull/2344. And rewrite some tosa indexput code for the new decompose pass.
Demand from https://github.com/nod-ai/SHARK/issues/1336 New generated T5 model: https://storage.googleapis.com/shark_tank/chi-nod/t5_small/stablehlo/t5small_stablehlo_0829_transformers4.26.0_elide.mlir
6 different index broadcast cases need to be supported:
# First 3 cases the index2 is torch.Size([3])
# Case 1
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0]])
index2 = torch.tensor([1,2,3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 1])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 2
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0,0,0]])
index2 = torch.tensor([1,2,3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 3
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([0,0,0])
index2 = torch.tensor([1,2,3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])
# Next 3 cases Change the index2 into torch.Size([1, 3])
# Case 4
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0]]) # torch.Size([1,1])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 1])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 5
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0,0,0]]) # torch.Size([1, 3])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])
#Case 6
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([0,0,0]) # torch.Size([3])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([3])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])
For anyone( @Vremold @RamirezLucas @mgehre-amd @vivekkhandelwal1 @eric-k256 ) who might review/test, you can cherry-pick and use this test code directly test_indexput_hacketwin_35.py:
import torch
import torch_mlir
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input, index1, index2, src):
return torch.index_put(input, indices=(index1, index2), values=src, accumulate=False)
m = Net()
src = torch.arange(1, 6)
index1 = torch.tensor([0, 0, 0, 0, 0])
index2 = torch.tensor([1, 2, 3, 4, 0])
input = torch.arange(10, 25, step=1, dtype=src.dtype).view(3, 5)
m = torch_mlir.compile(m, [input, index1, index2, src], output_type="stablehlo")
print(m.operation.get_asm())
m = torch_mlir.compile(m, [input, index1, index2, src], output_type="tosa")
print(m.operation.get_asm())
When tested with the t5 model, it breaks before the index_put op is generated by slice and copy pattern.
File "/nodclouddata/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 69, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline:
error: "aten::copy_"("<eval_with_key>.2 from /nodclouddata/chi/src/SHARK/shark_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":9:12): found an op that was marked as backend illegal
note: "aten::copy_"("<eval_with_key>.2 from /nodclouddata/chi/src/SHARK/shark_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":9:12): see current operation: %166 = "torch.aten.index_put.hacked_twin"(%159, %165, %161, %7) : (!torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool) -> !torch.vtensor<[1,4],si64>
note: "aten::copy_"("<eval_with_key>.2 from /nodclouddata/chi/src/SHARK/shark_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":9:12): this is likely due to DecomposeComplexOps being unable to decompose this op
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops= extra-library=})' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
test_slicecopy.py Get the decompose pass. But then lower to stablehlo bug.
Legalizing operation : 'torch.aten.index_put.hacked_twin'(0xeb61fd0) {
%90 = "torch.aten.index_put.hacked_twin"(%15, %89, %49, %10) : (!torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool) -> !torch.vtensor<[1,4],si64>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.index_put.hacked_twin -> ()' {
Trying to match "mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenIndexPutHackedTwinOp>"
** Insert : 'torch_c.to_builtin_tensor'(0xebdf150)
** Insert : 'torch_c.to_builtin_tensor'(0xebdf1e0)
** Insert : 'stablehlo.reshape'(0xebdf2e0)
** Insert : 'stablehlo.reshape'(0xebdf440)
** Failure : unimplemented: Only support multi indexes with same shape
"mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenIndexPutHackedTwinOp>" result 0
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
/nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir:17:10: error: failed to legalize operation 'torch.aten.index_put.hacked_twin' that was explicitly marked illegal
%8 = torch.aten.index_put.hacked_twin %1, %7, %3, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
indexput_hacked_twin_debuginfo.txt
error: number of output elements (3) doesn't match expected number of elements (1)
the issue is with this https://github.com/llvm/torch-mlir/pull/2407/files#diff-37f576362dfffc48c6d013a2edd9357d53d2bb08345a7d8e7f7217cea90701c8R610-R611 reshape op. It generates this op %92 = "stablehlo.reshape"(%90) : (tensor<1x1xi64>) -> tensor<1x3x1xi64> which is not reshape but reshape and broadcast, this kind of reshape is not supported in stablehlo We have to first reshape to tensor<1x1xi64> -> tensor<1x1x1xi64>, and then broadcast it to tensor<1x1x1xi64> -> tensor<1x3x1xi64>
indexput.hacked_twin to stablehlo Done. Here is the Python e2e tests.
import torch
import torch_mlir
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input, index1, index2, src):
return torch.index_put(input, indices=(index1, index2), values=src, accumulate=False)
m = Net()
src = torch.arange(1, 6)
index1 = torch.tensor([0, 0, 0, 0, 0])
index2 = torch.tensor([1, 2, 3, 4, 0])
input = torch.arange(10, 25, step=1, dtype=src.dtype).view(3, 5)
m = torch_mlir.compile(m, [input, index1, index2, src], output_type="stablehlo")
print(m.operation.get_asm())
'''
module attributes {torch.debug_module_name = "Net"} {
func.func @forward(%arg0: tensor<3x5xi64>, %arg1: tensor<5xi64>, %arg2: tensor<5xi64>, %arg3: tensor<5xi64>) -> tensor<3x5xi64> {
%0 = stablehlo.reshape %arg1 : (tensor<5xi64>) -> tensor<5x1xi64>
%1 = stablehlo.reshape %arg2 : (tensor<5xi64>) -> tensor<5x1xi64>
%2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
%3 = stablehlo.reshape %arg3 : (tensor<5xi64>) -> tensor<5x1xi64>
%4 = stablehlo.reshape %2 : (tensor<5x2xi64>) -> tensor<5x2xi64>
%5 = "stablehlo.scatter"(%arg0, %4, %3) ({
^bb0(%arg4: tensor<i64>, %arg5: tensor<i64>):
stablehlo.return %arg5 : tensor<i64>
}) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<3x5xi64>, tensor<5x2xi64>, tensor<5x1xi64>) -> tensor<3x5xi64>
return %5 : tensor<3x5xi64>
}
}
'''
Adding 34 indexput related torchdynamo xfail DONE. Next, fix TorchDynamo e2e test crush , add linalg e2e fail. and rewrite tosa indexput. TorchDynamo e2e test:
python -m e2e_testing.main -f "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic"
python3.11: /nodclouddata/chi/src/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h:61: ArrayRef<int64_t> mlir::torch::Torch::BaseTensorType::getSizes() const: Assertion `hasSizes() && "must have sizes"' failed.
[1] 34802 abort (core dumped) python3.11 -m e2e_testing.main -f
This bug is from the decompose of indexput, forget to add return before a rewriter.notifyMatchFailure()
TorchDynamo e2e test crush DONE. Add 41 linalg e2e xfail DONE. https://github.com/llvm/torch-mlir/actions/runs/6026994334/job/16351185680
Fix tosa indexput support DONE. And pass the broke tosa e2e tests IndexPutImpl2DNoneIndexStaticModule_basic.
module attributes {torch.debug_module_name = "IndexPutImpl2DNoneIndexStaticModule"} {
func.func @forward(%arg0: !torch.vtensor<[1,4],si64> loc(unknown), %arg1: !torch.vtensor<[3],si64> loc(unknown), %arg2: !torch.vtensor<[1,3],si64> loc(unknown)) -> !torch.vtensor<[1,4],si64> {
%none = torch.constant.none loc(#loc3)
%int4 = torch.constant.int 4 loc(#loc3)
%int1 = torch.constant.int 1 loc(#loc3)
%int0 = torch.constant.int 0 loc(#loc3)
%int-1 = torch.constant.int -1 loc(#loc3)
%false = torch.constant.bool false loc(#loc2)
%0 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64> loc(#loc3)
%1 = torch.aten.unsqueeze %0, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> loc(#loc3)
%2 = torch.prim.ListConstruct %1, %arg1 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor> loc(#loc3)
%3 = torch.aten.index_put.hacked_twin %arg0, %2, %arg2, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64> loc(#loc3)
return %3 : !torch.vtensor<[1,4],si64> loc(#loc)
} loc(#loc)
} loc(#loc)
->
module attributes {torch.debug_module_name = "IndexPutImpl2DNoneIndexStaticModule"} {
func.func @forward(%arg0: !torch.vtensor<[1,4],si64>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[1,3],si64>) -> !torch.vtensor<[1,4],si64> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
%1 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[1,3],si64> -> tensor<1x3xi64>
%none = torch.constant.none
%int4 = torch.constant.int 4
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%2 = "tosa.const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
%3 = "tosa.cast"(%2) : (tensor<1xi64>) -> tensor<1xi64>
%4 = "tosa.reshape"(%3) <{new_shape = array<i64: 1, 1>}> : (tensor<1xi64>) -> tensor<1x1xi64>
%5 = torch_c.from_builtin_tensor %4 : tensor<1x1xi64> -> !torch.vtensor<[1,1],si64>
%6 = torch.prim.ListConstruct %5, %arg1 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%7 = torch_c.to_builtin_tensor %5 : !torch.vtensor<[1,1],si64> -> tensor<1x1xi64>
%8 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],si64> -> tensor<3xi64>
%9 = "tosa.cast"(%7) : (tensor<1x1xi64>) -> tensor<1x1xi32>
%10 = "tosa.reshape"(%9) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1x1xi32>) -> tensor<1x1x1xi32>
%11 = "tosa.const"() <{value = dense<0> : tensor<1x3x1xi32>}> : () -> tensor<1x3x1xi32>
%12 = "tosa.add"(%10, %11) : (tensor<1x1x1xi32>, tensor<1x3x1xi32>) -> tensor<1x3x1xi32>
%13 = "tosa.cast"(%8) : (tensor<3xi64>) -> tensor<3xi32>
%14 = "tosa.reshape"(%13) <{new_shape = array<i64: 3, 1>}> : (tensor<3xi32>) -> tensor<3x1xi32>
%15 = "tosa.const"() <{value = dense<0> : tensor<1x3x1xi32>}> : () -> tensor<1x3x1xi32>
%16 = "tosa.add"(%14, %15) : (tensor<3x1xi32>, tensor<1x3x1xi32>) -> tensor<1x3x1xi32>
%17 = "tosa.concat"(%12, %16) <{axis = 2 : i64}> : (tensor<1x3x1xi32>, tensor<1x3x1xi32>) -> tensor<1x3x2xi32>
%18 = "tosa.reshape"(%1) <{new_shape = array<i64: 1, 3, 1>}> : (tensor<1x3xi64>) -> tensor<1x3x1xi64>
%19 = "tosa.reshape"(%0) <{new_shape = array<i64: 1, 4, 1>}> : (tensor<1x4xi64>) -> tensor<1x4x1xi64>
%20 = "tosa.reshape"(%17) <{new_shape = array<i64: 3, 2>}> : (tensor<1x3x2xi32>) -> tensor<3x2xi32>
%21 = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
%22 = "tosa.mul"(%20, %21) <{shift = 0 : i32}> : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32>
%23 = "tosa.reduce_sum"(%22) <{axis = 1 : i64}> : (tensor<3x2xi32>) -> tensor<3x1xi32>
%24 = "tosa.reshape"(%23) <{new_shape = array<i64: 1, 3>}> : (tensor<3x1xi32>) -> tensor<1x3xi32>
%25 = "tosa.scatter"(%19, %24, %18) : (tensor<1x4x1xi64>, tensor<1x3xi32>, tensor<1x3x1xi64>) -> tensor<1x4x1xi64>
%26 = "tosa.reshape"(%25) <{new_shape = array<i64: 1, 4>}> : (tensor<1x4x1xi64>) -> tensor<1x4xi64>
%27 = torch_c.from_builtin_tensor %26 : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
return %27 : !torch.vtensor<[1,4],si64>
}
}
Accidentally merge a commit from another WIP patch, just clean it.
Remove the iostream and stablehlo crash line DONE.
Get the linalg e2e xfails and 33 torchdynamo xfail fixed by changing the TMTensor lowering ops name from indexput to indexputHackedTwin. But in linalg/tmtensor, 2 more e2e xfail still need to be fixed.
"IndexPutImplIndexWithNoneModule_basic",
"SliceCopyNonZeroDim_Module_basic",
The IndexPutImplIndexWithNoneModule_basic is brought in from https://github.com/llvm/torch-mlir/pull/1762, it would be great that @ramiro050 and @vivekkhandelwal1 could take a look at it and give some advice, the algorithm in the TMtensor indexput implementation looks not intuitive for me.
After decompose, it will looks like:
module attributes {torch.debug_module_name = "IndexPutImplIndexWithNoneModule"} {
func.func @forward(%arg0: !torch.vtensor<[2,3,4,5],f32> loc(unknown), %arg1: !torch.vtensor<[6,1],si64> loc(unknown), %arg2: !torch.vtensor<[7],si64> loc(unknown), %arg3: !torch.vtensor<[2,3,6,7],f32> loc(unknown)) -> !torch.vtensor<[2,3,4,5],f32> {
%none = torch.constant.none loc(#loc1)
%int4 = torch.constant.int 4 loc(#loc4)
%int3 = torch.constant.int 3 loc(#loc4)
%int0 = torch.constant.int 0 loc(#loc4)
%int1 = torch.constant.int 1 loc(#loc4)
%int-1 = torch.constant.int -1 loc(#loc4)
%int2 = torch.constant.int 2 loc(#loc4)
%true = torch.constant.bool true loc(#loc3)
%0 = torch.aten.arange.start_step %int0, %int3, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> loc(#loc4)
%1 = torch.aten.unsqueeze %0, %int-1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,1],si64> loc(#loc4)
%2 = torch.aten.unsqueeze %1, %int-1 : !torch.vtensor<[3,1],si64>, !torch.int -> !torch.vtensor<[3,1,1],si64> loc(#loc4)
%3 = torch.aten.arange.start_step %int0, %int2, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> loc(#loc4)
%4 = torch.aten.unsqueeze %3, %int-1 : !torch.vtensor<[2],si64>, !torch.int -> !torch.vtensor<[2,1],si64> loc(#loc4)
%5 = torch.aten.unsqueeze %4, %int-1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1,1],si64> loc(#loc4)
%6 = torch.aten.unsqueeze %5, %int-1 : !torch.vtensor<[2,1,1],si64>, !torch.int -> !torch.vtensor<[2,1,1,1],si64> loc(#loc4)
%7 = torch.prim.ListConstruct %6, %2, %arg1, %arg2 : (!torch.vtensor<[2,1,1,1],si64>, !torch.vtensor<[3,1,1],si64>, !torch.vtensor<[6,1],si64>, !torch.vtensor<[7],si64>) -> !torch.list<vtensor> loc(#loc4)
%8 = torch.aten.index_put.hacked_twin %arg0, %7, %arg3, %true : !torch.vtensor<[2,3,4,5],f32>, !torch.list<vtensor>, !torch.vtensor<[2,3,6,7],f32>, !torch.bool -> !torch.vtensor<[2,3,4,5],f32> loc(#loc4)
return %8 : !torch.vtensor<[2,3,4,5],f32> loc(#loc)
} loc(#loc)
}
Let me take a look at the xfails and see if I find anything
The issue is that the TorchToTMTensor pattern expects the op to only have 2 tensors in the indices list of tensors (for example: (None, None, tensor1, tensor2) would be valid). However, because now we are replacing the Nones with tensors, the check for 2 tensors fails. We need to generalize the pattern to handle more than 2 tensors. I can help out with this on Friday.