torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[ONNX] Support onnx.LSTM

Open renxida opened this issue 2 years ago • 5 comments

renxida avatar Feb 29 '24 21:02 renxida

Test case draft: lstm.onnx.mlir

module {
  func.func @lstm(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
    return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
  }
}
current Torch IR ```mlir module { func.func @lstm(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor, %arg3: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none %int0 = torch.constant.int 0 %int0_0 = torch.constant.int 0 %0 = torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_1 = torch.constant.int 0 %int0_2 = torch.constant.int 0 %1 = torch.aten.select.int %arg2, %int0_1, %int0_2 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_3 = torch.constant.int 0 %int0_4 = torch.constant.int 0 %2 = torch.aten.select.int %arg3, %int0_3, %int0_4 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %none_5 = torch.constant.none %int0_6 = torch.constant.int 0 %int1_7 = torch.constant.int 1 %3 = torch.prim.ListConstruct %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list %int6 = torch.constant.int 6 %4 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor %5 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor %int0_8 = torch.constant.int 0 %int0_9 = torch.constant.int 0 %6 = torch.aten.select.int %4, %int0_8, %int0_9 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_10 = torch.constant.int 0 %int0_11 = torch.constant.int 0 %7 = torch.aten.select.int %5, %int0_10, %int0_11 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int12 = torch.constant.int 12 %int24 = torch.constant.int 24 %8 = torch.aten.slice.Tensor %2, %int0_6, %int0_6, %int12, %int1_7 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %9 = torch.aten.slice.Tensor %2, %int0_6, %int12, %int24, %int1_7 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %10 = torch.prim.ListConstruct : () -> !torch.list> %int15 = torch.constant.int 15 %true = torch.constant.bool true %int0_12 = torch.constant.int 0 %int1_13 = torch.constant.int 1 %11:2 = torch.prim.Loop %int15, %true, init(%6, %7) { ^bb0(%arg4: !torch.int, %arg5: !torch.vtensor, %arg6: !torch.vtensor): %16 = torch.aten.select.int %arg0, %int0_12, %arg4 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_14 = torch.constant.int 0 %int1_15 = torch.constant.int 1 %int4 = torch.constant.int 4 %17 = torch.prim.ListConstruct %int1_15, %int4 : (!torch.int, !torch.int) -> !torch.list %18 = torch.aten.tile %16, %17 : !torch.vtensor, !torch.list -> !torch.vtensor %19 = torch.aten.tile %arg5, %17 : !torch.vtensor, !torch.list -> !torch.vtensor %20 = torch.aten.linear %18, %0, %8 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %21 = torch.aten.linear %19, %1, %9 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %22 = torch.aten.add.Tensor %20, %21, %int1_15 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor %int3_16 = torch.constant.int 3 %int6_17 = torch.constant.int 6 %int9 = torch.constant.int 9 %int12_18 = torch.constant.int 12 %23 = torch.aten.slice.Tensor %22, %int1_15, %int0_14, %int9, %int1_15 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %24 = torch.aten.slice.Tensor %22, %int1_15, %int9, %int12_18, %int1_15 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %25 = torch.aten.sigmoid %23 : !torch.vtensor -> !torch.vtensor %26 = torch.aten.tanh %24 : !torch.vtensor -> !torch.vtensor %27 = torch.aten.slice.Tensor %25, %int1_15, %int0_14, %int3_16, %int1_15 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %28 = torch.aten.slice.Tensor %25, %int1_15, %int3_16, %int6_17, %int1_15 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %29 = torch.aten.slice.Tensor %25, %int1_15, %int6_17, %int9, %int1_15 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %30 = torch.aten.mul.Tensor %29, %arg6 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %31 = torch.aten.mul.Tensor %27, %26 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %32 = torch.aten.add.Tensor %30, %31, %int1_15 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor %33 = torch.aten.tanh %32 : !torch.vtensor -> !torch.vtensor %34 = torch.aten.mul.Tensor %28, %33 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %35 = torch.aten.append.t %10, %34 : !torch.list>, !torch.vtensor -> !torch.list> %36 = torch.aten.add.int %arg4, %int1_13 : !torch.int, !torch.int -> !torch.int torch.prim.Loop.condition %true, iter(%34, %32 : !torch.vtensor, !torch.vtensor) } : (!torch.int, !torch.bool, !torch.vtensor, !torch.vtensor) -> (!torch.vtensor, !torch.vtensor) %12 = torch.aten.unsqueeze %11#0, %int0_6 : !torch.vtensor, !torch.int -> !torch.vtensor %13 = torch.aten.unsqueeze %11#1, %int0_6 : !torch.vtensor, !torch.int -> !torch.vtensor %14 = torch.aten.stack %10, %int0_6 : !torch.list>, !torch.int -> !torch.vtensor %15 = torch.aten.unsqueeze %14, %int1_7 : !torch.vtensor, !torch.int -> !torch.vtensor return %15, %12, %13 : !torch.vtensor, !torch.vtensor, !torch.vtensor } }

</details>

renxida avatar Mar 01 '24 03:03 renxida

Currently testing with:


/home/azureuser/torch-mlir/build/bin/torch-mlir-opt -pass-pipeline='builtin.module(func.func(convert-torch-onnx-to-torch))' lstm.onnx.mlir  -o lstm.torch.mlir

~/torch-mlir/build/bin/torch-mlir-opt --mlir-print-debuginfo  --mlir-elide-elementsattrs-if-larger=16 --mlir-print-stacktrace-on-diagnostic --mlir-disable-threading --mlir-print-ir-after-failure --mlir-print-ir-module-scope -pass-pipeline='builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)' ./lstm.torch.mlir -o ./lstm.linalg.mlir 2>&1 | tee torchtolinalg.log

renxida avatar Mar 01 '24 03:03 renxida

encountered a bunch of issues earlier, fixed thanks to Rob.

Currently stuck on:

./lstm.torch.mlir:61:13: error: 'torch_c.to_builtin_tensor' op operand #0 must be Multi-dimensional array modeling Torch's Tensor type, but got 'tensor' %30 = torch.aten.mul.Tensor %29, %arg6 : !torch.vtensor, !torch.vtensor -> !torch.vtensor ^ ./lstm.torch.mlir:61:13: note: diagnostic emitted with trace: #0 0x000055b8b5e22e2d llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/azureuser/torch-mlir/externals/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:11 #1 0x000055b8b5bf959e emitDiag(mlir::Location, mlir::DiagnosticSeverity, llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Diagnostics.cpp:319:5 #2 0x000055b8b5bf94c5 mlir::emitError(mlir::Location, llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Diagnostics.cpp:330:10 #3 0x000055b8b5c9f888 mlir::Operation::emitError(llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Operation.cpp:269:29 #4 0x000055b8b5c9f359 mlir::Operation::emitOpError(llvm::Twine const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Operation.cpp:672:22 #5 0x000055b8b2d9c635 mlir::torch::TorchConversion::__mlir_ods_local_type_constraint_TorchConversionOps1(mlir::Operation*, mlir::Type, llvm::StringRef, unsigned int) /home/azureuser/torch-mlir/build/tools/torch-mlir/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc:52:39 #6 0x000055b8b2da4a46 mlir::torch::TorchConversion::ToBuiltinTensorOp::verifyInvariantsImpl() /home/azureuser/torch-mlir/build/tools/torch-mlir/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc:1499:26 #7 0x000055b8b2d93cc2 mlir::OpTrait::OpInvariants<:torch::torchconversion::tobuiltintensorop>::verifyTrait(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:432:35 #8 0x000055b8b2d93b05 std::enable_if>::value, mlir::LogicalResult>::type mlir::op_definition_impl::verifyTrait<:optrait::opinvariants>>(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:1620:10 #9 0x000055b8b2d938f7 mlir::LogicalResult mlir::op_definition_impl::verifyTraits<:optrait::zeroregions>, mlir::OpTrait::OneResult<:torch::torchconversion::tobuiltintensorop>, mlir::OpTrait::OneTypedResult<:tensortype>::Impl<:torch::torchconversion::tobuiltintensorop>, mlir::OpTrait::ZeroSuccessors<:torch::torchconversion::tobuiltintensorop>, mlir::OpTrait::OneOperand<:torch::torchconversion::tobuiltintensorop>, mlir::OpTrait::OpInvariants<:torch::torchconversion::tobuiltintensorop>, mlir::InferTypeOpInterface::Trait<:torch::torchconversion::tobuiltintensorop>, mlir::ConditionallySpeculatable::Trait<:torch::torchconversion::tobuiltintensorop>, mlir::OpTrait::AlwaysSpeculatableImplTrait<:torch::torchconversion::tobuiltintensorop>, mlir::MemoryEffectOpInterface::Trait<:torch::torchconversion::tobuiltintensorop>>(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:1631:29 #10 0x000055b8b2d937b5 mlir::Op<:torch::torchconversion::tobuiltintensorop mlir::optrait::zeroregions mlir::optrait::oneresult mlir::optrait::onetypedresult>::Impl, mlir::OpTrait::ZeroSuccessors, mlir::OpTrait::OneOperand, mlir::OpTrait::OpInvariants, mlir::InferTypeOpInterface::Trait, mlir::ConditionallySpeculatable::Trait, mlir::OpTrait::AlwaysSpeculatableImplTrait, mlir::MemoryEffectOpInterface::Trait>::verifyInvariants(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OpDefinition.h:2012:16 #11 0x000055b8b20126f5 mlir::LogicalResult llvm::detail::UniqueFunctionBase<:logicalresult mlir::operation>::CallImpl<:logicalresult const>(void*, mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/FunctionExtras.h:221:12 #12 0x000055b8b2011e57 llvm::unique_function<:logicalresult const>::operator()(mlir::Operation*) const /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/FunctionExtras.h:411:12 #13 0x000055b8b2d92466 mlir::RegisteredOperationName::Model<:torch::torchconversion::tobuiltintensorop>::verifyInvariants(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/OperationSupport.h:558:14 #14 0x000055b8b5ce8ba6 mlir::OperationName::verifyInvariants(mlir::Operation*) const /home/azureuser/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/OperationSupport.h:317:23 #15 0x000055b8b5ce5b8d (anonymous namespace)::OperationVerifier::verifyOnEntrance(mlir::Operation&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:179:48 #16 0x000055b8b5ce58e0 _ZZN12_GLOBAL__N_117OperationVerifier15verifyOperationERN4mlir9OperationEENK3$_2clIS2_EEDaPT_ /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:293:45 #17 0x000055b8b5ce47f7 _ZZN12_GLOBAL__N_117OperationVerifier15verifyOperationERN4mlir9OperationEENK3$_1clIZNS0_15verifyOperationES3_E3$_2EEDaOT_N4llvm12PointerUnionIJPS2_PNS1_5BlockEEEE /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:277:16 #18 0x000055b8b5ce401f (anonymous namespace)::OperationVerifier::verifyOperation(mlir::Operation&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:292:16 #19 0x000055b8b5ce3de1 (anonymous namespace)::OperationVerifier::verifyOpAndDominance(mlir::Operation&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:85:14 #20 0x000055b8b5ce3d92 mlir::verify(mlir::Operation*, bool) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/IR/Verifier.cpp:423:19 #21 0x000055b8b4826985 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:548:27 #22 0x000055b8b4826e74 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:585:16 #23 0x000055b8b4827cce mlir::detail::OpToOpPassAdaptor::runOnOperationImpl(bool) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:726:20 #24 0x000055b8b482757d mlir::detail::OpToOpPassAdaptor::runOnOperation(bool) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:709:1 #25 0x000055b8b482b186 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:517:11 #26 0x000055b8b482b135 void llvm::function_ref::callback_fn<:detail::optooppassadaptor::run mlir::operation mlir::analysismanager bool unsigned int>(long) /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:5 #27 0x000055b8b1efc0e9 llvm::function_ref::operator()() const /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:5 #28 0x000055b8b482ded5 void mlir::MLIRContext::executeAction<:passexecutionaction mlir::pass>(llvm::function_ref, llvm::ArrayRef<:irunit>, mlir::Pass&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/MLIRContext.h:276:3 #29 0x000055b8b48268f3 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:525:17 #30 0x000055b8b4826e74 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:585:16 #31 0x000055b8b48288b8 mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:896:10 #32 0x000055b8b48287e2 mlir::PassManager::run(mlir::Operation*) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Pass/Pass.cpp:876:60 #33 0x000055b8b1e7ee72 performActions(llvm::raw_ostream&, std::shared_ptr<:sourcemgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:396:17 #34 0x000055b8b1e7eaa8 processBuffer(llvm::raw_ostream&, std::unique_ptr<:memorybuffer std::default_delete>>, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPool*) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:461:12 #35 0x000055b8b1e7e88c mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<:memorybuffer std::default_delete>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0::operator()(std::unique_ptr<:memorybuffer std::default_delete>>, llvm::raw_ostream&) const /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:532:12 #36 0x000055b8b1e7e826 mlir::LogicalResult llvm::function_ref<:logicalresult std::default_delete>>, llvm::raw_ostream&)>::callback_fn<:mliroptmain std::unique_ptr std::default_delete>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::unique_ptr<:memorybuffer std::default_delete>>, llvm::raw_ostream&) /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12 #37 0x000055b8b5d0ea22 llvm::function_ref<:logicalresult std::default_delete>>, llvm::raw_ostream&)>::operator()(std::unique_ptr<:memorybuffer std::default_delete>>, llvm::raw_ostream&) const /home/azureuser/torch-mlir/externals/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12 #38 0x000055b8b5d0e03d mlir::splitAndProcessBuffer(std::unique_ptr<:memorybuffer std::default_delete>>, llvm::function_ref<:logicalresult std::default_delete>>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Support/ToolUtilities.cpp:28:12 #39 0x000055b8b1e7b75b mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<:memorybuffer std::default_delete>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:535:10 #40 0x000055b8b1e7b9f5 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:570:14 #41 0x000055b8b1e7bbc8 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) /home/azureuser/torch-mlir/externals/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:586:10 #42 0x000055b8b1e770e5 main /home/azureuser/torch-mlir/tools/torch-mlir-opt/torch-mlir-opt.cpp:43:33 #43 0x00007faabdc23a90 (/lib/x86_64-linux-gnu/libc.so.6+0x23a90) #44 0x00007faabdc23b49 __libc_start_main (/lib/x86_64-linux-gnu/libc.so.6+0x23b49) #45 0x000055b8b1e76f75 _start (/home/azureuser/torch-mlir/build/bin/torch-mlir-opt+0x234f75)

./lstm.torch.mlir:61:13: note: see current operation: %45 = "torch_c.to_builtin_tensor"(%arg6) : (tensor<2x3xf32>) -> tensor<2x3xf32> loc("./lstm.torch.mlir":61:13) // -----// IR Dump After ConvertTorchToSCF Failed (convert-torch-to-scf) ('func.func' operation: @lstm) //----- // #loc2 = loc("./lstm.torch.mlir":2:19) #loc3 = loc("./lstm.torch.mlir":2:56) #loc4 = loc("./lstm.torch.mlir":2:93) #loc5 = loc("./lstm.torch.mlir":2:130) #loc6 = loc(unknown) #loc7 = loc("./lstm.torch.mlir":38:13) #loc18 = loc("./lstm.torch.mlir":61:13) #loc24 = loc("./lstm.torch.mlir":49:13) #loc26 = loc("./lstm.torch.mlir":54:13) #loc27 = loc("./lstm.torch.mlir":55:13) #loc28 = loc("./lstm.torch.mlir":56:13) #loc29 = loc("./lstm.torch.mlir":57:13) #loc30 = loc("./lstm.torch.mlir":58:13) #loc31 = loc("./lstm.torch.mlir":59:13) #loc32 = loc("./lstm.torch.mlir":60:13) #loc33 = loc("./lstm.torch.mlir":62:13) #loc34 = loc("./lstm.torch.mlir":63:13) #loc35 = loc("./lstm.torch.mlir":64:13) #loc36 = loc("./lstm.torch.mlir":65:13) #map = affine_map<(d0, d1) -> (d0, d1)> "builtin.module"() ({ "func.func"() <{function_type = (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>), sym_name = "lstm"}> ({ ^bb0(%arg0: !torch.vtensor<[15,2,4],f32> loc("./lstm.torch.mlir":2:19), %arg1: !torch.vtensor<[1,12,4],f32> loc("./lstm.torch.mlir":2:56), %arg2: !torch.vtensor<[1,12,3],f32> loc("./lstm.torch.mlir":2:93), %arg3: !torch.vtensor<[1,24],f32> loc("./lstm.torch.mlir":2:130)): %0 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32 loc(#loc6) %1 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32 loc(#loc6) %2 = "torch.constant.int"() <{value = 0 : i64}> : () -> !torch.int loc(#loc6) %3 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int loc(#loc6) %4 = "torch.constant.int"() <{value = 15 : i64}> : () -> !torch.int loc(#loc6) %5 = "torch_c.to_i64"(%4) : (!torch.int) -> i64 loc(#loc7) %6 = "torch.constant.bool"() <{value = true}> : () -> !torch.bool loc(#loc6) %7 = "torch.constant.int"() <{value = 4 : i64}> : () -> !torch.int loc(#loc6) %8 = "torch.aten.select.int"(%arg1, %2, %2) : (!torch.vtensor<[1,12,4],f32>, !torch.int, !torch.int) -> !torch.vtensor<[12,4],f32> loc(#loc8) %9 = "torch.aten.select.int"(%arg2, %2, %2) : (!torch.vtensor<[1,12,3],f32>, !torch.int, !torch.int) -> !torch.vtensor<[12,3],f32> loc(#loc9) %10 = "torch.aten.select.int"(%arg3, %2, %2) : (!torch.vtensor<[1,24],f32>, !torch.int, !torch.int) -> !torch.vtensor<[24],f32> loc(#loc10) %11 = "torch_c.to_builtin_tensor"(%10) : (!torch.vtensor<[24],f32>) -> tensor<24xf32> loc(#loc11) %12 = "tensor.empty"() : () -> tensor<1x2x3xf32> loc(#loc12) %13 = "linalg.fill"(%1, %12) <{operandSegmentSizes = array<i32: 1, 1>}> ({ ^bb0(%arg4: f32 loc(unknown), %arg5: f32 loc(unknown)): "linalg.yield"(%arg4) : (f32) -> () loc(#loc6) }) : (f32, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> loc(#loc12) %14 = "torch_c.from_builtin_tensor"(%13) : (tensor<1x2x3xf32>) -> !torch.vtensor<[1,2,3],f32> loc(#loc12) %15 = "tensor.empty"() : () -> tensor<1x2x3xf32> loc(#loc13) %16 = "linalg.fill"(%1, %15) <{operandSegmentSizes = array<i32: 1, 1>}> ({ ^bb0(%arg4: f32 loc(unknown), %arg5: f32 loc(unknown)): "linalg.yield"(%arg4) : (f32) -> () loc(#loc6) }) : (f32, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> loc(#loc13) %17 = "torch_c.from_builtin_tensor"(%16) : (tensor<1x2x3xf32>) -> !torch.vtensor<[1,2,3],f32> loc(#loc13) %18 = "torch.aten.select.int"(%14, %2, %2) : (!torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int) -> !torch.vtensor<[2,3],f32> loc(#loc14) %19 = "torch_c.to_builtin_tensor"(%18) : (!torch.vtensor<[2,3],f32>) -> tensor<2x3xf32> loc(#loc7) %20 = "torch.aten.select.int"(%17, %2, %2) : (!torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int) -> !torch.vtensor<[2,3],f32> loc(#loc15) %21 = "torch_c.to_builtin_tensor"(%20) : (!torch.vtensor<[2,3],f32>) -> tensor<2x3xf32> loc(#loc7) %22 = "tensor.extract_slice"(%11) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0>, static_sizes = array<i64: 12>, static_strides = array<i64: 1>}> : (tensor<24xf32>) -> tensor<12xf32> loc(#loc11) %23 = "torch_c.from_builtin_tensor"(%22) : (tensor<12xf32>) -> !torch.vtensor<[12],f32> loc(#loc11) %24 = "tensor.extract_slice"(%11) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 12>, static_sizes = array<i64: 12>, static_strides = array<i64: 1>}> : (tensor<24xf32>) -> tensor<12xf32> loc(#loc16) %25 = "torch_c.from_builtin_tensor"(%24) : (tensor<12xf32>) -> !torch.vtensor<[12],f32> loc(#loc16) %26 = "torch.prim.ListConstruct"() : () -> !torch.list<vtensor<[2,3],f32>> loc(#loc17) %27 = "arith.constant"() <{value = 0 : index}> : () -> index loc(#loc7) %28 = "arith.constant"() <{value = 1 : index}> : () -> index loc(#loc7) %29 = "arith.index_cast"(%5) : (i64) -> index loc(#loc7) %30:2 = "scf.for"(%27, %29, %28, %19, %21) ({ ^bb0(%arg4: index loc("./lstm.torch.mlir":38:13), %arg5: tensor<2x3xf32> loc("./lstm.torch.mlir":38:13), %arg6: tensor<2x3xf32> loc("./lstm.torch.mlir":38:13)): %43 = "arith.index_cast"(%arg4) : (index) -> i64 loc(#loc7) %44 = "torch_c.from_i64"(%43) : (i64) -> !torch.int loc(#loc7) %45 = "torch_c.to_builtin_tensor"(%arg6) : (tensor<2x3xf32>) -> tensor<2x3xf32> loc(#loc18) %46 = "torch.aten.select.int"(%arg0, %2, %44) : (!torch.vtensor<[15,2,4],f32>, !torch.int, !torch.int) -> !torch.vtensor<[2,4],f32> loc(#loc19) %47 = "torch.prim.ListConstruct"(%3, %7) : (!torch.int, !torch.int) -> !torch.list loc(#loc20) %48 = "torch.aten.tile"(%46, %47) : (!torch.vtensor<[2,4],f32>, !torch.list) -> !torch.vtensor<[2,16],f32> loc(#loc21) %49 = "torch.aten.tile"(%arg5, %47) : (tensor<2x3xf32>, !torch.list) -> !torch.vtensor<[2,12],f32> loc(#loc22) %50 = "torch.aten.linear"(%48, %8, %23) : (!torch.vtensor<[2,16],f32>, !torch.vtensor<[12,4],f32>, !torch.vtensor<[12],f32>) -> !torch.vtensor<[2,12],f32> loc(#loc23) %51 = "torch_c.to_builtin_tensor"(%50) : (!torch.vtensor<[2,12],f32>) -> tensor<2x12xf32> loc(#loc24) %52 = "torch.aten.linear"(%49, %9, %25) : (!torch.vtensor<[2,12],f32>, !torch.vtensor<[12,3],f32>, !torch.vtensor<[12],f32>) -> !torch.vtensor<[2,12],f32> loc(#loc25) %53 = "torch_c.to_builtin_tensor"(%52) : (!torch.vtensor<[2,12],f32>) -> tensor<2x12xf32> loc(#loc24) %54 = "tensor.empty"() : () -> tensor<2x12xf32> loc(#loc24) %55 = "linalg.generic"(%51, %53, %54) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array<i32: 2, 1>}> ({ ^bb0(%arg7: f32 loc("./lstm.torch.mlir":49:13), %arg8: f32 loc("./lstm.torch.mlir":49:13), %arg9: f32 loc("./lstm.torch.mlir":49:13)): %79 = "arith.addf"(%arg7, %arg8) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc(#loc24) "linalg.yield"(%79) : (f32) -> () loc(#loc24) }) : (tensor<2x12xf32>, tensor<2x12xf32>, tensor<2x12xf32>) -> tensor<2x12xf32> loc(#loc24) %56 = "tensor.extract_slice"(%55) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0>, static_sizes = array<i64: 2, 9>, static_strides = array<i64: 1, 1>}> : (tensor<2x12xf32>) -> tensor<2x9xf32> loc(#loc26) %57 = "tensor.extract_slice"(%55) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 9>, static_sizes = array<i64: 2, 3>, static_strides = array<i64: 1, 1>}> : (tensor<2x12xf32>) -> tensor<2x3xf32> loc(#loc27) %58 = "tensor.empty"() : () -> tensor<2x9xf32> loc(#loc28) %59 = "linalg.generic"(%56, %58) <{indexing_maps = [#map, #map], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array<i32: 1, 1>}> ({ ^bb0(%arg7: f32 loc("./lstm.torch.mlir":54:13), %arg8: f32 loc("./lstm.torch.mlir":56:13)): %79 = "arith.negf"(%arg7) <{fastmath = #arith.fastmath}> : (f32) -> f32 loc(#loc28) %80 = "math.exp"(%79) <{fastmath = #arith.fastmath}> : (f32) -> f32 loc(#loc28) %81 = "arith.addf"(%80, %0) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc(#loc28) %82 = "arith.divf"(%0, %81) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc(#loc28) "linalg.yield"(%82) : (f32) -> () loc(#loc28) }) : (tensor<2x9xf32>, tensor<2x9xf32>) -> tensor<2x9xf32> loc(#loc28) %60 = "tensor.empty"() : () -> tensor<2x3xf32> loc(#loc29) %61 = "linalg.generic"(%57, %60) <{indexing_maps = [#map, #map], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array<i32: 1, 1>}> ({ ^bb0(%arg7: f32 loc("./lstm.torch.mlir":55:13), %arg8: f32 loc("./lstm.torch.mlir":57:13)): %79 = "math.tanh"(%arg7) <{fastmath = #arith.fastmath}> : (f32) -> f32 loc(#loc29) "linalg.yield"(%79) : (f32) -> () loc(#loc29) }) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> loc(#loc29) %62 = "tensor.extract_slice"(%59) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0>, static_sizes = array<i64: 2, 3>, static_strides = array<i64: 1, 1>}> : (tensor<2x9xf32>) -> tensor<2x3xf32> loc(#loc30) %63 = "tensor.extract_slice"(%59) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 3>, static_sizes = array<i64: 2, 3>, static_strides = array<i64: 1, 1>}> : (tensor<2x9xf32>) -> tensor<2x3xf32> loc(#loc31) %64 = "tensor.extract_slice"(%59) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 6>, static_sizes = array<i64: 2, 3>, static_strides = array<i64: 1, 1>}> : (tensor<2x9xf32>) -> tensor<2x3xf32> loc(#loc32) %65 = "tensor.empty"() : () -> tensor<2x3xf32> loc(#loc18) %66 = "linalg.generic"(%64, %45, %65) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array<i32: 2, 1>}> ({ ^bb0(%arg7: f32 loc("./lstm.torch.mlir":60:13), %arg8: f32 loc("./lstm.torch.mlir":61:13), %arg9: f32 loc("./lstm.torch.mlir":61:13)): %79 = "arith.mulf"(%arg7, %arg8) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc(#loc18) "linalg.yield"(%79) : (f32) -> () loc(#loc18) }) : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> loc(#loc18) %67 = "tensor.empty"() : () -> tensor<2x3xf32> loc(#loc33) %68 = "linalg.generic"(%62, %61, %67) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array<i32: 2, 1>}> ({ ^bb0(%arg7: f32 loc("./lstm.torch.mlir":58:13), %arg8: f32 loc("./lstm.torch.mlir":57:13), %arg9: f32 loc("./lstm.torch.mlir":62:13)): %79 = "arith.mulf"(%arg7, %arg8) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc(#loc33) "linalg.yield"(%79) : (f32) -> () loc(#loc33) }) : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> loc(#loc33) %69 = "tensor.empty"() : () -> tensor<2x3xf32> loc(#loc34) %70 = "linalg.generic"(%66, %68, %69) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array<i32: 2, 1>}> ({ ^bb0(%arg7: f32 loc("./lstm.torch.mlir":61:13), %arg8: f32 loc("./lstm.torch.mlir":62:13), %arg9: f32 loc("./lstm.torch.mlir":63:13)): %79 = "arith.addf"(%arg7, %arg8) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc(#loc34) "linalg.yield"(%79) : (f32) -> () loc(#loc34) }) : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> loc(#loc34) %71 = "torch_c.from_builtin_tensor"(%70) : (tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> loc(#loc34) %72 = "tensor.empty"() : () -> tensor<2x3xf32> loc(#loc35) %73 = "linalg.generic"(%70, %72) <{indexing_maps = [#map, #map], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array<i32: 1, 1>}> ({ ^bb0(%arg7: f32 loc("./lstm.torch.mlir":63:13), %arg8: f32 loc("./lstm.torch.mlir":64:13)): %79 = "math.tanh"(%arg7) <{fastmath = #arith.fastmath}> : (f32) -> f32 loc(#loc35) "linalg.yield"(%79) : (f32) -> () loc(#loc35) }) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> loc(#loc35) %74 = "tensor.empty"() : () -> tensor<2x3xf32> loc(#loc36) %75 = "linalg.generic"(%63, %73, %74) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type, #linalg.iterator_type], operandSegmentSizes = array<i32: 2, 1>}> ({ ^bb0(%arg7: f32 loc("./lstm.torch.mlir":59:13), %arg8: f32 loc("./lstm.torch.mlir":64:13), %arg9: f32 loc("./lstm.torch.mlir":65:13)): %79 = "arith.mulf"(%arg7, %arg8) <{fastmath = #arith.fastmath}> : (f32, f32) -> f32 loc(#loc36) "linalg.yield"(%79) : (f32) -> () loc(#loc36) }) : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> loc(#loc36) %76 = "torch_c.from_builtin_tensor"(%75) : (tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> loc(#loc36) %77 = "torch.aten.append.t"(%26, %76) : (!torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>> loc(#loc37) %78 = "torch.aten.add.int"(%44, %3) : (!torch.int, !torch.int) -> !torch.int loc(#loc38) "scf.yield"(%76, %71) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> () loc(#loc7) }) : (index, index, index, tensor<2x3xf32>, tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) loc(#loc7) %31 = "torch_c.from_builtin_tensor"(%30#0) : (tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> loc(#loc7) %32 = "torch_c.from_builtin_tensor"(%30#1) : (tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> loc(#loc7) %33 = "torch_c.to_builtin_tensor"(%31) : (!torch.vtensor<[2,3],f32>) -> tensor<2x3xf32> loc(#loc39) %34 = "torch_c.to_builtin_tensor"(%32) : (!torch.vtensor<[2,3],f32>) -> tensor<2x3xf32> loc(#loc40) %35 = "tensor.expand_shape"(%33) <{reassociation = [[0, 1], [2]]}> : (tensor<2x3xf32>) -> tensor<1x2x3xf32> loc(#loc39) %36 = "torch_c.from_builtin_tensor"(%35) : (tensor<1x2x3xf32>) -> !torch.vtensor<[1,2,3],f32> loc(#loc39) %37 = "tensor.expand_shape"(%34) <{reassociation = [[0, 1], [2]]}> : (tensor<2x3xf32>) -> tensor<1x2x3xf32> loc(#loc40) %38 = "torch_c.from_builtin_tensor"(%37) : (tensor<1x2x3xf32>) -> !torch.vtensor<[1,2,3],f32> loc(#loc40) %39 = "torch.aten.stack"(%26, %2) : (!torch.list<vtensor<[2,3],f32>>, !torch.int) -> !torch.vtensor<[15,2,3],f32> loc(#loc41) %40 = "torch_c.to_builtin_tensor"(%39) : (!torch.vtensor<[15,2,3],f32>) -> tensor<15x2x3xf32> loc(#loc42) %41 = "tensor.expand_shape"(%40) <{reassociation = [[0], [1, 2], [3]]}> : (tensor<15x2x3xf32>) -> tensor<15x1x2x3xf32> loc(#loc42) %42 = "torch_c.from_builtin_tensor"(%41) : (tensor<15x1x2x3xf32>) -> !torch.vtensor<[15,1,2,3],f32> loc(#loc42) "func.return"(%42, %36, %38) : (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) -> () loc(#loc43) }) {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} : () -> () loc(#loc1) }) : () -> () loc(#loc) #loc = loc("./lstm.torch.mlir":1:1) #loc1 = loc("./lstm.torch.mlir":2:3) #loc8 = loc("./lstm.torch.mlir":6:10) #loc9 = loc("./lstm.torch.mlir":9:10) #loc10 = loc("./lstm.torch.mlir":12:10) #loc11 = loc("./lstm.torch.mlir":31:10) #loc12 = loc("./lstm.torch.mlir":21:10) #loc13 = loc("./lstm.torch.mlir":22:10) #loc14 = loc("./lstm.torch.mlir":25:10) #loc15 = loc("./lstm.torch.mlir":28:10) #loc16 = loc("./lstm.torch.mlir":32:10) #loc17 = loc("./lstm.torch.mlir":33:11) #loc19 = loc("./lstm.torch.mlir":40:13) #loc20 = loc("./lstm.torch.mlir":44:13) #loc21 = loc("./lstm.torch.mlir":45:13) #loc22 = loc("./lstm.torch.mlir":46:13) #loc23 = loc("./lstm.torch.mlir":47:13) #loc25 = loc("./lstm.torch.mlir":48:13) #loc37 = loc("./lstm.torch.mlir":66:13) #loc38 = loc("./lstm.torch.mlir":67:13) #loc39 = loc("./lstm.torch.mlir":70:11) #loc40 = loc("./lstm.torch.mlir":71:11) #loc41 = loc("./lstm.torch.mlir":72:11) #loc42 = loc("./lstm.torch.mlir":73:11) #loc43 = loc("./lstm.torch.mlir":74:5)

renxida avatar Mar 01 '24 03:03 renxida

Hi @renxida, is there anything remaining for this PR to be merged?

vivekkhandelwal1 avatar Mar 21 '24 05:03 vivekkhandelwal1

@vivekkhandelwal1 yup. need a numeric test. but it's hitting some issues lowering past LinAlg. Specifically:

  1. scf for loop doesn't support tensor args (resolved by https://github.com/llvm/torch-mlir/pull/3040)
  2. concat doesn't support list of tensors created in a loop. Currently working on this.

renxida avatar Mar 21 '24 21:03 renxida

Currently stuck tryna lower this from aten / torch to linalg:

mlir file
module {
  func.func @lstm(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %int0 = torch.constant.int 0
    %int0_0 = torch.constant.int 0
    %0 = torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1,12,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,4],f32>
    %int0_1 = torch.constant.int 0
    %int0_2 = torch.constant.int 0
    %1 = torch.aten.select.int %arg2, %int0_1, %int0_2 : !torch.vtensor<[1,12,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[12,3],f32>
    %int0_3 = torch.constant.int 0
    %int0_4 = torch.constant.int 0
    %2 = torch.aten.select.int %arg3, %int0_3, %int0_4 : !torch.vtensor<[1,24],f32>, !torch.int, !torch.int -> !torch.vtensor<[24],f32>
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %none_5 = torch.constant.none
    %int0_6 = torch.constant.int 0
    %int1_7 = torch.constant.int 1
    %3 = torch.prim.ListConstruct %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %int6 = torch.constant.int 6
    %4 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2,3],f32>
    %5 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2,3],f32>
    %int0_8 = torch.constant.int 0
    %int0_9 = torch.constant.int 0
    %6 = torch.aten.select.int %4, %int0_8, %int0_9 : !torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
    %int0_10 = torch.constant.int 0
    %int0_11 = torch.constant.int 0
    %7 = torch.aten.select.int %5, %int0_10, %int0_11 : !torch.vtensor<[1,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
    %int12 = torch.constant.int 12
    %int24 = torch.constant.int 24
    %8 = torch.aten.slice.Tensor %2, %int0_6, %int0_6, %int12, %int1_7 : !torch.vtensor<[24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[12],f32>
    %9 = torch.aten.slice.Tensor %2, %int0_6, %int12, %int24, %int1_7 : !torch.vtensor<[24],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[12],f32>
    %none_12 = torch.constant.none
    %true = torch.constant.bool true
    %int0_13 = torch.constant.int 0
    %int1_14 = torch.constant.int 1
    %int15 = torch.constant.int 15
    %int2_15 = torch.constant.int 2
    %int3_16 = torch.constant.int 3
    %int3_17 = torch.constant.int 3
    %int6_18 = torch.constant.int 6
    %int9 = torch.constant.int 9
    %int12_19 = torch.constant.int 12
    %10 = torch.aten.slice.Tensor %0, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor<[12,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
    %11 = torch.aten.slice.Tensor %0, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor<[12,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
    %12 = torch.aten.slice.Tensor %0, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor<[12,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
    %13 = torch.aten.slice.Tensor %0, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor<[12,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
    %14 = torch.aten.slice.Tensor %1, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor<[12,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
    %15 = torch.aten.slice.Tensor %1, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor<[12,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
    %16 = torch.aten.slice.Tensor %1, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor<[12,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
    %17 = torch.aten.slice.Tensor %1, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor<[12,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32>
    %18 = torch.aten.slice.Tensor %8, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
    %19 = torch.aten.slice.Tensor %8, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
    %20 = torch.aten.slice.Tensor %8, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
    %21 = torch.aten.slice.Tensor %8, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
    %22 = torch.aten.slice.Tensor %9, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
    %23 = torch.aten.slice.Tensor %9, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
    %24 = torch.aten.slice.Tensor %9, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
    %25 = torch.aten.slice.Tensor %9, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor<[12],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
    %26 = torch.prim.ListConstruct %int15, %int2_15, %int3_16 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %int6_20 = torch.constant.int 6
    %27 = torch.aten.zeros %26, %int6_20, %none_12, %none_12, %none_12 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[15,2,3],f32>
    %28 = torch.copy.to_tensor %27 : !torch.tensor<[15,2,3],f32>
    %int15_21 = torch.constant.int 15
    %true_22 = torch.constant.bool true
    %int0_23 = torch.constant.int 0
    %29:2 = torch.prim.Loop %int15_21, %true_22, init(%6, %7) {
    ^bb0(%arg4: !torch.int, %arg5: !torch.vtensor<[2,3],f32>, %arg6: !torch.vtensor<[2,3],f32>):
      %33 = torch.aten.select.int %arg0, %int0_23, %arg4 : !torch.vtensor<[15,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4],f32>
      %int1_24 = torch.constant.int 1
      %34 = torch.aten.linear %33, %10, %18 : !torch.vtensor<[2,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
      %35 = torch.aten.linear %arg5, %14, %22 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
      %36 = torch.aten.add.Tensor %34, %35, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
      %37 = torch.aten.sigmoid %36 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
      %38 = torch.aten.linear %33, %11, %19 : !torch.vtensor<[2,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
      %39 = torch.aten.linear %arg5, %15, %23 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
      %40 = torch.aten.add.Tensor %38, %39, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
      %41 = torch.aten.sigmoid %40 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
      %42 = torch.aten.linear %33, %12, %20 : !torch.vtensor<[2,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
      %43 = torch.aten.linear %arg5, %16, %24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
      %44 = torch.aten.add.Tensor %42, %43, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
      %45 = torch.aten.sigmoid %44 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
      %46 = torch.aten.linear %33, %13, %21 : !torch.vtensor<[2,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
      %47 = torch.aten.linear %arg5, %17, %25 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[2,3],f32>
      %48 = torch.aten.add.Tensor %46, %47, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
      %49 = torch.aten.tanh %48 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
      %50 = torch.aten.mul.Tensor %45, %arg6 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
      %51 = torch.aten.mul.Tensor %37, %49 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
      %52 = torch.aten.add.Tensor %50, %51, %int1_24 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
      %53 = torch.aten.tanh %52 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
      %54 = torch.aten.mul.Tensor %41, %53 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
      %55 = torch.copy.to_tensor %54 : !torch.tensor<[2,3],f32>
      %56 = torch.aten.select.int %28, %int0_23, %arg4 : !torch.tensor<[15,2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32>
      %57 = torch.aten.copy_ %56, %55, %true : !torch.tensor<[2,3],f32>, !torch.tensor<[2,3],f32>, !torch.bool -> !torch.tensor<[2,3],f32>
      torch.prim.Loop.condition %true_22, iter(%54, %52 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>)
    } : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>)
    %30 = torch.aten.unsqueeze %29#0, %int0_6 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[1,2,3],f32>
    %31 = torch.aten.unsqueeze %29#1, %int0_6 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[1,2,3],f32>
    %32 = torch.aten.unsqueeze %28, %int1_7 : !torch.tensor<[15,2,3],f32>, !torch.int -> !torch.vtensor<[15,1,2,3],f32>
    return %32, %30, %31 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
  }
}

I'm getting the error message:

./lstm.torch.mlir:93:13: error: failed to legalize operation 'torch.aten.slice.Tensor' that was explicitly marked illegal
      %56 = torch.aten.select.int %28, %int0_23, %arg4 : !torch.tensor<[15,2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32>
            ^

With details that indicates a failure of:

`*(click to see full error message)* Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed. `
./lstm.torch.mlir:93:13: note: see current operation: %1298 = "torch.aten.slice.Tensor"(%785, %22, %1296, %1297, %19) : (!torch.tensor<[15,2,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.tensor<[1,2,3],f32> loc("./lstm.torch.mlir":93:13)
torch-mlir-opt: /home/azureuser/torch-mlir/externals/llvm-project/mlir/include/mlir/IR/UseDefLists.h:198: mlir::IRObjectWithUseList<mlir::OpOperand>::~IRObjectWithUseList() [OperandType = mlir::OpOperand]: Assertion `use_empty() && "Cannot destroy a value that still has uses!"' failed.

