torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Decompose aten.index_put_impl to aten.index_put.hacked_twin

Open AmosLewis opened this issue 2 years ago • 12 comments

This patch is to remove the none-index in aten.index_put, then lower it to stablehlo and tosa. This patch will use and change the commit in [Stablehlo] Add converter for aten._index_put_impl op #2343. Then add a decompose pass, which will lead to rewritting of tosa index_put and some torchdynamo/linalg bug. The decompose pass is mostly inspired from https://github.com/llvm/torch-mlir/pull/2344. And rewrite some tosa indexput code for the new decompose pass.

Demand from https://github.com/nod-ai/SHARK/issues/1336 New generated T5 model: https://storage.googleapis.com/shark_tank/chi-nod/t5_small/stablehlo/t5small_stablehlo_0829_transformers4.26.0_elide.mlir

6 different index broadcast cases need to be supported:

# First 3 cases the index2 is torch.Size([3])
# Case 1
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0]])
index2 = torch.tensor([1,2,3])
update =  torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 1])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 2
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0,0,0]])
index2 = torch.tensor([1,2,3])
update =  torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 3
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([0,0,0])
index2 = torch.tensor([1,2,3])
update =  torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])

# Next 3 cases Change the index2 into torch.Size([1, 3])
# Case 4
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0]]) # torch.Size([1,1])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update =  torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 1])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 5
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0,0,0]]) # torch.Size([1, 3])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update =  torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])
#Case 6
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([0,0,0]) # torch.Size([3])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update =  torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([3])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])

For anyone( @Vremold @RamirezLucas @mgehre-amd @vivekkhandelwal1 @eric-k256 ) who might review/test, you can cherry-pick and use this test code directly test_indexput_hacketwin_35.py:

import torch
import torch_mlir
class Net(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, input, index1, index2, src):
        return torch.index_put(input, indices=(index1, index2), values=src, accumulate=False)
m = Net()
src = torch.arange(1, 6)
index1 = torch.tensor([0, 0, 0, 0, 0])
index2 = torch.tensor([1, 2, 3, 4, 0])
input = torch.arange(10, 25, step=1, dtype=src.dtype).view(3, 5)
m = torch_mlir.compile(m, [input, index1, index2, src], output_type="stablehlo")
print(m.operation.get_asm())
m = torch_mlir.compile(m, [input, index1, index2, src], output_type="tosa")
print(m.operation.get_asm())

AmosLewis avatar Aug 21 '23 01:08 AmosLewis

When tested with the t5 model, it breaks before the index_put op is generated by slice and copy pattern.

  File "/nodclouddata/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 69, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:


python exception: Failure while executing pass pipeline:
error: "aten::copy_"("<eval_with_key>.2 from /nodclouddata/chi/src/SHARK/shark_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":9:12): found an op that was marked as backend illegal
note: "aten::copy_"("<eval_with_key>.2 from /nodclouddata/chi/src/SHARK/shark_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":9:12): see current operation: %166 = "torch.aten.index_put.hacked_twin"(%159, %165, %161, %7) : (!torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool) -> !torch.vtensor<[1,4],si64>
note: "aten::copy_"("<eval_with_key>.2 from /nodclouddata/chi/src/SHARK/shark_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":9:12): this is likely due to DecomposeComplexOps being unable to decompose this op

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops= extra-library=})' /tmp/_lambda.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

AmosLewis avatar Aug 21 '23 01:08 AmosLewis

test_slicecopy.py Get the decompose pass. But then lower to stablehlo bug.

