llvm-project icon indicating copy to clipboard operation
llvm-project copied to clipboard

[mlir][bufferization] Return BufferLikeType in BufferizableOpInterface

Open andrey-golubev opened this issue 6 months ago • 5 comments

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.

Relates to ee070d08163ac09842d9bf0c1315f311df39faf1.

andrey-golubev avatar Jun 19 '25 10:06 andrey-golubev

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Andrei Golubev (andrey-golubev)

Changes

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.


Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+2-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-7)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-33)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+28-24)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+21-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+53)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const BufferizationState &state,
                      SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state,
         SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+      return getBuffer().getType();
     }
   }];
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     if (!bufferType)
       return op->emitOpError("could not infer buffer type of block argument");
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 
 protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
-      return *trueType;
+      return cast<BufferLikeType>(*trueType);
     if (trueType->getMemorySpace() != falseType->getMemorySpace())
       return op->emitError("inconsistent memory space on true/false operands");
 
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
-        memrefType.getMemorySpace());
+        memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
-  if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(tensorType, options);
+  if (llvm::isa<BlockArgument>(value)) {
+    return cast<BufferLikeType>(
+        bufferization::getMemRefType(tensorType, options));
+  }
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return asMemRefType(getBufferType(equivalentOperand, options,
-                                      bufferizationState, invocationStack));
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+  return cast<BufferLikeType>(
+      getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return bufferizedType;
+      return cast<BufferLikeType>(bufferizedType);
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorType>(resultType);
-    return options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
-                                          options);
+      return cast<BufferLikeType>(
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
 
     // Best case: Both branches have the exact same buffer type.
     if (thenBufferType == elseBufferType)
-      return thenBufferType;
+      return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
   }
 };
 
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
     }
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 };
 
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only differ in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType = bufferization::detail::asMemRefType(
-      bufferization::getBufferType(initArg, options, state, invocationStack));
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
+  BufferLikeType yieldedValueBufferType;
   if (isa<BaseMemRefType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::detail::asMemRefType(bufferization::getBufferType(
-            yieldedValue, options, state, invocationStack));
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+      iterTensorType, yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      auto bufferType =
-          bufferization::getBufferType(bbArg, options, state, invocationStack);
-      if (failed(bufferType))
-        return failure();
-      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
-      return cast<BaseMemRefType>(*bufferType);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
     if (!isa<TensorType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
-        conditionYieldedVal, options, state, invocationStack));
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::detail::asMemRefType(
-          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
-                                       options, state, invocationStack));
+      return bufferization::getBufferType(
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
+    return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack));
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
     return {{op->getResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Opera...
[truncated]

llvmbot avatar Jun 19 '25 10:06 llvmbot

@llvm/pr-subscribers-mlir-tensor

Author: Andrei Golubev (andrey-golubev)

Changes

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.


Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+2-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-7)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-33)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+28-24)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+21-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+53)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const BufferizationState &state,
                      SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state,
         SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+      return getBuffer().getType();
     }
   }];
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     if (!bufferType)
       return op->emitOpError("could not infer buffer type of block argument");
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 
 protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
-      return *trueType;
+      return cast<BufferLikeType>(*trueType);
     if (trueType->getMemorySpace() != falseType->getMemorySpace())
       return op->emitError("inconsistent memory space on true/false operands");
 
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
-        memrefType.getMemorySpace());
+        memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
-  if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(tensorType, options);
+  if (llvm::isa<BlockArgument>(value)) {
+    return cast<BufferLikeType>(
+        bufferization::getMemRefType(tensorType, options));
+  }
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return asMemRefType(getBufferType(equivalentOperand, options,
-                                      bufferizationState, invocationStack));
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+  return cast<BufferLikeType>(
+      getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return bufferizedType;
+      return cast<BufferLikeType>(bufferizedType);
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorType>(resultType);
-    return options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
-                                          options);
+      return cast<BufferLikeType>(
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
 
     // Best case: Both branches have the exact same buffer type.
     if (thenBufferType == elseBufferType)
-      return thenBufferType;
+      return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
   }
 };
 
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
     }
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 };
 
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only differ in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType = bufferization::detail::asMemRefType(
-      bufferization::getBufferType(initArg, options, state, invocationStack));
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
+  BufferLikeType yieldedValueBufferType;
   if (isa<BaseMemRefType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::detail::asMemRefType(bufferization::getBufferType(
-            yieldedValue, options, state, invocationStack));
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+      iterTensorType, yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      auto bufferType =
-          bufferization::getBufferType(bbArg, options, state, invocationStack);
-      if (failed(bufferType))
-        return failure();
-      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
-      return cast<BaseMemRefType>(*bufferType);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
     if (!isa<TensorType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
-        conditionYieldedVal, options, state, invocationStack));
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::detail::asMemRefType(
-          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
-                                       options, state, invocationStack));
+      return bufferization::getBufferType(
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
+    return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack));
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
     return {{op->getResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Opera...
[truncated]

llvmbot avatar Jun 19 '25 10:06 llvmbot

@llvm/pr-subscribers-mlir-bufferization

Author: Andrei Golubev (andrey-golubev)

Changes

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.


Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+2-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-7)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-33)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+28-24)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+21-2)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+53)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const BufferizationState &state,
                      SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state,
         SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+      return getBuffer().getType();
     }
   }];
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     if (!bufferType)
       return op->emitOpError("could not infer buffer type of block argument");
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 
 protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
