llvm-project icon indicating copy to clipboard operation
llvm-project copied to clipboard

[MLIR][Linalg] Use Top-Down traversal to safely optimize multi-use producer fusion

Open milos1397 opened this issue 1 month ago • 4 comments

Switches the greedy rewrite traversal for the multi-use producer fusion pattern to Top-Down (Pre-Order).

The previous Bottom-Up (Post-Order) traversal led to a critical SSA violation when a producer (P) had multiple users (I and C) and the first user (I) appeared before the current consumer (C) in the block. Processing the outer consumer (C) first and attempting to fuse P into C would create a new fused operation, F. The rewrite would attempt to replace P's result (used by I) with the output of F. However, since I is located before F in the block, this replacement breaks SSA dominance rules, leading to a crash. To ensure correctness, the first use (I) must be processed and fused before the second use (C). Using Top-Down traversal ensures that operations are visited and rewritten in the correct flow order.

Take a look at this example, which represents a three-operation chain where the first operation, P (%13:2), has two users: an intermediate operation I (%15:2) and a final consumer C (%17:2):

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
  func.func @avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
    %1 = llvm.mlir.constant(31 : index) : i64
    %11 = tensor.empty() : tensor<1x32x32x8xf32>
    %12 = tensor.empty() : tensor<1x32x32x8xindex>
    %13:2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x32x32x8xf32>) outs(%11, %12 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
    ^bb0(%in: f32, %out: f32, %out_0: index):
      %59 = linalg.index 1 : index
      linalg.yield %0, %59 : f32, index
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
    %14 = tensor.empty() : tensor<1x32x32x8xi64>
    %15:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) outs(%11, %14 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
      %59 = builtin.unrealized_conversion_cast %in_0 : index to i64
      linalg.yield %0, %59 : f32, i64
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
    %16 = tensor.empty() : tensor<1x32x32x8xi64>
    %17:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#1, %15#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) outs(%11, %16 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
      %59 = llvm.sub %1, %in_1 : i64
      linalg.yield %0, %59 : f32, i64
    } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
    return %17 : tensor<1x32x32x8xf32>
  }
}

If fused op is inserted at the position of %17, the rewrite mechanism must update all users of P's result (%13). Since the intermediate user I (%15) is before the final consumer C (%17) in the block, renaming I's operand (which is %13) to the output of the new fused operation results in a violation of SSA dominance, causing the compiler to crash.

Issue: #131446

milos1397 avatar Dec 14 '25 15:12 milos1397

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

github-actions[bot] avatar Dec 14 '25 15:12 github-actions[bot]

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Miloš Poletanović (milos1397)

Changes

Switches the greedy rewrite traversal for the multi-use producer fusion pattern to Top-Down (Pre-Order).

The previous Bottom-Up (Post-Order) traversal led to a critical SSA violation when a producer (P) had multiple users (I and C) and the first user (I) appeared before the current consumer (C) in the block. Processing the outer consumer (C) first and attempting to fuse P into C would create a new fused operation, F. The rewrite would attempt to replace P's result (used by I) with the output of F. However, since I is located before F in the block, this replacement breaks SSA dominance rules, leading to a crash. To ensure correctness, the first use (I) must be processed and fused before the second use (C). Using Top-Down traversal ensures that operations are visited and rewritten in the correct flow order.

Take a look at this example, which represents a three-operation chain where the first operation, P (%13:2), has two users: an intermediate operation I (%15:2) and a final consumer C (%17:2):

#map = affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d1, d2, d3)&gt;
module {
  func.func @<!-- -->avgpool2d_pad_top(%arg0: tensor&lt;1x32x32x8xf32&gt;) -&gt; tensor&lt;1x32x32x8xf32&gt; attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
    %1 = llvm.mlir.constant(31 : index) : i64
    %11 = tensor.empty() : tensor&lt;1x32x32x8xf32&gt;
    %12 = tensor.empty() : tensor&lt;1x32x32x8xindex&gt;
    %13:2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor&lt;1x32x32x8xf32&gt;) outs(%11, %12 : tensor&lt;1x32x32x8xf32&gt;, tensor&lt;1x32x32x8xindex&gt;) {
    ^bb0(%in: f32, %out: f32, %out_0: index):
      %59 = linalg.index 1 : index
      linalg.yield %0, %59 : f32, index
    } -&gt; (tensor&lt;1x32x32x8xf32&gt;, tensor&lt;1x32x32x8xindex&gt;)
    %14 = tensor.empty() : tensor&lt;1x32x32x8xi64&gt;
    %15:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#<!-- -->1 : tensor&lt;1x32x32x8xf32&gt;, tensor&lt;1x32x32x8xindex&gt;) outs(%11, %14 : tensor&lt;1x32x32x8xf32&gt;, tensor&lt;1x32x32x8xi64&gt;) {
    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
      %59 = builtin.unrealized_conversion_cast %in_0 : index to i64
      linalg.yield %0, %59 : f32, i64
    } -&gt; (tensor&lt;1x32x32x8xf32&gt;, tensor&lt;1x32x32x8xi64&gt;)
    %16 = tensor.empty() : tensor&lt;1x32x32x8xi64&gt;
    %17:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#<!-- -->1, %15#<!-- -->1 : tensor&lt;1x32x32x8xf32&gt;, tensor&lt;1x32x32x8xindex&gt;, tensor&lt;1x32x32x8xi64&gt;) outs(%11, %16 : tensor&lt;1x32x32x8xf32&gt;, tensor&lt;1x32x32x8xi64&gt;) {
    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
      %59 = llvm.sub %1, %in_1 : i64
      linalg.yield %0, %59 : f32, i64
    } -&gt; (tensor&lt;1x32x32x8xf32&gt;, tensor&lt;1x32x32x8xi64&gt;)
    return %17 : tensor&lt;1x32x32x8xf32&gt;
  }
}

