[ONNX] Support onnx.LSTM
Test case draft: lstm.onnx.mlir
module {
func.func @lstm(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
}
}
current Torch IR
```mlir module { func.func @lstm(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor, %arg3: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none %int0 = torch.constant.int 0 %int0_0 = torch.constant.int 0 %0 = torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_1 = torch.constant.int 0 %int0_2 = torch.constant.int 0 %1 = torch.aten.select.int %arg2, %int0_1, %int0_2 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_3 = torch.constant.int 0 %int0_4 = torch.constant.int 0 %2 = torch.aten.select.int %arg3, %int0_3, %int0_4 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %none_5 = torch.constant.none %int0_6 = torch.constant.int 0 %int1_7 = torch.constant.int 1 %3 = torch.prim.ListConstruct %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list
</details>
Currently testing with:
/home/azureuser/torch-mlir/build/bin/torch-mlir-opt -pass-pipeline='builtin.module(func.func(convert-torch-onnx-to-torch))' lstm.onnx.mlir -o lstm.torch.mlir
~/torch-mlir/build/bin/torch-mlir-opt --mlir-print-debuginfo --mlir-elide-elementsattrs-if-larger=16 --mlir-print-stacktrace-on-diagnostic --mlir-disable-threading --mlir-print-ir-after-failure --mlir-print-ir-module-scope -pass-pipeline='builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)' ./lstm.torch.mlir -o ./lstm.linalg.mlir 2>&1 | tee torchtolinalg.log
encountered a bunch of issues earlier, fixed thanks to Rob.
Currently stuck on:
./lstm.torch.mlir:61:13: error: 'torch_c.to_builtin_tensor' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got 'tensor' %30 = torch.aten.mul.Tensor %29, %arg6 : !torch.vtensor, !torch.vtensor -> !torch.vtensor ^
./lstm.torch.mlir:61:13: note: diagnostic emitted with trace: #0 0x000055b8b5e22e2d llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/azureuser/torch-mlir/externals/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:11 #1 0x000055b8b5bf959e emitDiag(mlir::Location, mlir::DiagnosticSeverity, llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Diagnostics.cpp:319:5 #2 0x000055b8b5bf94c5 mlir::emitError(mlir::Location, llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Diagnostics.cpp:330:10 #3 0x000055b8b5c9f888 mlir::Operation::emitError(llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Operation.cpp:269:29 #4 0x000055b8b5c9f359 mlir::Operation::emitOpError(llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Operation.cpp:672:22 #5 0x000055b8b2d9c635 mlir::torch::TorchConversion::__mlir_ods_local_type_constraint_TorchConversionOps1(mlir::Operation*, mlir::Type, llvm::StringRef, unsigned int) /home/azureuser/torch-mlir/build/tools/torch-mlir/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc:52:39 #6 0x000055b8b2da4a46 mlir::torch::TorchConversion::ToBuiltinTensorOp::verifyInvariantsImpl() /home/azureuser/torch-mlir/build/tools/torch-mlir/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc:1499:26 #7 0x000055b8b2d93cc2 mlir::OpTrait::OpInvariants<:torch::torchconversion::tobuiltintensorop>::verifyTrait(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:432:35 #8 0x000055b8b2d93b05 std::enable_if./lstm.torch.mlir:61:13: note: see current operation: %45 = "torch_c.to_builtin_tensor"(%arg6) : (tensor<2x3xf32>) -> tensor<2x3xf32> loc("./lstm.torch.mlir":61:13)
// -----// IR Dump After ConvertTorchToSCF Failed (convert-torch-to-scf) ('func.func' operation: @lstm) //----- //
#loc2 = loc("./lstm.torch.mlir":2:19)
#loc3 = loc("./lstm.torch.mlir":2:56)
#loc4 = loc("./lstm.torch.mlir":2:93)
#loc5 = loc("./lstm.torch.mlir":2:130)
#loc6 = loc(unknown)
#loc7 = loc("./lstm.torch.mlir":38:13)
#loc18 = loc("./lstm.torch.mlir":61:13)
#loc24 = loc("./lstm.torch.mlir":49:13)
#loc26 = loc("./lstm.torch.mlir":54:13)
#loc27 = loc("./lstm.torch.mlir":55:13)
#loc28 = loc("./lstm.torch.mlir":56:13)
#loc29 = loc("./lstm.torch.mlir":57:13)
#loc30 = loc("./lstm.torch.mlir":58:13)
#loc31 = loc("./lstm.torch.mlir":59:13)
#loc32 = loc("./lstm.torch.mlir":60:13)
#loc33 = loc("./lstm.torch.mlir":62:13)
#loc34 = loc("./lstm.torch.mlir":63:13)
#loc35 = loc("./lstm.torch.mlir":64:13)
#loc36 = loc("./lstm.torch.mlir":65:13)
#map = affine_map<(d0, d1) -> (d0, d1)>
"builtin.module"() ({
"func.func"() <{function_type = (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>), sym_name = "lstm"}> ({
^bb0(%arg0: !torch.vtensor<[15,2,4],f32> loc("./lstm.torch.mlir":2:19), %arg1: !torch.vtensor<[1,12,4],f32> loc("./lstm.torch.mlir":2:56), %arg2: !torch.vtensor<[1,12,3],f32> loc("./lstm.torch.mlir":2:93), %arg3: !torch.vtensor<[1,24],f32> loc("./lstm.torch.mlir":2:130)):
%0 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32 loc(#loc6)
%1 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32 loc(#loc6)
%2 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int loc(#loc6)
%3 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int loc(#loc6)
%4 = "torch.constant.int"() <{value = 15 : i64}> : () -> !torch.int loc(#loc6)
%5 = "torch_c.to_i64"(%4) : (!torch.int) -> i64 loc(#loc7)
%6 = "torch.constant.bool"() <{value = true}> : () -> !torch.bool loc(#loc6)
%7 = "torch.constant.int"() <{value = 4 : i64}> : () -> !torch.int loc(#loc6)
%8 = "torch.aten.select.int"(%arg1, %2, %2) : (!torch.vtensor<[1,12,4],f32>, !torch.int, !torch.int) -> !torch.vtensor<[12,4],f32> loc(#loc8)
%9 = "torch.aten.select.int"(%arg2, %2, %2) : (!torch.vtensor<[1,12,3],f32>, !torch.int, !torch.int) -> !torch.vtensor<[12,3],f32> loc(#loc9)
%10 = "torch.aten.select.int"(%arg3, %2, %2) : (!torch.vtensor<[1,24],f32>, !torch.int, !torch.int) -> !torch.vtensor<[24],f32> loc(#loc10)
%11 = "torch_c.to_builtin_tensor"(%10) : (!torch.vtensor<[24],f32>) -> tensor<24xf32> loc(#loc11)
%12 = "tensor.empty"() : () -> tensor<1x2x3xf32> loc(#loc12)
%13 = "linalg.fill"(%1, %12) <{operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg4: f32 loc(unknown), %arg5: f32 loc(unknown)):
"linalg.yield"(%arg4) : (f32) -> () loc(#loc6)
}) : (f32, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> loc(#loc12)
%14 = "torch_c.from_builtin_tensor"(%13) : (tensor<1x2x3xf32>) -> !torch.vtensor<[1,2,3],f32> loc(#loc12)
%15 = "tensor.empty"() : () -> tensor<1x2x3xf32> loc(#loc13)
%16 = "linalg.fill"(%1, %15) <{operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg4: f32 loc(unknown), %arg5: f32 loc(unknown)):
"linalg.yield"(%arg4) : (f32) -> () loc(#loc6)
}) : (f32, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> loc(#loc13)
%17 = "torch_c.from_builtin_tensor"(%16) : (tensor<1x2x3xf32>) -> !torch.vtensor<[1,2,3],f32> loc(#loc13)
%18 = "torch.aten.select.int"(%14, %2, %2) : (!torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int) -> !torch.vtensor<[2,3],f32> loc(#loc14)
%19 = "torch_c.to_builtin_tensor"(%18) : (!torch.vtensor<[2,3],f32>) -> tensor<2x3xf32> loc(#loc7)
%20 = "torch.aten.select.int"(%17, %2, %2) : (!torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int) -> !torch.vtensor<[2,3],f32> loc(#loc15)
%21 = "torch_c.to_builtin_tensor"(%20) : (!torch.vtensor<[2,3],f32>) -> tensor<2x3xf32> loc(#loc7)
%22 = "tensor.extract_slice"(%11) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0>, static_sizes = array<i64: 12>, static_strides = array<i64: 1>}> : (tensor<24xf32>) -> tensor<12xf32> loc(#loc11)
%23 = "torch_c.from_builtin_tensor"(%22) : (tensor<12xf32>) -> !torch.vtensor<[12],f32> loc(#loc11)
%24 = "tensor.extract_slice"(%11) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 12>, static_sizes = array<i64: 12>, static_strides = array<i64: 1>}> : (tensor<24xf32>) -> tensor<12xf32> loc(#loc16)
%25 = "torch_c.from_builtin_tensor"(%24) : (tensor<12xf32>) -> !torch.vtensor<[12],f32> loc(#loc16)
%26 = "torch.prim.ListConstruct"() : () -> !torch.list<vtensor<[2,3],f32>> loc(#loc17)
%27 = "arith.constant"() <{value = 0 : index}> : () -> index loc(#loc7)
%28 = "arith.constant"() <{value = 1 : index}> : () -> index loc(#loc7)
%29 = "arith.index_cast"(%5) : (i64) -> index loc(#loc7)
%30:2 = "scf.for"(%27, %29, %28, %19, %21) ({
^bb0(%arg4: index loc("./lstm.torch.mlir":38:13), %arg5: tensor<2x3xf32> loc("./lstm.torch.mlir":38:13), %arg6: tensor<2x3xf32> loc("./lstm.torch.mlir":38:13)):
%43 = "arith.index_cast"(%arg4) : (index) -> i64 loc(#loc7)
%44 = "torch_c.from_i64"(%43) : (i64) -> !torch.int loc(#loc7)
%45 = "torch_c.to_builtin_tensor"(%arg6) : (tensor<2x3xf32>) -> tensor<2x3xf32> loc(#loc18)
%46 = "torch.aten.select.int"(%arg0, %2, %44) : (!torch.vtensor<[15,2,4],f32>, !torch.int, !torch.int) -> !torch.vtensor<[2,4],f32> loc(#loc19)
%47 = "torch.prim.ListConstruct"(%3, %7) : (!torch.int, !torch.int) -> !torch.list
Hi @renxida, is there anything remaining for this PR to be merged?
@vivekkhandelwal1 yup. need a numeric test. but it's hitting some issues lowering past LinAlg. Specifically:
- scf for loop doesn't support tensor args (resolved by https://github.com/llvm/torch-mlir/pull/3040)
- concat doesn't support list of tensors created in a loop. Currently working on this.
Currently stuck tryna lower this from aten / torch to linalg:
mlir file
module {
func.func @lstm(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%int0 = torch.constant.int 0
%int0_0 = torch.constant.int 0
%0 = torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1,12,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,4],f32>
%int0_1 = torch.constant.int 0
%int0_2 = torch.constant.int 0
%1 = torch.aten.select.int %arg2, %int0_1, %int0_2 : !torch.vtensor<[1,12,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,3],f32>
%int0_3 = torch.constant.int 0
%int0_4 = torch.constant.int 0
%2 = torch.aten.select.int %arg3, %int0_3, %int0_4 : !torch.vtensor<[1,24],f32>, !torch.int, !torch.int -> !torch.vtensor<[24],f32>
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%none_5 = torch.constant.none
%int0_6 = torch.constant.int 0
%int1_7 = torch.constant.int 1
%3 = torch.prim.ListConstruct %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%int6 = torch.constant.int 6
%4 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2,3],f32>
%5 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2,3],f32>
%int0_8 = torch.constant.int 0
%int0_9 = torch.constant.int 0
%6 = torch.aten.select.int %4, %int0_8, %int0_9 : !torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
%int0_10 = torch.constant.int 0
%int0_11 = torch.constant.int 0
%7 = torch.aten.select.int %5, %int0_10, %int0_11 : !torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
%int12 = torch.constant.int 12
%int24 = torch.constant.int 24
%8 = torch.aten.slice.Tensor %2, %int0_6, %int0_6, %int12, %int1_7 : !torch.vtensor<[24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[12],f32>
%9 = torch.aten.slice.Tensor %2, %int0_6, %int12, %int24, %int1_7 : !torch.vtensor<[24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[12],f32>
%none_12 = torch.constant.none
%true = torch.constant.bool true
%int0_13 = torch.constant.int 0
%int1_14 = torch.constant.int 1
%int15 = torch.constant.int 15
%int2_15 = torch.constant.int 2
%int3_16 = torch.constant.int 3
%int3_17 = torch.constant.int 3
%int6_18 = torch.constant.int 6
%int9 = torch.constant.int 9
%int12_19 = torch.constant.int 12
%10 = torch.aten.slice.Tensor %0, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor<[12,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
%11 = torch.aten.slice.Tensor %0, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor<[12,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
%12 = torch.aten.slice.Tensor %0, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor<[12,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
%13 = torch.aten.slice.Tensor %0, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor<[12,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
%14 = torch.aten.slice.Tensor %1, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor<[12,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
%15 = torch.aten.slice.Tensor %1, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor<[12,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
%16 = torch.aten.slice.Tensor %1, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor<[12,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
%17 = torch.aten.slice.Tensor %1, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor<[12,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
%18 = torch.aten.slice.Tensor %8, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
%19 = torch.aten.slice.Tensor %8, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
%20 = torch.aten.slice.Tensor %8, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
%21 = torch.aten.slice.Tensor %8, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
%22 = torch.aten.slice.Tensor %9, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
%23 = torch.aten.slice.Tensor %9, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
%24 = torch.aten.slice.Tensor %9, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
%25 = torch.aten.slice.Tensor %9, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
%26 = torch.prim.ListConstruct %int15, %int2_15, %int3_16 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%int6_20 = torch.constant.int 6
%27 = torch.aten.zeros %26, %int6_20, %none_12, %none_12, %none_12 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[15,2,3],f32>
%28 = torch.copy.to_tensor %27 : !torch.tensor<[15,2,3],f32>
%int15_21 = torch.constant.int 15
%true_22 = torch.constant.bool true
%int0_23 = torch.constant.int 0
%29:2 = torch.prim.Loop %int15_21, %true_22, init(%6, %7) {
^bb0(%arg4: !torch.int, %arg5: !torch.vtensor<[2,3],f32>, %arg6: !torch.vtensor<[2,3],f32>):
%33 = torch.aten.select.int %arg0, %int0_23, %arg4 : !torch.vtensor<[15,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4],f32>
%int1_24 = torch.constant.int 1
%34 = torch.aten.linear %33, %10, %18 : !torch.vtensor<[2,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
%35 = torch.aten.linear %arg5, %14, %22 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
%36 = torch.aten.add.Tensor %34, %35, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
%37 = torch.aten.sigmoid %36 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%38 = torch.aten.linear %33, %11, %19 : !torch.vtensor<[2,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
%39 = torch.aten.linear %arg5, %15, %23 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
%40 = torch.aten.add.Tensor %38, %39, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
%41 = torch.aten.sigmoid %40 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%42 = torch.aten.linear %33, %12, %20 : !torch.vtensor<[2,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
%43 = torch.aten.linear %arg5, %16, %24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
%44 = torch.aten.add.Tensor %42, %43, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
%45 = torch.aten.sigmoid %44 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%46 = torch.aten.linear %33, %13, %21 : !torch.vtensor<[2,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
%47 = torch.aten.linear %arg5, %17, %25 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
%48 = torch.aten.add.Tensor %46, %47, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
%49 = torch.aten.tanh %48 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%50 = torch.aten.mul.Tensor %45, %arg6 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%51 = torch.aten.mul.Tensor %37, %49 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%52 = torch.aten.add.Tensor %50, %51, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
%53 = torch.aten.tanh %52 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%54 = torch.aten.mul.Tensor %41, %53 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
%55 = torch.copy.to_tensor %54 : !torch.tensor<[2,3],f32>
%56 = torch.aten.select.int %28, %int0_23, %arg4 : !torch.tensor<[15,2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32>
%57 = torch.aten.copy_ %56, %55, %true : !torch.tensor<[2,3],f32>, !torch.tensor<[2,3],f32>, !torch.bool -> !torch.tensor<[2,3],f32>
torch.prim.Loop.condition %true_22, iter(%54, %52 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>)
} : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>)
%30 = torch.aten.unsqueeze %29#0, %int0_6 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[1,2,3],f32>
%31 = torch.aten.unsqueeze %29#1, %int0_6 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[1,2,3],f32>
%32 = torch.aten.unsqueeze %28, %int1_7 : !torch.tensor<[15,2,3],f32>, !torch.int -> !torch.vtensor<[15,1,2,3],f32>
return %32, %30, %31 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
}
}
I'm getting the error message:
./lstm.torch.mlir:93:13: error: failed to legalize operation 'torch.aten.slice.Tensor' that was explicitly marked illegal
%56 = torch.aten.select.int %28, %int0_23, %arg4 : !torch.tensor<[15,2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32>
^
With details that indicates a failure of:
`*(click to see full error message)* Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed. `
./lstm.torch.mlir:93:13: note: see current operation: %1298 = "torch.aten.slice.Tensor"(%785, %22, %1296, %1297, %19) : (!torch.tensor<[15,2,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.tensor<[1,2,3],f32> loc("./lstm.torch.mlir":93:13)
torch-mlir-opt: /home/azureuser/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/UseDefLists.h:198: mlir::IRObjectWithUseList<mlir::OpOperand>::~IRObjectWithUseList() [OperandType = mlir::OpOperand]: Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed.
Quinn helped.
Will try to stick to using only VTensors
New error:
/home/azureuser/torch-mlir/build/bin/torch-mlir-opt: /home/azureuser/miniconda/lib/libtinfo.so.6: no version information available (required by /home/azureuser/torch-mlir/build/bin/torch-mlir-opt)
./lstm.torch.mlir:93:13: error: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
%56 = torch.aten.index_put %arg5, %55, %53, %false : !torch.vtensor<[15,2,3],f32>, !torch.list<vtensor<[1],si64>>, !torch.vtensor<[2,3],f32>, !torch.bool -> !torch.vtensor<[15,2,3],f32>
^
Torch IR:
https://github.com/pytorch/pytorch/issues/91439
LSTM e2e test case fails, but only on torch-nightly.
Should I remove the test case or somehow xfail it only for nightly hm
A lot of this code does not use camelCase because the onnx.LSTM documentation uses snake_case and I'm trying to stay consistent with that.
Note: a e2e test has been added but xfailed. The actual numerical consistency test is done here: https://github.com/nod-ai/SHARK-TestSuite/pull/142
@renxida, you have done the LLVM bump in this PR. Do we need that for the changes in this PR? If not, can we do it in a separate PR?
Also, the CI seems to be broken might be because of the LLVM bump.
@renxida, you have done the LLVM bump in this PR. Do we need that for the changes in this PR? If not, can we do it in a separate PR?
Also, the CI seems to be broken might be because of the LLVM bump.
no but i was keeping this up to date with main because github is telling me i have merge conflicts.
is there a better way i should have done it?