-      return *trueType;
+      return cast<BufferLikeType>(*trueType);
     if (trueType->getMemorySpace() != falseType->getMemorySpace())
       return op->emitError("inconsistent memory space on true/false operands");
 
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
-        memrefType.getMemorySpace());
+        memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
-  if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(tensorType, options);
+  if (llvm::isa<BlockArgument>(value)) {
+    return cast<BufferLikeType>(
+        bufferization::getMemRefType(tensorType, options));
+  }
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return asMemRefType(getBufferType(equivalentOperand, options,
-                                      bufferizationState, invocationStack));
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+  return cast<BufferLikeType>(
+      getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return bufferizedType;
+      return cast<BufferLikeType>(bufferizedType);
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorType>(resultType);
-    return options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
-                                          options);
+      return cast<BufferLikeType>(
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
 
     // Best case: Both branches have the exact same buffer type.
     if (thenBufferType == elseBufferType)
-      return thenBufferType;
+      return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
   }
 };
 
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
     }
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 };
 
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only differ in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType = bufferization::detail::asMemRefType(
-      bufferization::getBufferType(initArg, options, state, invocationStack));
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
+  BufferLikeType yieldedValueBufferType;
   if (isa<BaseMemRefType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::detail::asMemRefType(bufferization::getBufferType(
-            yieldedValue, options, state, invocationStack));
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+      iterTensorType, yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      auto bufferType =
-          bufferization::getBufferType(bbArg, options, state, invocationStack);
-      if (failed(bufferType))
-        return failure();
-      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
-      return cast<BaseMemRefType>(*bufferType);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
     if (!isa<TensorType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
-        conditionYieldedVal, options, state, invocationStack));
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::detail::asMemRefType(
-          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
-                                       options, state, invocationStack));
+      return bufferization::getBufferType(
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
+    return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack));
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
     return {{op->getResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Opera...
[truncated]

llvmbot avatar Jun 19 '25 10:06 llvmbot

@matthias-springer (gently pinging)

andrey-golubev avatar Jun 30 '25 09:06 andrey-golubev

I've applied the suggestions and also rebased (it's been a while since this patch was originally implemented). If there are no issues, could you merge it? (I lack commit rights) - Let's wait until CI finishes though.

andrey-golubev avatar Jun 30 '25 11:06 andrey-golubev

Can we wait a few more days? The IREE folks are still integrating the last bufferization changes. They are one of the few projects that use the bufferization extensively and are open-source, so I use them as an indicator to see if there problems with larger refactorings.

matthias-springer avatar Jun 30 '25 13:06 matthias-springer

Can we wait a few more days? The IREE folks are still integrating the last bufferization changes. They are one of the few projects that use the bufferization extensively and are open-source, so I use them as an indicator to see if there problems with larger refactorings.

No problem. Please let me know if there are any issues.

andrey-golubev avatar Jun 30 '25 15:06 andrey-golubev

@hanhanW I guess you might be interested to review this one (since you seem to be quite involved with bufferization right now).

andrey-golubev avatar Jul 01 '25 08:07 andrey-golubev

Can we wait a few more days? The IREE folks are still integrating the last bufferization changes. They are one of the few projects that use the bufferization extensively and are open-source, so I use them as an indicator to see if there problems with larger refactorings.

I'm merging the PR because the concern is addressed. IREE does not carry any bufferization revert now. I'm happy to help if there are issues in our integrate.

hanhanW avatar Jul 02 '25 18:07 hanhanW

LLVM Buildbot has detected a new failure on builder flang-arm64-windows-msvc running on linaro-armv8-windows-msvc-01 while building mlir at step 7 "test-build-unified-tree-check-flang".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/207/builds/3273

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-flang) failure: test (failure)
******************** TEST 'Flang :: Semantics/OpenMP/loop-association.f90' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 1
"C:\Users\tcwg\scoop\apps\python\current\python.exe" C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Semantics\OpenMP/../test_errors.py C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Semantics\OpenMP\loop-association.f90 c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe -fopenmp
# executed command: 'C:\Users\tcwg\scoop\apps\python\current\python.exe' 'C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Semantics\OpenMP/../test_errors.py' 'C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\llvm-project\flang\test\Semantics\OpenMP\loop-association.f90' 'c:\users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe' -fopenmp
# .---command stdout------------
# | PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
# | Stack dump:
# | 0.	Program arguments: C:\\Users\\tcwg\\llvm-worker\\flang-arm64-windows-msvc\\build\\bin\\flang -fc1 -triple aarch64-pc-windows-msvc19.39.33523 -fsyntax-only -mrelocation-model pic -pic-level 2 -target-cpu generic -target-feature +v8a -target-feature +fp-armv8 -target-feature +neon --dependent-lib=clang_rt.builtins-aarch64.lib -D_MT --dependent-lib=libcmt --dependent-lib=flang_rt.runtime.static.lib -D_MSC_VER=1939 -D_MSC_FULL_VER=193933523 -D_WIN32 -D_M_ARM64=1 -fopenmp -resource-dir C:\\Users\\tcwg\\llvm-worker\\flang-arm64-windows-msvc\\build\\lib\\clang\\21 -mframe-pointer=none -x f95 C:\\Users\\tcwg\\llvm-worker\\flang-arm64-windows-msvc\\llvm-project\\flang\\test\\Semantics\\OpenMP\\loop-association.f90
# | Exception Code: 0xC0000005
# |  #0 0x00007ff66ac32e34 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a92e34)
# |  #1 0x00007ff66ac30b1c (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a90b1c)
# |  #2 0x00007ff66ac32350 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a92350)
# |  #3 0x00007ff66ac31a58 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a91a58)
# |  #4 0x00007ff66ac30fb8 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a90fb8)
# |  #5 0x00007ff66ac34934 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a94934)
# |  #6 0x00007ff66ac2e518 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x1a8e518)
# |  #7 0x00007ff669a4cbdc (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x8acbdc)
# |  #8 0x00007ff669a9603c (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x8f603c)
# |  #9 0x00007ff669b3eef8 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x99eef8)
# | #10 0x00007ff669a95478 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x8f5478)
# | #11 0x00007ff669238ac0 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x98ac0)
# | #12 0x00007ff66924ea88 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0xaea88)
# | #13 0x00007ff6691a34e8 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x34e8)
# | #14 0x00007ff6691a20a4 (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x20a4)
# | #15 0x00007ff66de5dec8 mlir::detail::FallbackTypeIDResolver::registerImplicitTypeID(class llvm::StringRef) (C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin\flang.exe+0x4cbdec8)
# | #16 0xb93dfff66de5df64
# | flang: error: flang frontend command failed due to signal (use -v to see invocation)
# | flang version 21.0.0git (https://github.com/llvm/llvm-project.git a63f57262898588b576d66e5fd79c0aa64b35f2d)
# | Target: aarch64-pc-windows-msvc
# | Thread model: posix
# | InstalledDir: C:\Users\tcwg\llvm-worker\flang-arm64-windows-msvc\build\bin
# | Build config: +assertions
# | flang: note: diagnostic msg: 
# | ********************
# | 
# | PLEASE ATTACH THE FOLLOWING FILES TO THE BUG REPORT:
# | Preprocessed source(s) and associated run script(s) are located at:
# | flang: note: diagnostic msg: C:\Users\tcwg\AppData\Local\Temp\lit-tmp-ppdfp7wy\loop-association-385c17
# | flang: note: diagnostic msg: C:\Users\tcwg\AppData\Local\Temp\lit-tmp-ppdfp7wy\loop-association-385c17.sh
# | flang: note: diagnostic msg: 
# | 
# | ********************
# | 
# `-----------------------------
# error: command failed with exit status: 1

...

llvm-ci avatar Jul 02 '25 18:07 llvm-ci