onnx-mlir
onnx-mlir copied to clipboard
Gemm inside a loop-if
With the new Keras models translated to TF and then to ONNX, we have the following pattern:
onnx.loop (...) {
...
onnx.gemm(...)
'''
}
Simplified to the extreme, I isolated this pattern that cause problem once we are in the krnl dialect:
%1001 = krnl.define_loops 1
krnl.iterate(%1001) with (%1001 -> %1002 = 0 to 35) {
%1010 = krnl.load %1000[] : memref<i1>
scf.if %1010 { // comment out to make it work
// pattern generated when lowering Gemm
// enf of pattern generated for Gemm
} // if
} // iter
A pattern that has only the "if" or only the "iterate" works just fine. But the one with both results in the following error when further lowering to affine.
(onnx-mlir) alexe@pancetta i117-gemm-in-loop % onnx-mlir-opt --convert-krnl-to-affine --canonicalize onnx-gemm-191-krnl-loopif.mlir
Assertion failed: (isPerfectlyNested(input) && "input not perfectly nested"), function permuteLoops, file /Users/alexe/Onnxcode/llvm-project/mlir/lib/Transforms/Utils/LoopUtils.cpp, line 1548.
PLEASE submit a bug report to https://bugs.llvm.org/ and include the crash backtrace.
Stack dump:
0. Program arguments: onnx-mlir-opt --convert-krnl-to-affine --canonicalize onnx-gemm-191-krnl-loopif.mlir
1. Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0 onnx-mlir-opt 0x00000001094b05bb llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 43
1 onnx-mlir-opt 0x00000001094af238 llvm::sys::RunSignalHandlers() + 248
2 onnx-mlir-opt 0x00000001094b0fb7 SignalHandler(int) + 295
3 libsystem_platform.dylib 0x00007fff2049dd7d _sigtramp + 29
4 libsystem_platform.dylib 0x0000000000018b00 _sigtramp + 18446603339974552992
5 libsystem_c.dylib 0x00007fff203ad411 abort + 120
6 libsystem_c.dylib 0x00007fff203ac7e8 err + 0
7 onnx-mlir-opt 0x0000000109587143 mlir::permuteLoops(llvm::MutableArrayRef<mlir::AffineForOp>, llvm::ArrayRef<unsigned int>) (.cold.19) + 35
8 onnx-mlir-opt 0x0000000108c6dcbf mlir::permuteLoops(llvm::MutableArrayRef<mlir::AffineForOp>, llvm::ArrayRef<unsigned int>) + 1535
9 onnx-mlir-opt 0x00000001087222bb (anonymous namespace)::interpretOperation(mlir::Operation*, mlir::OpBuilder&, llvm::SmallDenseMap<mlir::Value, mlir::AffineForOp, 4u, llvm::DenseMapInfo<mlir::Value>, llvm::detail::DenseMapPair<mlir::Value, mlir::AffineForOp> >&, llvm::SmallPtrSetImpl<mlir::Operation*>&, (anonymous namespace)::LoopBodyMover&) + 4283
10 onnx-mlir-opt 0x00000001087213ad (anonymous namespace)::interpretOperation(mlir::Operation*, mlir::OpBuilder&, llvm::SmallDenseMap<mlir::Value, mlir::AffineForOp, 4u, llvm::DenseMapInfo<mlir::Value>, llvm::detail::DenseMapPair<mlir::Value, mlir::AffineForOp> >&, llvm::SmallPtrSetImpl<mlir::Operation*>&, (anonymous namespace)::LoopBodyMover&) + 429
11 onnx-mlir-opt 0x00000001087213ad (anonymous namespace)::interpretOperation(mlir::Operation*, mlir::OpBuilder&, llvm::SmallDenseMap<mlir::Value, mlir::AffineForOp, 4u, llvm::DenseMapInfo<mlir::Value>, llvm::detail::DenseMapPair<mlir::Value, mlir::AffineForOp> >&, llvm::SmallPtrSetImpl<mlir::Operation*>&, (anonymous namespace)::LoopBodyMover&) + 429
12 onnx-mlir-opt 0x00000001087213ad (anonymous namespace)::interpretOperation(mlir::Operation*, mlir::OpBuilder&, llvm::SmallDenseMap<mlir::Value, mlir::AffineForOp, 4u, llvm::DenseMapInfo<mlir::Value>, llvm::detail::DenseMapPair<mlir::Value, mlir::AffineForOp> >&, llvm::SmallPtrSetImpl<mlir::Operation*>&, (anonymous namespace)::LoopBodyMover&) + 429
13 onnx-mlir-opt 0x000000010871f93a (anonymous namespace)::ConvertKrnlToAffinePass::runOnFunction() + 394
14 onnx-mlir-opt 0x0000000108e0dbb7 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 487
15 onnx-mlir-opt 0x0000000108e0e096 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::__1::unique_ptr<mlir::Pass, std::__1::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 134
16 onnx-mlir-opt 0x0000000108e13ee4 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8::operator()(llvm::MutableArrayRef<mlir::OpPassManager>) const + 452
17 onnx-mlir-opt 0x0000000108e0efc1 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) + 1761
18 onnx-mlir-opt 0x0000000108e0dd4c mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 892
19 onnx-mlir-opt 0x0000000108e0e096 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::__1::unique_ptr<mlir::Pass, std::__1::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 134
20 onnx-mlir-opt 0x0000000108e10ba3 mlir::PassManager::run(mlir::Operation*) + 739
21 onnx-mlir-opt 0x0000000108aa4b90 performActions(llvm::raw_ostream&, bool, bool, llvm::SourceMgr&, mlir::MLIRContext*, mlir::PassPipelineCLParser const&) + 400
22 onnx-mlir-opt 0x0000000108aa2dea processBuffer(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer> >, bool, bool, bool, bool, mlir::PassPipelineCLParser const&, mlir::DialectRegistry&) + 410
23 onnx-mlir-opt 0x0000000108aa2c24 mlir::MlirOptMain(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer> >, mlir::PassPipelineCLParser const&, mlir::DialectRegistry&, bool, bool, bool, bool, bool) + 180
24 onnx-mlir-opt 0x000000010834a05c main + 1468
25 libdyld.dylib 0x00007fff20473f3d start + 1
zsh: abort onnx-mlir-opt --convert-krnl-to-affine --canonicalize
For ref, here is the file that has both the iterate and if.
module {
func @test_gemm_big(%arg0: memref<1x1xf32>, %arg1: memref<1x9xf32>, %arg2: memref<1x9xf32>) -> memref<1x9xf32> {
%c1 = constant 1 : index
%c9 = constant 9 : index
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<512x128xf32>
%2 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
%3 = memref.alloc() {alignment = 128 : i64} : memref<1x9xf32>
%1000 = memref.alloc() : memref<i1>
%1001 = krnl.define_loops 1
krnl.iterate(%1001) with (%1001 -> %1002 = 0 to 35) {
%1010 = krnl.load %1000[] : memref<i1>
scf.if %1010 { // comment out to make it work
// pattern generated when lowering Gemm
%4:2 = krnl.define_loops 2
krnl.iterate(%4#0, %4#1) with (%4#0 -> %arg3 = 0 to 1, %4#1 -> %arg4 = 0 to 9) {
%7:2 = krnl.get_induction_var_value(%4#0, %4#1) : (!krnl.loop, !krnl.loop) -> (index, index)
krnl.store %cst, %3[%7#0, %7#1] : memref<1x9xf32>
}
%5:3 = krnl.define_loops 3
%loop_block, %loop_local = krnl.block %5#0 64 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_0, %loop_local_1 = krnl.block %loop_local 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_2, %loop_local_3 = krnl.block %5#1 128 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_4, %loop_local_5 = krnl.block %loop_local_3 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_6, %loop_local_7 = krnl.block %5#2 512 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
krnl.permute(%loop_block, %loop_block_0, %loop_local_1, %loop_block_2, %loop_block_4, %loop_local_5, %loop_block_6, %loop_local_7) [0, 4, 5, 1, 3, 6, 2, 7] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
krnl.iterate(%loop_block, %loop_block_2) with (%5#0 -> %arg3 = 0 to 1, %5#1 -> %arg4 = 0 to 9) {
%7:2 = krnl.get_induction_var_value(%loop_block, %loop_block_2) : (!krnl.loop, !krnl.loop) -> (index, index)
krnl.copy_to_tile_buffer %0, %3[%7#0, %7#1], %cst {padToNext = [], tileSize = [], transpose = false} : memref<64x512xf32>, memref<1x9xf32>
krnl.iterate(%loop_block_6) with (%5#2 -> %arg5 = 0 to 1) {
%8 = krnl.get_induction_var_value(%loop_block_6) : (!krnl.loop) -> index
krnl.copy_to_tile_buffer %2, %arg0[%7#0, %8], %cst {padToNext = [], tileSize = [], transpose = false} : memref<64x512xf32>, memref<1x1xf32>
krnl.copy_to_tile_buffer %1, %arg1[%8, %7#1], %cst {padToNext = [], tileSize = [], transpose = false} : memref<512x128xf32>, memref<1x9xf32>
krnl.iterate(%loop_block_4, %loop_block_0) with () {
%9:2 = krnl.get_induction_var_value(%loop_block_4, %loop_block_0) : (!krnl.loop, !krnl.loop) -> (index, index)
krnl.matmul %2[%7#0, %8], %1[%8, %7#1], %0[%7#0, %7#1], (%loop_local_1, %loop_local_5, %loop_local_7), (%9#1, %9#0, %8), (%c1, %c9, %c1) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 512], overcompute = false, simdize = true, unroll = true} : memref<64x512xf32>, memref<512x128xf32>, memref<64x512xf32>, (!krnl.loop, !krnl.loop, !krnl.loop)
}
}
krnl.copy_from_tile_buffer %0, %3[%7#0, %7#1] {tileSize = []} : memref<64x512xf32>, memref<1x9xf32>
}
%6:2 = krnl.define_loops 2
krnl.iterate(%6#0, %6#1) with (%6#0 -> %arg3 = 0 to 1, %6#1 -> %arg4 = 0 to 9) {
%7:2 = krnl.get_induction_var_value(%6#0, %6#1) : (!krnl.loop, !krnl.loop) -> (index, index)
%8 = krnl.load %3[%7#0, %7#1] : memref<1x9xf32>
%9 = krnl.load %arg2[%c0, %7#1] : memref<1x9xf32>
%10 = addf %8, %9 : f32
krnl.store %10, %3[%7#0, %7#1] : memref<1x9xf32>
}
// enf of pattern generated for Gemm
} // if
} // iter
memref.dealloc %2 : memref<64x512xf32>
memref.dealloc %1 : memref<512x128xf32>
memref.dealloc %0 : memref<64x512xf32>
return %3 : memref<1x9xf32>
}
}
Smaller with the krnl.copy removed as well as the init/add C loops. Even the krnl.matmul can be removed (but I kept it as it does add loops and conditionals in the innermost structure).
module {
func @test_gemm_big(%arg0: memref<1x1xf32>, %arg1: memref<1x9xf32>, %arg2: memref<1x9xf32>) -> memref<1x9xf32> {
%c1 = constant 1 : index
%c9 = constant 9 : index
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<512x128xf32>
%2 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
%3 = memref.alloc() {alignment = 128 : i64} : memref<1x9xf32>
%1000 = memref.alloc() : memref<i1>
%1001 = krnl.define_loops 1
krnl.iterate(%1001) with (%1001 -> %1002 = 0 to 35) {
%1010 = krnl.load %1000[] : memref<i1>
scf.if %1010 { // comment out to make it work
// pattern generated when lowering Gemm
%5:3 = krnl.define_loops 3
%loop_block, %loop_local = krnl.block %5#0 64 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_0, %loop_local_1 = krnl.block %loop_local 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_2, %loop_local_3 = krnl.block %5#1 128 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_4, %loop_local_5 = krnl.block %loop_local_3 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_6, %loop_local_7 = krnl.block %5#2 512 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
krnl.permute(%loop_block, %loop_block_0, %loop_local_1, %loop_block_2, %loop_block_4, %loop_local_5, %loop_block_6, %loop_local_7) [0, 4, 5, 1, 3, 6, 2, 7] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
krnl.iterate(%loop_block, %loop_block_2) with (%5#0 -> %arg3 = 0 to 1, %5#1 -> %arg4 = 0 to 9) {
%7:2 = krnl.get_induction_var_value(%loop_block, %loop_block_2) : (!krnl.loop, !krnl.loop) -> (index, index)
krnl.iterate(%loop_block_6) with (%5#2 -> %arg5 = 0 to 1) {
%8 = krnl.get_induction_var_value(%loop_block_6) : (!krnl.loop) -> index
krnl.iterate(%loop_block_4, %loop_block_0) with () {
%9:2 = krnl.get_induction_var_value(%loop_block_4, %loop_block_0) : (!krnl.loop, !krnl.loop) -> (index, index)
// matmul can be commented out and the error still shows.
krnl.matmul %2[%7#0, %8], %1[%8, %7#1], %0[%7#0, %7#1], (%loop_local_1, %loop_local_5, %loop_local_7), (%9#1, %9#0, %8), (%c1, %c9, %c1) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 512], overcompute = false, simdize = true, unroll = true} : memref<64x512xf32>, memref<512x128xf32>, memref<64x512xf32>, (!krnl.loop, !krnl.loop, !krnl.loop)
}
}
}
// enf of pattern generated for Gemm
} // if
} // iter
memref.dealloc %2 : memref<64x512xf32>
memref.dealloc %1 : memref<512x128xf32>
memref.dealloc %0 : memref<64x512xf32>
return %3 : memref<1x9xf32>
}
}
@AlexandreEichenberger there seems to be a really easy fix, if you just put all the unoptimized loops in a single paranthesis.
module {
func @test_gemm_big(%arg0: memref<1x1xf32>, %arg1: memref<1x9xf32>, %arg2: memref<1x9xf32>) -> memref<1x9xf32> {
%c1 = constant 1 : index
%c9 = constant 9 : index
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%0 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
%1 = memref.alloc() {alignment = 128 : i64} : memref<512x128xf32>
%2 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
%3 = memref.alloc() {alignment = 128 : i64} : memref<1x9xf32>
%1000 = memref.alloc() : memref<i1>
%1001 = krnl.define_loops 1
krnl.iterate(%1001) with (%1001 -> %1002 = 0 to 35) {
%1010 = krnl.load %1000[] : memref<i1>
scf.if %1010 { // comment out to make it work
// pattern generated when lowering Gemm
%5:3 = krnl.define_loops 3
%loop_block, %loop_local = krnl.block %5#0 64 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_0, %loop_local_1 = krnl.block %loop_local 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_2, %loop_local_3 = krnl.block %5#1 128 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_4, %loop_local_5 = krnl.block %loop_local_3 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
%loop_block_6, %loop_local_7 = krnl.block %5#2 512 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
krnl.permute(%loop_block, %loop_block_0, %loop_local_1, %loop_block_2, %loop_block_4, %loop_local_5, %loop_block_6, %loop_local_7) [0, 4, 5, 1, 3, 6, 2, 7] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
krnl.iterate(%loop_block, %loop_block_2) with (%5#0 -> %arg3 = 0 to 1, %5#1 -> %arg4 = 0 to 9, %5#2 -> %arg5 = 0 to 1) {
%7:2 = krnl.get_induction_var_value(%loop_block, %loop_block_2) : (!krnl.loop, !krnl.loop) -> (index, index)
krnl.iterate(%loop_block_6) with () {
%8 = krnl.get_induction_var_value(%loop_block_6) : (!krnl.loop) -> index
krnl.iterate(%loop_block_4, %loop_block_0) with () {
%9:2 = krnl.get_induction_var_value(%loop_block_4, %loop_block_0) : (!krnl.loop, !krnl.loop) -> (index, index)
// matmul can be commented out and the error still shows.
krnl.matmul %2[%7#0, %8], %1[%8, %7#1], %0[%7#0, %7#1], (%loop_local_1, %loop_local_5, %loop_local_7), (%9#1, %9#0, %8), (%c1, %c9, %c1) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 512], overcompute = false, simdize = true, unroll = true} : memref<64x512xf32>, memref<512x128xf32>, memref<64x512xf32>, (!krnl.loop, !krnl.loop, !krnl.loop)
}
}
}
// enf of pattern generated for Gemm
} // if
} // iter
memref.dealloc %2 : memref<64x512xf32>
memref.dealloc %1 : memref<512x128xf32>
memref.dealloc %0 : memref<64x512xf32>
return %3 : memref<1x9xf32>
}
}
Let me look at this issue a bit more to get to the root of the problem.
Confirming that the proposed solution works