torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[Stablehlo] Add accumulate support for AtenIndexPutHackedTwinOp

Open AmosLewis opened this issue 2 years ago • 1 comments

AmosLewis avatar Sep 08 '23 05:09 AmosLewis

torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/models/t5/slicecopy/EmbeddingBagDenseBackwardsModule.mlir

#loc = loc(unknown)
module attributes {torch.debug_module_name = "EmbeddingBagDenseBackwardModule"} {
  func.func @forward(%arg0: !torch.vtensor<[3,2],f32> loc(unknown), %arg1: !torch.vtensor<[3],si64> loc(unknown), %arg2: !torch.vtensor<[3],si64> loc(unknown), %arg3: !torch.vtensor<[1],si64> loc(unknown)) -> !torch.vtensor<[2,2],f32> {    
    %int2 = torch.constant.int 2 loc(#loc2)
    %none = torch.constant.none loc(#loc2)
    %int6 = torch.constant.int 6 loc(#loc2)
    %true = torch.constant.bool true loc(#loc2)
    %0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc2)
    %1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,2],f32> loc(#loc2)    
    %2 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[3],si64>) -> !torch.list<vtensor> loc(#loc2)
    %3 = torch.aten.index_put.hacked_twin %1, %2, %arg0, %true : !torch.vtensor<[2,2],f32>, !torch.list<vtensor>, !torch.vtensor<[3,2],f32>, !torch.bool -> !torch.vtensor<[2,2],f32> loc(#loc2)
    return %3 : !torch.vtensor<[2,2],f32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc1 = loc("/home/nithin/torch-mlir/build-debug/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/backprop.py":324:15)
#loc2 = loc("aten::_embedding_bag_dense_backward"(#loc1))
module attributes {torch.debug_module_name = "EmbeddingBagDenseBackwardModule"} {
  func.func @forward(%arg0: !torch.vtensor<[3,2],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,2],f32> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32>
    %int2 = torch.constant.int 2
    %none = torch.constant.none
    %int6 = torch.constant.int 6
    %true = torch.constant.bool true
    %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = stablehlo.constant dense<0> : tensor<2x2xi32>
    %3 = stablehlo.convert %2 : (tensor<2x2xi32>) -> tensor<2x2xf32>
    %4 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[3],si64>) -> !torch.list<vtensor>
    %5 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],si64> -> tensor<3xi64>
    %6 = stablehlo.reshape %5 : (tensor<3xi64>) -> tensor<3x1xi64>
    %7 = stablehlo.concatenate %6, dim = 1 : (tensor<3x1xi64>) -> tensor<3x1xi64>
    %8 = stablehlo.reshape %0 : (tensor<3x2xf32>) -> tensor<3x2xf32>
    %9 = stablehlo.reshape %7 : (tensor<3x1xi64>) -> tensor<3x1xi64>
    %10 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<3x1xi64>) -> tensor<3x2xi64>
    %11 = "stablehlo.scatter"(%3, %10, %8) ({
    ^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>):
      %13 = stablehlo.add %arg4, %arg5 : tensor<f32>
      stablehlo.return %13 : tensor<f32>
    }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<2x2xf32>, tensor<3x2xi64>, tensor<3x2xf32>) -> tensor<2x2xf32>
    %12 = torch_c.from_builtin_tensor %11 : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
    return %12 : !torch.vtensor<[2,2],f32>
  }
}

AmosLewis avatar Sep 08 '23 07:09 AmosLewis