If fused op is inserted at the position of %17, the rewrite mechanism must update all users of P's result (%13). Since the intermediate user I (%15) is before the final consumer C (%17) in the block, renaming I's operand (which is %13) to the output of the new fused operation results in a violation of SSA dominance, causing the compiler to crash.

Issue: #131446


Full diff: https://github.com/llvm/llvm-project/pull/172216.diff

2 Files Affected:

  • (modified) mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir (+71)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp (+2-1)
diff --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
index 7871ae08fd54a..96845448dd1c2 100644
--- a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -32,3 +32,74 @@ func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
 //      CHECK:   %[[RESULT:.+]]:3 = linalg.generic
 //      CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
+
+func.func @multi_use_producer_2(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
+  %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
+  %1 = llvm.mlir.constant(31 : index) : i64
+  %2 = tensor.empty() : tensor<1x32x32x8xf32>
+  %3 = tensor.empty() : tensor<1x32x32x8xindex>
+  %4:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0 : tensor<1x32x32x8xf32>) 
+  outs(%2, %3 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
+    ^bb0(%in: f32, %out: f32, %out_0: index):
+      %9 = linalg.index 1 : index
+      linalg.yield %0, %9 : f32, index
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
+
+  %5 = tensor.empty() : tensor<1x32x32x8xi64>
+  %6:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) 
+  outs(%2, %5 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
+      %9 = builtin.unrealized_conversion_cast %in_0 : index to i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+
+  %7 = tensor.empty() : tensor<1x32x32x8xi64>
+  %8:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1, %6#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) 
+  outs(%2, %7 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
+      %9 = llvm.sub %1, %in_1 : i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+  return %8#0 : tensor<1x32x32x8xf32>
+}
+// CHECK-LABEL: func @multi_use_producer_2(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x32x32x8xf32>)
+// CHECK-SAME: -> tensor<1x32x32x8xf32>
+// CHECK: %[[C31:.+]] = llvm.mlir.constant(31 : index) : i64
+// CHECK: %[[R0:.+]]:2 = linalg.generic {
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>)
+// CHECK-SAME: outs(%[[INIT:.+]], %[[INIT_1:.+]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[IN_2:.+]]: f32, %[[IN_3:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32, %[[OUT_I:.+]]: i64):
+// CHECK: %[[IDX_9:.+]] = linalg.index 1 : index
+// CHECK: %[[C_9:.+]] = builtin.unrealized_conversion_cast %[[IDX_9]] : index to i64
+// CHECK: %[[C_SUB:.+]] = llvm.sub %[[C31]], %[[C_9]] : i64
+// CHECK: linalg.yield %[[C0:.+]], %[[C_SUB]] : f32, i64
+// CHECK: } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..f8f2184b0de6f 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -252,7 +252,8 @@ struct TestLinalgElementwiseFusion
     if (fuseMultiUseProducer) {
       RewritePatternSet patterns(context);
       patterns.insert<TestMultiUseProducerFusion>(context);
-      if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
+      if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns),
+                GreedyRewriteConfig().setUseTopDownTraversal(true))))
         return signalPassFailure();
       return;
     }

llvmbot avatar Dec 14 '25 15:12 llvmbot

Thanks for the PR. I think previously when I tried switching for the cases where we dont hit such issues (and I didnt hit this issue), the top-down or bottom-up traversal didnt make a difference. I think we need to look into compilation time impacts that we arent setup to do upstream. Also you are only fixing a test here though.... Can you give a bit more context on how you are fixing this for your use case internally. Are you just updating the pattern application on your end. Could you update the documentation of linalg::fuseElementwiseOps (here) to indicate this.

The PR itself seems good to go for me.

MaheshRavishankar avatar Dec 15 '25 18:12 MaheshRavishankar

Hey @MaheshRavishankar, thanks for the review.

That’s correct, this PR just fixes the test case. In this instance, we are preparing the data so that linalg::fuseElementwiseOps processes operations in an order that yields a valid result (as explained above). Essentially, we can establish that a current precondition for using linalg::fuseElementwiseOps is that the IR must be processed in a TopDown order. On the other side, more robust logic could be added to linalg::fuseElementwiseOps itself to handle these cases automatically and prevent the ordering issue described above.

milos1397 avatar Dec 16 '25 16:12 milos1397

Can this be merged? @MaheshRavishankar

milos1397 avatar Dec 23 '25 11:12 milos1397

@milos1397 Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

github-actions[bot] avatar Dec 23 '25 16:12 github-actions[bot]