renxida avatar Mar 25 '24 19:03 renxida

Quinn helped.

Will try to stick to using only VTensors

renxida avatar Mar 25 '24 20:03 renxida

New error:

/home/azureuser/torch-mlir/build/bin/torch-mlir-opt: /home/azureuser/miniconda/lib/libtinfo.so.6: no version information available (required by /home/azureuser/torch-mlir/build/bin/torch-mlir-opt)
./lstm.torch.mlir:93:13: error: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
      %56 = torch.aten.index_put %arg5, %55, %53, %false : !torch.vtensor<[15,2,3],f32>, !torch.list<vtensor<[1],si64>>, !torch.vtensor<[2,3],f32>, !torch.bool -> !torch.vtensor<[15,2,3],f32>
            ^

Torch IR:

```mlir module { func.func @lstm(%arg0: !torch.vtensor, %arg1: !torch.vtensor, %arg2: !torch.vtensor, %arg3: !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none %int0 = torch.constant.int 0 %int0_0 = torch.constant.int 0 %0 = torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_1 = torch.constant.int 0 %int0_2 = torch.constant.int 0 %1 = torch.aten.select.int %arg2, %int0_1, %int0_2 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_3 = torch.constant.int 0 %int0_4 = torch.constant.int 0 %2 = torch.aten.select.int %arg3, %int0_3, %int0_4 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %none_5 = torch.constant.none %int0_6 = torch.constant.int 0 %int1_7 = torch.constant.int 1 %3 = torch.prim.ListConstruct %int1, %int2, %int3 : (!torch.int, !torch.int, !torch.int) -> !torch.list %int6 = torch.constant.int 6 %4 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor %5 = torch.aten.zeros %3, %int6, %none_5, %none_5, %none_5 : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor %int0_8 = torch.constant.int 0 %int0_9 = torch.constant.int 0 %6 = torch.aten.select.int %4, %int0_8, %int0_9 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int0_10 = torch.constant.int 0 %int0_11 = torch.constant.int 0 %7 = torch.aten.select.int %5, %int0_10, %int0_11 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int12 = torch.constant.int 12 %int24 = torch.constant.int 24 %8 = torch.aten.slice.Tensor %2, %int0_6, %int0_6, %int12, %int1_7 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %9 = torch.aten.slice.Tensor %2, %int0_6, %int12, %int24, %int1_7 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %none_12 = torch.constant.none %int0_13 = torch.constant.int 0 %int1_14 = torch.constant.int 1 %int15 = torch.constant.int 15 %int2_15 = torch.constant.int 2 %int3_16 = torch.constant.int 3 %int3_17 = torch.constant.int 3 %int6_18 = torch.constant.int 6 %int9 = torch.constant.int 9 %int12_19 = torch.constant.int 12 %10 = torch.aten.slice.Tensor %0, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %11 = torch.aten.slice.Tensor %0, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %12 = torch.aten.slice.Tensor %0, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %13 = torch.aten.slice.Tensor %0, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %14 = torch.aten.slice.Tensor %1, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %15 = torch.aten.slice.Tensor %1, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %16 = torch.aten.slice.Tensor %1, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %17 = torch.aten.slice.Tensor %1, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %18 = torch.aten.slice.Tensor %8, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %19 = torch.aten.slice.Tensor %8, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %20 = torch.aten.slice.Tensor %8, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %21 = torch.aten.slice.Tensor %8, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %22 = torch.aten.slice.Tensor %9, %int0_13, %int0_13, %int3_17, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %23 = torch.aten.slice.Tensor %9, %int0_13, %int3_17, %int6_18, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %24 = torch.aten.slice.Tensor %9, %int0_13, %int6_18, %int9, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %25 = torch.aten.slice.Tensor %9, %int0_13, %int9, %int12_19, %int1_14 : !torch.vtensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor %26 = torch.prim.ListConstruct %int15, %int2_15, %int3_16 : (!torch.int, !torch.int, !torch.int) -> !torch.list %int6_20 = torch.constant.int 6 %27 = torch.aten.zeros %26, %int6_20, %none_12, %none_12, %none_12 : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor %int15_21 = torch.constant.int 15 %true = torch.constant.bool true %false = torch.constant.bool false %int0_22 = torch.constant.int 0 %28:3 = torch.prim.Loop %int15_21, %true, init(%27, %6, %7) { ^bb0(%arg4: !torch.int, %arg5: !torch.vtensor, %arg6: !torch.vtensor, %arg7: !torch.vtensor): %32 = torch.aten.select.int %arg0, %int0_22, %arg4 : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor %int1_23 = torch.constant.int 1 %33 = torch.aten.linear %32, %10, %18 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %34 = torch.aten.linear %arg6, %14, %22 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %35 = torch.aten.add.Tensor %33, %34, %int1_23 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor %36 = torch.aten.sigmoid %35 : !torch.vtensor -> !torch.vtensor %37 = torch.aten.linear %32, %11, %19 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %38 = torch.aten.linear %arg6, %15, %23 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %39 = torch.aten.add.Tensor %37, %38, %int1_23 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor %40 = torch.aten.sigmoid %39 : !torch.vtensor -> !torch.vtensor %41 = torch.aten.linear %32, %12, %20 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %42 = torch.aten.linear %arg6, %16, %24 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %43 = torch.aten.add.Tensor %41, %42, %int1_23 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor %44 = torch.aten.sigmoid %43 : !torch.vtensor -> !torch.vtensor %45 = torch.aten.linear %32, %13, %21 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %46 = torch.aten.linear %arg6, %17, %25 : !torch.vtensor, !torch.vtensor, !torch.vtensor -> !torch.vtensor %47 = torch.aten.add.Tensor %45, %46, %int1_23 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor %48 = torch.aten.tanh %47 : !torch.vtensor -> !torch.vtensor %49 = torch.aten.mul.Tensor %44, %arg7 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %50 = torch.aten.mul.Tensor %36, %48 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %51 = torch.aten.add.Tensor %49, %50, %int1_23 : !torch.vtensor, !torch.vtensor, !torch.int -> !torch.vtensor %52 = torch.aten.tanh %51 : !torch.vtensor -> !torch.vtensor %53 = torch.aten.mul.Tensor %40, %52 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %54 = torch.aten.tensor.int %arg4, %none_12, %none_12, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor %55 = torch.prim.ListConstruct %54 : (!torch.vtensor) -> !torch.list> %56 = torch.aten.index_put %arg5, %55, %53, %false : !torch.vtensor, !torch.list>, !torch.vtensor, !torch.bool -> !torch.vtensor torch.prim.Loop.condition %true, iter(%56, %53, %51 : !torch.vtensor, !torch.vtensor, !torch.vtensor) } : (!torch.int, !torch.bool, !torch.vtensor, !torch.vtensor, !torch.vtensor) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor) %29 = torch.aten.unsqueeze %28#1, %int0_6 : !torch.vtensor, !torch.int -> !torch.vtensor %30 = torch.aten.unsqueeze %28#2, %int0_6 : !torch.vtensor, !torch.int -> !torch.vtensor %31 = torch.aten.unsqueeze %28#0, %int1_7 : !torch.vtensor, !torch.int -> !torch.vtensor return %31, %29, %30 : !torch.vtensor, !torch.vtensor, !torch.vtensor } } ```

