onnx-mlir icon indicating copy to clipboard operation
onnx-mlir copied to clipboard

Gemm inside a loop-if

Open AlexandreEichenberger opened this issue 3 years ago • 3 comments

With the new Keras models translated to TF and then to ONNX, we have the following pattern:

  onnx.loop (...) {
     ...
     onnx.gemm(...)
   '''
 }

Simplified to the extreme, I isolated this pattern that cause problem once we are in the krnl dialect:

    %1001 = krnl.define_loops 1
    krnl.iterate(%1001) with (%1001 -> %1002 = 0 to 35) {
      %1010 = krnl.load %1000[] : memref<i1>
      scf.if %1010 { // comment out to make it work

        // pattern generated when lowering Gemm
        // enf of pattern generated for Gemm
        
      } // if
    } // iter

A pattern that has only the "if" or only the "iterate" works just fine. But the one with both results in the following error when further lowering to affine.

(onnx-mlir) alexe@pancetta i117-gemm-in-loop % onnx-mlir-opt --convert-krnl-to-affine --canonicalize onnx-gemm-191-krnl-loopif.mlir
Assertion failed: (isPerfectlyNested(input) && "input not perfectly nested"), function permuteLoops, file /Users/alexe/Onnxcode/llvm-project/mlir/lib/Transforms/Utils/LoopUtils.cpp, line 1548.
PLEASE submit a bug report to https://bugs.llvm.org/ and include the crash backtrace.
Stack dump:
0.	Program arguments: onnx-mlir-opt --convert-krnl-to-affine --canonicalize onnx-gemm-191-krnl-loopif.mlir
1.	Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  onnx-mlir-opt            0x00000001094b05bb llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 43
1  onnx-mlir-opt            0x00000001094af238 llvm::sys::RunSignalHandlers() + 248
2  onnx-mlir-opt            0x00000001094b0fb7 SignalHandler(int) + 295
3  libsystem_platform.dylib 0x00007fff2049dd7d _sigtramp + 29
4  libsystem_platform.dylib 0x0000000000018b00 _sigtramp + 18446603339974552992
5  libsystem_c.dylib        0x00007fff203ad411 abort + 120
6  libsystem_c.dylib        0x00007fff203ac7e8 err + 0
7  onnx-mlir-opt            0x0000000109587143 mlir::permuteLoops(llvm::MutableArrayRef<mlir::AffineForOp>, llvm::ArrayRef<unsigned int>) (.cold.19) + 35
8  onnx-mlir-opt            0x0000000108c6dcbf mlir::permuteLoops(llvm::MutableArrayRef<mlir::AffineForOp>, llvm::ArrayRef<unsigned int>) + 1535
9  onnx-mlir-opt            0x00000001087222bb (anonymous namespace)::interpretOperation(mlir::Operation*, mlir::OpBuilder&, llvm::SmallDenseMap<mlir::Value, mlir::AffineForOp, 4u, llvm::DenseMapInfo<mlir::Value>, llvm::detail::DenseMapPair<mlir::Value, mlir::AffineForOp> >&, llvm::SmallPtrSetImpl<mlir::Operation*>&, (anonymous namespace)::LoopBodyMover&) + 4283
10 onnx-mlir-opt            0x00000001087213ad (anonymous namespace)::interpretOperation(mlir::Operation*, mlir::OpBuilder&, llvm::SmallDenseMap<mlir::Value, mlir::AffineForOp, 4u, llvm::DenseMapInfo<mlir::Value>, llvm::detail::DenseMapPair<mlir::Value, mlir::AffineForOp> >&, llvm::SmallPtrSetImpl<mlir::Operation*>&, (anonymous namespace)::LoopBodyMover&) + 429
11 onnx-mlir-opt            0x00000001087213ad (anonymous namespace)::interpretOperation(mlir::Operation*, mlir::OpBuilder&, llvm::SmallDenseMap<mlir::Value, mlir::AffineForOp, 4u, llvm::DenseMapInfo<mlir::Value>, llvm::detail::DenseMapPair<mlir::Value, mlir::AffineForOp> >&, llvm::SmallPtrSetImpl<mlir::Operation*>&, (anonymous namespace)::LoopBodyMover&) + 429
12 onnx-mlir-opt            0x00000001087213ad (anonymous namespace)::interpretOperation(mlir::Operation*, mlir::OpBuilder&, llvm::SmallDenseMap<mlir::Value, mlir::AffineForOp, 4u, llvm::DenseMapInfo<mlir::Value>, llvm::detail::DenseMapPair<mlir::Value, mlir::AffineForOp> >&, llvm::SmallPtrSetImpl<mlir::Operation*>&, (anonymous namespace)::LoopBodyMover&) + 429
13 onnx-mlir-opt            0x000000010871f93a (anonymous namespace)::ConvertKrnlToAffinePass::runOnFunction() + 394
14 onnx-mlir-opt            0x0000000108e0dbb7 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 487
15 onnx-mlir-opt            0x0000000108e0e096 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::__1::unique_ptr<mlir::Pass, std::__1::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 134
16 onnx-mlir-opt            0x0000000108e13ee4 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8::operator()(llvm::MutableArrayRef<mlir::OpPassManager>) const + 452
17 onnx-mlir-opt            0x0000000108e0efc1 mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) + 1761
18 onnx-mlir-opt            0x0000000108e0dd4c mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 892
19 onnx-mlir-opt            0x0000000108e0e096 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::__1::unique_ptr<mlir::Pass, std::__1::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 134
20 onnx-mlir-opt            0x0000000108e10ba3 mlir::PassManager::run(mlir::Operation*) + 739
21 onnx-mlir-opt            0x0000000108aa4b90 performActions(llvm::raw_ostream&, bool, bool, llvm::SourceMgr&, mlir::MLIRContext*, mlir::PassPipelineCLParser const&) + 400
22 onnx-mlir-opt            0x0000000108aa2dea processBuffer(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer> >, bool, bool, bool, bool, mlir::PassPipelineCLParser const&, mlir::DialectRegistry&) + 410
23 onnx-mlir-opt            0x0000000108aa2c24 mlir::MlirOptMain(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer> >, mlir::PassPipelineCLParser const&, mlir::DialectRegistry&, bool, bool, bool, bool, bool) + 180
24 onnx-mlir-opt            0x000000010834a05c main + 1468
25 libdyld.dylib            0x00007fff20473f3d start + 1
zsh: abort      onnx-mlir-opt --convert-krnl-to-affine --canonicalize

For ref, here is the file that has both the iterate and if.

module  {
  func @test_gemm_big(%arg0: memref<1x1xf32>, %arg1: memref<1x9xf32>, %arg2: memref<1x9xf32>) -> memref<1x9xf32> {
    %c1 = constant 1 : index
    %c9 = constant 9 : index
    %cst = constant 0.000000e+00 : f32
    %c0 = constant 0 : index
    %0 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
    %1 = memref.alloc() {alignment = 128 : i64} : memref<512x128xf32>
    %2 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
    %3 = memref.alloc() {alignment = 128 : i64} : memref<1x9xf32>
    %1000 = memref.alloc() : memref<i1>
    %1001 = krnl.define_loops 1
    krnl.iterate(%1001) with (%1001 -> %1002 = 0 to 35) {
      %1010 = krnl.load %1000[] : memref<i1>
      scf.if %1010 { // comment out to make it work

        // pattern generated when lowering Gemm
        %4:2 = krnl.define_loops 2
        krnl.iterate(%4#0, %4#1) with (%4#0 -> %arg3 = 0 to 1, %4#1 -> %arg4 = 0 to 9) {
          %7:2 = krnl.get_induction_var_value(%4#0, %4#1) : (!krnl.loop, !krnl.loop) -> (index, index)
          krnl.store %cst, %3[%7#0, %7#1] : memref<1x9xf32>
        }
        %5:3 = krnl.define_loops 3
        %loop_block, %loop_local = krnl.block %5#0 64 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_0, %loop_local_1 = krnl.block %loop_local 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_2, %loop_local_3 = krnl.block %5#1 128 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_4, %loop_local_5 = krnl.block %loop_local_3 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_6, %loop_local_7 = krnl.block %5#2 512 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        krnl.permute(%loop_block, %loop_block_0, %loop_local_1, %loop_block_2, %loop_block_4, %loop_local_5, %loop_block_6, %loop_local_7) [0, 4, 5, 1, 3, 6, 2, 7] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
        krnl.iterate(%loop_block, %loop_block_2) with (%5#0 -> %arg3 = 0 to 1, %5#1 -> %arg4 = 0 to 9) {
          %7:2 = krnl.get_induction_var_value(%loop_block, %loop_block_2) : (!krnl.loop, !krnl.loop) -> (index, index)
          krnl.copy_to_tile_buffer %0, %3[%7#0, %7#1], %cst {padToNext = [], tileSize = [], transpose = false} : memref<64x512xf32>, memref<1x9xf32>
          krnl.iterate(%loop_block_6) with (%5#2 -> %arg5 = 0 to 1) {
            %8 = krnl.get_induction_var_value(%loop_block_6) : (!krnl.loop) -> index
            krnl.copy_to_tile_buffer %2, %arg0[%7#0, %8], %cst {padToNext = [], tileSize = [], transpose = false} : memref<64x512xf32>, memref<1x1xf32>
            krnl.copy_to_tile_buffer %1, %arg1[%8, %7#1], %cst {padToNext = [], tileSize = [], transpose = false} : memref<512x128xf32>, memref<1x9xf32>
            krnl.iterate(%loop_block_4, %loop_block_0) with () {
              %9:2 = krnl.get_induction_var_value(%loop_block_4, %loop_block_0) : (!krnl.loop, !krnl.loop) -> (index, index)
              krnl.matmul %2[%7#0, %8], %1[%8, %7#1], %0[%7#0, %7#1], (%loop_local_1, %loop_local_5, %loop_local_7), (%9#1, %9#0, %8), (%c1, %c9, %c1) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 512], overcompute = false, simdize = true, unroll = true} : memref<64x512xf32>, memref<512x128xf32>, memref<64x512xf32>, (!krnl.loop, !krnl.loop, !krnl.loop)
            }
          }
          krnl.copy_from_tile_buffer %0, %3[%7#0, %7#1] {tileSize = []} : memref<64x512xf32>, memref<1x9xf32>
        }
        %6:2 = krnl.define_loops 2
        krnl.iterate(%6#0, %6#1) with (%6#0 -> %arg3 = 0 to 1, %6#1 -> %arg4 = 0 to 9) {
          %7:2 = krnl.get_induction_var_value(%6#0, %6#1) : (!krnl.loop, !krnl.loop) -> (index, index)
          %8 = krnl.load %3[%7#0, %7#1] : memref<1x9xf32>
          %9 = krnl.load %arg2[%c0, %7#1] : memref<1x9xf32>
          %10 = addf %8, %9 : f32
          krnl.store %10, %3[%7#0, %7#1] : memref<1x9xf32>
        }
        // enf of pattern generated for Gemm
        
      } // if
    } // iter
    memref.dealloc %2 : memref<64x512xf32>
    memref.dealloc %1 : memref<512x128xf32>
    memref.dealloc %0 : memref<64x512xf32>
    return %3 : memref<1x9xf32>
  }
}

AlexandreEichenberger avatar May 14 '21 20:05 AlexandreEichenberger

Smaller with the krnl.copy removed as well as the init/add C loops. Even the krnl.matmul can be removed (but I kept it as it does add loops and conditionals in the innermost structure).

module  {
  func @test_gemm_big(%arg0: memref<1x1xf32>, %arg1: memref<1x9xf32>, %arg2: memref<1x9xf32>) -> memref<1x9xf32> {
    %c1 = constant 1 : index
    %c9 = constant 9 : index
    %cst = constant 0.000000e+00 : f32
    %c0 = constant 0 : index
    %0 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
    %1 = memref.alloc() {alignment = 128 : i64} : memref<512x128xf32>
    %2 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
    %3 = memref.alloc() {alignment = 128 : i64} : memref<1x9xf32>
    %1000 = memref.alloc() : memref<i1>
    %1001 = krnl.define_loops 1
    krnl.iterate(%1001) with (%1001 -> %1002 = 0 to 35) {
      %1010 = krnl.load %1000[] : memref<i1>
      scf.if %1010 { // comment out to make it work

        // pattern generated when lowering Gemm
        %5:3 = krnl.define_loops 3
        %loop_block, %loop_local = krnl.block %5#0 64 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_0, %loop_local_1 = krnl.block %loop_local 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_2, %loop_local_3 = krnl.block %5#1 128 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_4, %loop_local_5 = krnl.block %loop_local_3 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_6, %loop_local_7 = krnl.block %5#2 512 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        krnl.permute(%loop_block, %loop_block_0, %loop_local_1, %loop_block_2, %loop_block_4, %loop_local_5, %loop_block_6, %loop_local_7) [0, 4, 5, 1, 3, 6, 2, 7] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
        krnl.iterate(%loop_block, %loop_block_2) with (%5#0 -> %arg3 = 0 to 1, %5#1 -> %arg4 = 0 to 9) {
          %7:2 = krnl.get_induction_var_value(%loop_block, %loop_block_2) : (!krnl.loop, !krnl.loop) -> (index, index)
          krnl.iterate(%loop_block_6) with (%5#2 -> %arg5 = 0 to 1) {
            %8 = krnl.get_induction_var_value(%loop_block_6) : (!krnl.loop) -> index
            krnl.iterate(%loop_block_4, %loop_block_0) with () {
              %9:2 = krnl.get_induction_var_value(%loop_block_4, %loop_block_0) : (!krnl.loop, !krnl.loop) -> (index, index)
              // matmul can be commented out and the error still shows.
              krnl.matmul %2[%7#0, %8], %1[%8, %7#1], %0[%7#0, %7#1], (%loop_local_1, %loop_local_5, %loop_local_7), (%9#1, %9#0, %8), (%c1, %c9, %c1) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 512], overcompute = false, simdize = true, unroll = true} : memref<64x512xf32>, memref<512x128xf32>, memref<64x512xf32>, (!krnl.loop, !krnl.loop, !krnl.loop)
            }
          }
        }
        // enf of pattern generated for Gemm
        
      } // if
    } // iter
    memref.dealloc %2 : memref<64x512xf32>
    memref.dealloc %1 : memref<512x128xf32>
    memref.dealloc %0 : memref<64x512xf32>
    return %3 : memref<1x9xf32>
  }
}

AlexandreEichenberger avatar May 14 '21 20:05 AlexandreEichenberger

@AlexandreEichenberger there seems to be a really easy fix, if you just put all the unoptimized loops in a single paranthesis.

module  {
  func @test_gemm_big(%arg0: memref<1x1xf32>, %arg1: memref<1x9xf32>, %arg2: memref<1x9xf32>) -> memref<1x9xf32> {
    %c1 = constant 1 : index
    %c9 = constant 9 : index
    %cst = constant 0.000000e+00 : f32
    %c0 = constant 0 : index
    %0 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
    %1 = memref.alloc() {alignment = 128 : i64} : memref<512x128xf32>
    %2 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
    %3 = memref.alloc() {alignment = 128 : i64} : memref<1x9xf32>
    %1000 = memref.alloc() : memref<i1>
    %1001 = krnl.define_loops 1
    krnl.iterate(%1001) with (%1001 -> %1002 = 0 to 35) {
      %1010 = krnl.load %1000[] : memref<i1>
      scf.if %1010 { // comment out to make it work

        // pattern generated when lowering Gemm
        %5:3 = krnl.define_loops 3
        %loop_block, %loop_local = krnl.block %5#0 64 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_0, %loop_local_1 = krnl.block %loop_local 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_2, %loop_local_3 = krnl.block %5#1 128 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_4, %loop_local_5 = krnl.block %loop_local_3 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        %loop_block_6, %loop_local_7 = krnl.block %5#2 512 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
        krnl.permute(%loop_block, %loop_block_0, %loop_local_1, %loop_block_2, %loop_block_4, %loop_local_5, %loop_block_6, %loop_local_7) [0, 4, 5, 1, 3, 6, 2, 7] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop
        krnl.iterate(%loop_block, %loop_block_2) with (%5#0 -> %arg3 = 0 to 1, %5#1 -> %arg4 = 0 to 9, %5#2 -> %arg5 = 0 to 1) {
          %7:2 = krnl.get_induction_var_value(%loop_block, %loop_block_2) : (!krnl.loop, !krnl.loop) -> (index, index)
          krnl.iterate(%loop_block_6) with () {
            %8 = krnl.get_induction_var_value(%loop_block_6) : (!krnl.loop) -> index
            krnl.iterate(%loop_block_4, %loop_block_0) with () {
              %9:2 = krnl.get_induction_var_value(%loop_block_4, %loop_block_0) : (!krnl.loop, !krnl.loop) -> (index, index)
              // matmul can be commented out and the error still shows.
              krnl.matmul %2[%7#0, %8], %1[%8, %7#1], %0[%7#0, %7#1], (%loop_local_1, %loop_local_5, %loop_local_7), (%9#1, %9#0, %8), (%c1, %c9, %c1) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 512], overcompute = false, simdize = true, unroll = true} : memref<64x512xf32>, memref<512x128xf32>, memref<64x512xf32>, (!krnl.loop, !krnl.loop, !krnl.loop)
            }
          }
        }
        // enf of pattern generated for Gemm
        
      } // if
    } // iter
    memref.dealloc %2 : memref<64x512xf32>
    memref.dealloc %1 : memref<512x128xf32>
    memref.dealloc %0 : memref<64x512xf32>
    return %3 : memref<1x9xf32>
  }
}

Let me look at this issue a bit more to get to the root of the problem.

tjingrant avatar Jun 07 '21 01:06 tjingrant

Confirming that the proposed solution works

AlexandreEichenberger avatar Nov 05 '21 19:11 AlexandreEichenberger