torch-mlir
torch-mlir copied to clipboard
[Stablehlo] Add accumulate support for AtenIndexPutHackedTwinOp
torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/models/t5/slicecopy/EmbeddingBagDenseBackwardsModule.mlir
#loc = loc(unknown)
module attributes {torch.debug_module_name = "EmbeddingBagDenseBackwardModule"} {
func.func @forward(%arg0: !torch.vtensor<[3,2],f32> loc(unknown), %arg1: !torch.vtensor<[3],si64> loc(unknown), %arg2: !torch.vtensor<[3],si64> loc(unknown), %arg3: !torch.vtensor<[1],si64> loc(unknown)) -> !torch.vtensor<[2,2],f32> {
%int2 = torch.constant.int 2 loc(#loc2)
%none = torch.constant.none loc(#loc2)
%int6 = torch.constant.int 6 loc(#loc2)
%true = torch.constant.bool true loc(#loc2)
%0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc2)
%1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,2],f32> loc(#loc2)
%2 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[3],si64>) -> !torch.list<vtensor> loc(#loc2)
%3 = torch.aten.index_put.hacked_twin %1, %2, %arg0, %true : !torch.vtensor<[2,2],f32>, !torch.list<vtensor>, !torch.vtensor<[3,2],f32>, !torch.bool -> !torch.vtensor<[2,2],f32> loc(#loc2)
return %3 : !torch.vtensor<[2,2],f32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/home/nithin/torch-mlir/build-debug/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/backprop.py":324:15)
#loc2 = loc("aten::_embedding_bag_dense_backward"(#loc1))
module attributes {torch.debug_module_name = "EmbeddingBagDenseBackwardModule"} {
func.func @forward(%arg0: !torch.vtensor<[3,2],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,2],f32> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32>
%int2 = torch.constant.int 2
%none = torch.constant.none
%int6 = torch.constant.int 6
%true = torch.constant.bool true
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = stablehlo.constant dense<0> : tensor<2x2xi32>
%3 = stablehlo.convert %2 : (tensor<2x2xi32>) -> tensor<2x2xf32>
%4 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[3],si64>) -> !torch.list<vtensor>
%5 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],si64> -> tensor<3xi64>
%6 = stablehlo.reshape %5 : (tensor<3xi64>) -> tensor<3x1xi64>
%7 = stablehlo.concatenate %6, dim = 1 : (tensor<3x1xi64>) -> tensor<3x1xi64>
%8 = stablehlo.reshape %0 : (tensor<3x2xf32>) -> tensor<3x2xf32>
%9 = stablehlo.reshape %7 : (tensor<3x1xi64>) -> tensor<3x1xi64>
%10 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<3x1xi64>) -> tensor<3x2xi64>
%11 = "stablehlo.scatter"(%3, %10, %8) ({
^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>):
%13 = stablehlo.add %arg4, %arg5 : tensor<f32>
stablehlo.return %13 : tensor<f32>
}) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<2x2xf32>, tensor<3x2xi64>, tensor<3x2xf32>) -> tensor<2x2xf32>
%12 = torch_c.from_builtin_tensor %11 : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
return %12 : !torch.vtensor<[2,2],f32>
}
}