renxida avatar Mar 26 '24 00:03 renxida

https://github.com/pytorch/pytorch/issues/91439

LSTM e2e test case fails, but only on torch-nightly.

Should I remove the test case or somehow xfail it only for nightly hm

renxida avatar Mar 29 '24 20:03 renxida

A lot of this code does not use camelCase because the onnx.LSTM documentation uses snake_case and I'm trying to stay consistent with that.

renxida avatar Mar 29 '24 20:03 renxida

Note: a e2e test has been added but xfailed. The actual numerical consistency test is done here: https://github.com/nod-ai/SHARK-TestSuite/pull/142

renxida avatar Apr 01 '24 16:04 renxida

@renxida, you have done the LLVM bump in this PR. Do we need that for the changes in this PR? If not, can we do it in a separate PR?

Also, the CI seems to be broken might be because of the LLVM bump.

vivekkhandelwal1 avatar Apr 08 '24 05:04 vivekkhandelwal1

@renxida, you have done the LLVM bump in this PR. Do we need that for the changes in this PR? If not, can we do it in a separate PR?

Also, the CI seems to be broken might be because of the LLVM bump.

no but i was keeping this up to date with main because github is telling me i have merge conflicts.

is there a better way i should have done it?

renxida avatar Apr 08 '24 17:04 renxida