Legalizing operation : 'torch.aten.index_put.hacked_twin'(0xeb61fd0) {
  %90 = "torch.aten.index_put.hacked_twin"(%15, %89, %49, %10) : (!torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool) -> !torch.vtensor<[1,4],si64>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.aten.index_put.hacked_twin -> ()' {
Trying to match "mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenIndexPutHackedTwinOp>"
    ** Insert  : 'torch_c.to_builtin_tensor'(0xebdf150)
    ** Insert  : 'torch_c.to_builtin_tensor'(0xebdf1e0)
    ** Insert  : 'stablehlo.reshape'(0xebdf2e0)
    ** Insert  : 'stablehlo.reshape'(0xebdf440)
    ** Failure : unimplemented: Only support multi indexes with same shape
"mlir::torch::torch_to_stablehlo::ConvertAtenOp<mlir::torch::Torch::AtenIndexPutHackedTwinOp>" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
/nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir:17:10: error: failed to legalize operation 'torch.aten.index_put.hacked_twin' that was explicitly marked illegal
    %8 = torch.aten.index_put.hacked_twin %1, %7, %3, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>

AmosLewis avatar Aug 22 '23 01:08 AmosLewis

indexput_hacked_twin_debuginfo.txt error: number of output elements (3) doesn't match expected number of elements (1)

the issue is with this https://github.com/llvm/torch-mlir/pull/2407/files#diff-37f576362dfffc48c6d013a2edd9357d53d2bb08345a7d8e7f7217cea90701c8R610-R611 reshape op. It generates this op %92 = "stablehlo.reshape"(%90) : (tensor<1x1xi64>) -> tensor<1x3x1xi64> which is not reshape but reshape and broadcast, this kind of reshape is not supported in stablehlo We have to first reshape to tensor<1x1xi64> -> tensor<1x1x1xi64>, and then broadcast it to tensor<1x1x1xi64> -> tensor<1x3x1xi64>

AmosLewis avatar Aug 29 '23 06:08 AmosLewis

indexput.hacked_twin to stablehlo Done. Here is the Python e2e tests.

import torch
import torch_mlir
class Net(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, input, index1, index2, src):
        return torch.index_put(input, indices=(index1, index2), values=src, accumulate=False)
m = Net()
src = torch.arange(1, 6)
index1 = torch.tensor([0, 0, 0, 0, 0])
index2 = torch.tensor([1, 2, 3, 4, 0])
input = torch.arange(10, 25, step=1, dtype=src.dtype).view(3, 5)
m = torch_mlir.compile(m, [input, index1, index2, src], output_type="stablehlo")
print(m.operation.get_asm())

'''
module attributes {torch.debug_module_name = "Net"} {
  func.func @forward(%arg0: tensor<3x5xi64>, %arg1: tensor<5xi64>, %arg2: tensor<5xi64>, %arg3: tensor<5xi64>) -> tensor<3x5xi64> {
    %0 = stablehlo.reshape %arg1 : (tensor<5xi64>) -> tensor<5x1xi64>
    %1 = stablehlo.reshape %arg2 : (tensor<5xi64>) -> tensor<5x1xi64>
    %2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
    %3 = stablehlo.reshape %arg3 : (tensor<5xi64>) -> tensor<5x1xi64>
    %4 = stablehlo.reshape %2 : (tensor<5x2xi64>) -> tensor<5x2xi64>
    %5 = "stablehlo.scatter"(%arg0, %4, %3) ({
    ^bb0(%arg4: tensor<i64>, %arg5: tensor<i64>):
      stablehlo.return %arg5 : tensor<i64>
    }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<3x5xi64>, tensor<5x2xi64>, tensor<5x1xi64>) -> tensor<3x5xi64>
    return %5 : tensor<3x5xi64>
  }
}
'''

AmosLewis avatar Aug 29 '23 23:08 AmosLewis

Adding 34 indexput related torchdynamo xfail DONE. Next, fix TorchDynamo e2e test crush , add linalg e2e fail. and rewrite tosa indexput. TorchDynamo e2e test:

python -m e2e_testing.main -f "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic"
python3.11: /nodclouddata/chi/src/torch-mlir/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h:61: ArrayRef<int64_t> mlir::torch::Torch::BaseTensorType::getSizes() const: Assertion `hasSizes() && "must have sizes"' failed.
[1]    34802 abort (core dumped)  python3.11 -m e2e_testing.main -f 

This bug is from the decompose of indexput, forget to add return before a rewriter.notifyMatchFailure()

AmosLewis avatar Aug 30 '23 01:08 AmosLewis

TorchDynamo e2e test crush DONE. Add 41 linalg e2e xfail DONE. https://github.com/llvm/torch-mlir/actions/runs/6026994334/job/16351185680

AmosLewis avatar Aug 30 '23 16:08 AmosLewis

Fix tosa indexput support DONE. And pass the broke tosa e2e tests IndexPutImpl2DNoneIndexStaticModule_basic.

module attributes {torch.debug_module_name = "IndexPutImpl2DNoneIndexStaticModule"} {
  func.func @forward(%arg0: !torch.vtensor<[1,4],si64> loc(unknown), %arg1: !torch.vtensor<[3],si64> loc(unknown), %arg2: !torch.vtensor<[1,3],si64> loc(unknown)) -> !torch.vtensor<[1,4],si64> {
    %none = torch.constant.none loc(#loc3)
    %int4 = torch.constant.int 4 loc(#loc3)
    %int1 = torch.constant.int 1 loc(#loc3)
    %int0 = torch.constant.int 0 loc(#loc3)
    %int-1 = torch.constant.int -1 loc(#loc3)
    %false = torch.constant.bool false loc(#loc2)
    %0 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64> loc(#loc3)
    %1 = torch.aten.unsqueeze %0, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> loc(#loc3)
    %2 = torch.prim.ListConstruct %1, %arg1 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor> loc(#loc3)
    %3 = torch.aten.index_put.hacked_twin %arg0, %2, %arg2, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64> loc(#loc3)
    return %3 : !torch.vtensor<[1,4],si64> loc(#loc)
  } loc(#loc)
} loc(#loc)

->

module attributes {torch.debug_module_name = "IndexPutImpl2DNoneIndexStaticModule"} {
  func.func @forward(%arg0: !torch.vtensor<[1,4],si64>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[1,3],si64>) -> !torch.vtensor<[1,4],si64> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4],si64> -> tensor<1x4xi64>
    %1 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[1,3],si64> -> tensor<1x3xi64>
    %none = torch.constant.none
    %int4 = torch.constant.int 4
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %int-1 = torch.constant.int -1
    %false = torch.constant.bool false
    %2 = "tosa.const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>
    %3 = "tosa.cast"(%2) : (tensor<1xi64>) -> tensor<1xi64>
    %4 = "tosa.reshape"(%3) <{new_shape = array<i64: 1, 1>}> : (tensor<1xi64>) -> tensor<1x1xi64>
    %5 = torch_c.from_builtin_tensor %4 : tensor<1x1xi64> -> !torch.vtensor<[1,1],si64>
    %6 = torch.prim.ListConstruct %5, %arg1 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
    %7 = torch_c.to_builtin_tensor %5 : !torch.vtensor<[1,1],si64> -> tensor<1x1xi64>
    %8 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],si64> -> tensor<3xi64>
    %9 = "tosa.cast"(%7) : (tensor<1x1xi64>) -> tensor<1x1xi32>
    %10 = "tosa.reshape"(%9) <{new_shape = array<i64: 1, 1, 1>}> : (tensor<1x1xi32>) -> tensor<1x1x1xi32>
    %11 = "tosa.const"() <{value = dense<0> : tensor<1x3x1xi32>}> : () -> tensor<1x3x1xi32>
    %12 = "tosa.add"(%10, %11) : (tensor<1x1x1xi32>, tensor<1x3x1xi32>) -> tensor<1x3x1xi32>
    %13 = "tosa.cast"(%8) : (tensor<3xi64>) -> tensor<3xi32>
    %14 = "tosa.reshape"(%13) <{new_shape = array<i64: 3, 1>}> : (tensor<3xi32>) -> tensor<3x1xi32>
    %15 = "tosa.const"() <{value = dense<0> : tensor<1x3x1xi32>}> : () -> tensor<1x3x1xi32>
    %16 = "tosa.add"(%14, %15) : (tensor<3x1xi32>, tensor<1x3x1xi32>) -> tensor<1x3x1xi32>
    %17 = "tosa.concat"(%12, %16) <{axis = 2 : i64}> : (tensor<1x3x1xi32>, tensor<1x3x1xi32>) -> tensor<1x3x2xi32>
    %18 = "tosa.reshape"(%1) <{new_shape = array<i64: 1, 3, 1>}> : (tensor<1x3xi64>) -> tensor<1x3x1xi64>
    %19 = "tosa.reshape"(%0) <{new_shape = array<i64: 1, 4, 1>}> : (tensor<1x4xi64>) -> tensor<1x4x1xi64>
    %20 = "tosa.reshape"(%17) <{new_shape = array<i64: 3, 2>}> : (tensor<1x3x2xi32>) -> tensor<3x2xi32>
    %21 = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
    %22 = "tosa.mul"(%20, %21) <{shift = 0 : i32}> : (tensor<3x2xi32>, tensor<2xi32>) -> tensor<3x2xi32>
    %23 = "tosa.reduce_sum"(%22) <{axis = 1 : i64}> : (tensor<3x2xi32>) -> tensor<3x1xi32>
    %24 = "tosa.reshape"(%23) <{new_shape = array<i64: 1, 3>}> : (tensor<3x1xi32>) -> tensor<1x3xi32>
    %25 = "tosa.scatter"(%19, %24, %18) : (tensor<1x4x1xi64>, tensor<1x3xi32>, tensor<1x3x1xi64>) -> tensor<1x4x1xi64>
    %26 = "tosa.reshape"(%25) <{new_shape = array<i64: 1, 4>}> : (tensor<1x4x1xi64>) -> tensor<1x4xi64>
    %27 = torch_c.from_builtin_tensor %26 : tensor<1x4xi64> -> !torch.vtensor<[1,4],si64>
    return %27 : !torch.vtensor<[1,4],si64>
  }
}

AmosLewis avatar Sep 06 '23 02:09 AmosLewis

Accidentally merge a commit from another WIP patch, just clean it.

AmosLewis avatar Sep 07 '23 02:09 AmosLewis

Remove the iostream and stablehlo crash line DONE.

AmosLewis avatar Sep 08 '23 00:09 AmosLewis

Get the linalg e2e xfails and 33 torchdynamo xfail fixed by changing the TMTensor lowering ops name from indexput to indexputHackedTwin. But in linalg/tmtensor, 2 more e2e xfail still need to be fixed.

    "IndexPutImplIndexWithNoneModule_basic",
    "SliceCopyNonZeroDim_Module_basic",

The IndexPutImplIndexWithNoneModule_basic is brought in from https://github.com/llvm/torch-mlir/pull/1762, it would be great that @ramiro050 and @vivekkhandelwal1 could take a look at it and give some advice, the algorithm in the TMtensor indexput implementation looks not intuitive for me.

After decompose, it will looks like:

module attributes {torch.debug_module_name = "IndexPutImplIndexWithNoneModule"} {
  func.func @forward(%arg0: !torch.vtensor<[2,3,4,5],f32> loc(unknown), %arg1: !torch.vtensor<[6,1],si64> loc(unknown), %arg2: !torch.vtensor<[7],si64> loc(unknown), %arg3: !torch.vtensor<[2,3,6,7],f32> loc(unknown)) -> !torch.vtensor<[2,3,4,5],f32> {
    %none = torch.constant.none loc(#loc1)
    %int4 = torch.constant.int 4 loc(#loc4)
    %int3 = torch.constant.int 3 loc(#loc4)
    %int0 = torch.constant.int 0 loc(#loc4)
    %int1 = torch.constant.int 1 loc(#loc4)
    %int-1 = torch.constant.int -1 loc(#loc4)
    %int2 = torch.constant.int 2 loc(#loc4)
    %true = torch.constant.bool true loc(#loc3)
    %0 = torch.aten.arange.start_step %int0, %int3, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> loc(#loc4)
    %1 = torch.aten.unsqueeze %0, %int-1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,1],si64> loc(#loc4)
    %2 = torch.aten.unsqueeze %1, %int-1 : !torch.vtensor<[3,1],si64>, !torch.int -> !torch.vtensor<[3,1,1],si64> loc(#loc4)
    %3 = torch.aten.arange.start_step %int0, %int2, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> loc(#loc4)
    %4 = torch.aten.unsqueeze %3, %int-1 : !torch.vtensor<[2],si64>, !torch.int -> !torch.vtensor<[2,1],si64> loc(#loc4)
    %5 = torch.aten.unsqueeze %4, %int-1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1,1],si64> loc(#loc4)
    %6 = torch.aten.unsqueeze %5, %int-1 : !torch.vtensor<[2,1,1],si64>, !torch.int -> !torch.vtensor<[2,1,1,1],si64> loc(#loc4)
    %7 = torch.prim.ListConstruct %6, %2, %arg1, %arg2 : (!torch.vtensor<[2,1,1,1],si64>, !torch.vtensor<[3,1,1],si64>, !torch.vtensor<[6,1],si64>, !torch.vtensor<[7],si64>) -> !torch.list<vtensor> loc(#loc4)
    %8 = torch.aten.index_put.hacked_twin %arg0, %7, %arg3, %true : !torch.vtensor<[2,3,4,5],f32>, !torch.list<vtensor>, !torch.vtensor<[2,3,6,7],f32>, !torch.bool -> !torch.vtensor<[2,3,4,5],f32> loc(#loc4)
    return %8 : !torch.vtensor<[2,3,4,5],f32> loc(#loc)
  } loc(#loc)
}

AmosLewis avatar Sep 12 '23 22:09 AmosLewis

Let me take a look at the xfails and see if I find anything

ramiro050 avatar Sep 18 '23 22:09 ramiro050

The issue is that the TorchToTMTensor pattern expects the op to only have 2 tensors in the indices list of tensors (for example: (None, None, tensor1, tensor2) would be valid). However, because now we are replacing the Nones with tensors, the check for 2 tensors fails. We need to generalize the pattern to handle more than 2 tensors. I can help out with this on Friday.

ramiro050 avatar Sep 20 '23 18:09 ramiro050