[MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom…
…position implementation
- Add a decomposition pass that handles complex aggregate ops (e.g., softmax), replacing them with a sequence of non-aggregate linalg named ops. Implementation for softmax follows the lowering semantics of popular frameworks like PyTorch, TensorFlow, and others.
- Make the
AggregatedOpInterfacereturn aDecompositionResult, similar to the tiling interface. This is to communicate the decomposition sequence nicely (e.g., useful for transform dialect, see below). - Rework
DecomposeInterfaceOpimplementation. This removes code duplication between the generalization pass and decomposition implementation - now aggregate ops are decomposed first and then generalized.
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg
Author: Petr Kurapov (kurapov-peter)
Changes
…position implementation
- Add a decomposition pass that handles complex aggregate ops (e.g., softmax), replacing them with a sequence of non-aggregate linalg named ops. Implementation for softmax follows the lowering semantics of popular frameworks like PyTorch, TensorFlow, and others.
- Make the
AggregatedOpInterfacereturn aDecompositionResult, similar to the tiling interface. This is to communicate the decomposition sequence nicely (e.g., useful for transform dialect, see below). - Rework
DecomposeInterfaceOpto return variadic results and use the new decomposition. This removes code duplication between the generalization pass and decomposition implementation - now aggregate ops are decomposed first and then generalized.
Patch is 35.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97582.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+10)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+1-1)
- (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+5)
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+7-11)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+82-129)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+15-19)
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp (+62)
- (added) mlir/test/Dialect/Linalg/decompose-named-ops.mlir (+107)
- (modified) mlir/test/Dialect/Linalg/transform-op-decompose.mlir (+44-10)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014..3858075fae137 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -30,6 +30,16 @@ class IteratorTypeAttr;
class LinalgOp;
class GenericOp;
+/// Container for result values of decomposition.
+/// - `decomposedOps` contains operations created by the decomposition that are
+/// returned to the caller for further transformations.
+/// - `decomposedValues` contains the values corresponding to the result of the
+/// aggregate operation.
+struct DecompositionResult {
+ SmallVector<Operation *> decomposedOps;
+ SmallVector<Value> decomposedValues;
+};
+
namespace detail {
/// Implementation of the method that check if given operands
/// can be dropped, i.e. the remaining operands can compute the loop
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..9b1ab20552628 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -862,7 +862,7 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
In other words, the returned vector can be used directly with
`RewriterBase::replaceOp(this, returnedValues)`.
}],
- /*retType=*/"FailureOr<SmallVector<Value>>",
+ /*retType=*/"FailureOr<DecompositionResult>",
/*methodName=*/"decomposeOperation",
/*args=*/(ins
"OpBuilder &":$b),
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0621a9f33ba1e..3031126e582f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -94,6 +94,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgDecomposeAggregateNamedOpsPass : Pass<"linalg-decompose-named-ops"> {
+ let summary = "Decompose complex named ops (e.g., Softmax) into a sequence of linalg named ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..2e8e294aa2e4c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1317,25 +1317,21 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- TODO
+ Decomposes high-level named ops into a sequence of non-aggregate named ops
+ via `AggregatedOpInterface`.
+
+ The operation ignores non-decomposable ops. The return handles point to
+ a sequence of named ops produced by the decomposition.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs TransformHandleTypeInterface:$transformed);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// RewriteInDestinationPassingStyleOp.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..b0eeb274f71bb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1546,6 +1546,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns to decompose high-level aggregate named
+/// ops (e.g., softmax) into a sequence of simpler linalg named ops, defining
+/// the operation semantics.
+void populateDecomposeAggregateNamedOpsPatterns(RewritePatternSet &patterns);
+
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..383f285969ad7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2564,116 +2564,41 @@ void SoftmaxOp::getEffects(
// Helper functions for softmax decomposition.
// @{
-
-// Helper function to produce the iterator types (reduction or parallel) and
-// affine maps for the iterators used in the decomposition of softmax.
-// This method creates:
-// If allParallel == true:
-// - iterator type: {parallel, ..., parallel}
-// - affine maps:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-// where N == inputRank.
-//
-// If allParallel == false:
-// - iterator type at dim(i) == parallel for i != \p dim and
-// dim(dim) == reduction.
-// - affine map:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-// where N == inputRank.
-static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
-computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
- int64_t dim, bool allParallel = false) {
- SmallVector<utils::IteratorType> iteratorTypes(inputRank,
- utils::IteratorType::parallel);
- if (!allParallel)
- iteratorTypes[dim] = utils::IteratorType::reduction;
- MLIRContext *ctxt = builder.getContext();
- auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
- SmallVector<AffineExpr, 2> affineExprs;
- for (int i = 0; i < inputRank; i++) {
- if (i != dim)
- affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
- }
- auto reductionMap =
- AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
- SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
- return std::make_tuple(iteratorTypes, indexingMaps);
-}
-
-// Helper function to produce a linalg.generic that computes a reduction on
-// dimension \p dim with the operation type \p T.
-template <typename T>
-static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
- int64_t dim) {
- auto inputType = cast<ShapedType>(input.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] =
- computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
- assert(indexingMaps.size() == 2 &&
- "We should have two maps: 1 for the input, 1 for the output");
- assert(indexingMaps[0].isIdentity() && "input map should be identity");
-
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), input, output, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<T>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the second step of the softmax
-/// decomposition: res = exp(input - max), where \p max is the max of \p input
-/// on dimension \p dim.
-static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
- Value max, Value output, int64_t dim) {
- auto inputType = cast<ShapedType>(input.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
- builder, inputRank, dim, /*allParallel=*/true);
- assert(indexingMaps.size() == 2 && "We should have one map for each input");
- assert(indexingMaps[0].isIdentity() && "input map should be identity");
- // Add the affine map for the output argument.
- indexingMaps.push_back(indexingMaps[0]);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
- Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
- Value result = b.create<math::ExpOp>(loc, diff);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the final step of the softmax
-/// decomposition.
-/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
-/// yield n / d
-/// }
-static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
- Value denominator, Value output, int64_t dim) {
- auto inputType = cast<ShapedType>(numerator.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
- builder, inputRank, dim, /*allParallel=*/true);
- assert(indexingMaps.size() == 2 &&
- "We should have one map for each input (2)");
- assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
- // Add the affine map for the output tensor.
- indexingMaps.push_back(indexingMaps[0]);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, numerator.getType(), ValueRange{numerator, denominator}, output,
- indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
+TypedAttr createInitValueForReduceMaxOp(Type type, OpBuilder &b) {
+ if (isa<FloatType>(type))
+ return b.getFloatAttr(
+ type, APFloat::getSmallest(cast<FloatType>(type).getFloatSemantics()));
+ if (isa<IntegerType>(type))
+ return b.getIntegerAttr(
+ type, APInt::getSignedMinValue(type.getIntOrFloatBitWidth()));
+ return {};
+}
+
+TypedAttr createInitValueForReduceSumOp(Type type, OpBuilder &b) {
+ if (isa<FloatType>(type))
+ return b.getFloatAttr(
+ type, APFloat::getZero(cast<FloatType>(type).getFloatSemantics()));
+ if (isa<IntegerType>(type))
+ return b.getIntegerAttr(type, APInt::getZero(type.getIntOrFloatBitWidth()));
+ return {};
+}
+
+Value createLinalgReduceMaxBody(OpBuilder b, Location loc, ValueRange args,
+ Type elementTy) {
+ if (isa<FloatType>(elementTy))
+ return b.create<arith::MaxNumFOp>(loc, args[0], args[1]);
+ if (isa<IntegerType>(elementTy))
+ return b.create<arith::MaxSIOp>(loc, args[0], args[1]);
+ return {};
+}
+
+Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
+ Type elementTy) {
+ if (isa<FloatType>(elementTy))
+ return b.create<arith::AddFOp>(loc, args[0], args[1]);
+ if (isa<IntegerType>(elementTy))
+ return b.create<arith::AddIOp>(loc, args[0], args[1]);
+ return {};
}
// @} End helper functions for softmax decomposition.
@@ -2695,7 +2620,7 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
/// 4. Divide z and l. This gives the N-dimensional softmax.
/// softmax = z / l
///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(*this);
Location loc = getLoc();
@@ -2706,32 +2631,60 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
Value output = getOutput();
dims.erase(dims.begin() + reductionDim);
+
// Step 1: Compute max along dim.
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
- Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
- elementType, b, loc,
- /*useOnlyFiniteValue=*/true);
- Value neutralForMaxFInit =
- b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
- .result();
- Value max =
- reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+ auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+ auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+ auto neutralMaxInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+ Value neutralForMaxFInit = neutralMaxInitOp.result();
+
+ auto reduceMaxOp = b.create<linalg::ReduceOp>(
+ loc, input, neutralForMaxFInit, reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ auto result =
+ createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 2: Subtract max from input and exponentiate.
- Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+ auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+ loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+ auto subOp = b.create<linalg::SubOp>(
+ loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+ ValueRange{output});
+ auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+ ValueRange{output});
// Step 3: Compute sum along dim.
- Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
- b, loc, /*useOnlyFiniteValue=*/true);
- Value zeroInit =
- b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
- Value denominator =
- reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+ auto sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+ auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+ auto neutralSumInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+ auto sumFilledTensor = neutralSumInitOp.result();
+ auto reduceSumOp = b.create<linalg::ReduceOp>(
+ loc, expOp.getResults(), sumFilledTensor, reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ auto result =
+ createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 4: Compute softmax.
- Value result =
- buildDivOp(b, loc, numerator, denominator, output, reductionDim);
- return SmallVector<Value>{result};
+ auto sumBcastOutput = b.create<tensor::EmptyOp>(
+ loc, getOutputOperandType().getShape(), elementType);
+ auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
+ loc, reduceSumOp.getResult(0), sumBcastOutput,
+ reduceSumOp.getDimensionsAttr());
+ auto divOp = b.create<linalg::DivOp>(
+ loc, ValueRange{expOp.getResult(0), sumBroadcastOp.getResults().front()},
+ ValueRange{output});
+ return DecompositionResult{{neutralMaxInitOp, reduceMaxOp, maxBroadcastOp,
+ subOp, expOp, neutralSumInitOp, reduceSumOp,
+ sumBroadcastOp, divOp},
+ {divOp.getResults().front()}};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..e3f0a18a5ec2c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -431,27 +431,23 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
// Decompose the target operation if it implements the AggregatedOpInterface.
// Push the decomposed operations (the ones that replaces the values produced by
// \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
- if (!decomposableOp) {
- failed(rewriter.notifyMatchFailure(target,
- "payload is not a decomposable op"));
- return emitDefaultSilenceableFailure(target);
- }
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+ auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+ if (!decomposableOp)
+ continue;
- FailureOr<SmallVector<Value>> maybeNewResults =
- decomposableOp.decomposeOperation(rewriter);
- if (failed(maybeNewResults))
- return emitDefaultSilenceableFailure(target);
+ FailureOr<DecompositionResult> maybeNewResults =
+ decomposableOp.decomposeOperation(rewriter);
+ if (failed(maybeNewResults))
+ return emitDefaultSilenceableFailure(target);
- rewriter.replaceOp(decomposableOp, *maybeNewResults);
- for (Value val : *maybeNewResults) {
- Operation *definition = val.getDefiningOp();
- if (definition)
- results.push_back(definition);
+ rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+ transformResults.set(cast<OpResult>(getResult(i)),
+ maybeNewResults->decomposedOps);
}
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..68582fe6cbad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ConvertConv2DToImg2Col.cpp
DataLayoutPropagation.cpp
DecomposeLinalgOps.cpp
+ DecomposeAggregateNamedLinalgOps.cpp
Detensorize.cpp
DropUnitDims.cpp
ElementwiseOpFusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
new file mode 100644
index 0000000000000..e8a5b96d54d34
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
@@ -0,0 +1,62 @@
+//===- DecomposeNamedLinalgOps.cpp - Patterns to break up complex ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/Gre...
[truncated]
- The
linalg.reduce { maxnum }'s init constant is wrong, will yield wrong max if all numbers are<=0.
Oh, nice catch! Misused the smallest api :) Will fix in a bit.
Thanks!
Does this PR change what Ops linalg.sotfmax is decomposed into? As in, do the semantics of linalg.softmax change? Looking at mlir/test/Dialect/Linalg/transform-op-decompose.mlir, I feel that the main difference was:
linalg.fills become a genteric, but that's a minor thing and I think it's the right thing to have.
Also, what's the difference between "transform-op-decompose.mlir" and "decompose-named-ops.mlir"? Is it just "how" the decomposition is "driven"? (TD vs Pass) Couldn't that be one test instead?
Implementation for softmax follows the lowering semantics of popular frameworks like https://github.com/intel/graph-compiler/issues/10#issuecomment-2161145033, https://github.com/intel/graph-compiler/issues/10#issuecomment-2162722181, and https://github.com/intel/graph-compiler/issues/10#issuecomment-2161153179.
Thanks for checking and for the extra context. I am just wondering:
- If this claims alignment with e.g. PyTorch (or, more specifically, torch-mlir), shouldn't there be a link to
torch-mlirdocs/issues/code/test instead? - Are you saying that this PR is changing the semantics of softmax in Linalg?
now aggregate ops are decomposed first and then generalized
I am a bit confused, there's -linalg-decompose-named-ops and -linalg-generalize-named-ops - which one are you referring to? The first one would only decompose and the latter would only generalise, right?
Does this PR change what Ops
linalg.sotfmaxis decomposed into? As in, do the semantics oflinalg.softmaxchange? Looking at mlir/test/Dialect/Linalg/transform-op-decompose.mlir, I feel that the main difference was:
The decomposition follows the op description (see LinalgOps.td) and specifies its semantics via the implementation. The implementation ends up generating the semantically the same code as the previous decomposition implementation (with minor deviation as you noted). The transform test demonstrates the change in the generalized code.
Also, what's the difference between "transform-op-decompose.mlir" and "decompose-named-ops.mlir"? Is it just "how" the decomposition is "driven"? (TD vs Pass) Couldn't that be one test instead?
The first one tests the new pass. The second one uses the transform interpreter and the decompose_interface op (which happen to partially rely on the same code now).
If this claims alignment with e.g. PyTorch (or, more specifically, torch-mlir), shouldn't there be a link to torch-mlir docs/issues/code/test instead?
The IR presented is the IR you get by lowering PyTorch to torch-mlir.
Are you saying that this PR is changing the semantics of softmax in Linalg?
I'd say it sets it. The implementation follows the op description, so there's no real 'change'.
now aggregate ops are decomposed first and then generalized
I am a bit confused, there's
-linalg-decompose-named-opsand-linalg-generalize-named-ops- which one are you referring to?
Decomposition is performed by the newly introduced -linalg-decompose-named-ops (as the name suggests). Generalization is done by the default -linalg-generalize-named-ops.
The first one would only decompose and the latter would only generalise, right?
Correct.
The implementation ends up generating the semantically the same code as the previous decomposition implementation (with minor deviation as you noted).
IMO this is key - please add that in the summary.
Also, what's the difference between "transform-op-decompose.mlir" and "decompose-named-ops.mlir"? Is it just "how" the decomposition is "driven"? (TD vs Pass) Couldn't that be one test instead?
The first one tests the new pass. The second one uses the transform interpreter and the decompose_interface op (which happen to partially rely on the same code now).
From what I can tell, both tests verify the decomposition of linalg.softmax:
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
return %1 : tensor<2x16x32xf32>
}
Couldn't we re-use the input and the CHECK lines? To avoid duplication.
If this claims alignment with e.g. PyTorch (or, more specifically, torch-mlir), shouldn't there be a link to torch-mlir docs/issues/code/test instead?
The IR presented is the IR you get by lowering PyTorch to torch-mlir.
I know where the IR is coming from.
FWIW (without links to documentation), that IR is just an implementation detail of torch-mlir. In this PR we are discussing an implementation detail of Linalg. Are you saying that the implementation in Linalg should match torch-mlir? Why? What if the implementation in torch-mlir changes?
I'm trying to understand the motivation here and the overall design that we are converging towards.
Are you saying that this PR is changing the semantics of softmax in Linalg?
I'd say it sets it. The implementation follows the op description, so there's no real 'change'.
Key info, please highlight in the summary.
now aggregate ops are decomposed first and then generalized
I am a bit confused, there's
-linalg-decompose-named-opsand-linalg-generalize-named-ops- which one are you referring to?Decomposition is performed by the newly introduced
-linalg-decompose-named-ops(as the name suggests). Generalization is done by the default-linalg-generalize-named-ops.
I know what the options are, thanks. To me, your comment implies that -linalg-decompose-named-ops is meant to be followed by -linalg-generalize-named-ops ("aggregate ops are decomposed first and then generalized")? Is that what you had in mind?
The implementation ends up generating the semantically the same code as the previous decomposition implementation (with minor deviation as you noted).
IMO this is key - please add that in the summary.
Done.
Also, what's the difference between "transform-op-decompose.mlir" and "decompose-named-ops.mlir"? Is it just "how" the decomposition is "driven"? (TD vs Pass) Couldn't that be one test instead?
The first one tests the new pass. The second one uses the transform interpreter and the decompose_interface op (which happen to partially rely on the same code now).
From what I can tell, both tests verify the decomposition of
linalg.softmax:func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> return %1 : tensor<2x16x32xf32> }Couldn't we re-use the input and the
CHECKlines? To avoid duplication.
Do I understand correctly that you suggest having a single lit with the body of @softmax, the transformation IR, and runs both the decomposition via the pass and transform interpreter?
FWIW (without links to documentation), that IR is just an implementation detail of torch-mlir. In this PR we are discussing an implementation detail of Linalg. Are you saying that the implementation in Linalg should match torch-mlir? Why? What if the implementation in torch-mlir changes?
I'm trying to understand the motivation here and the overall design that we are converging towards.
So the goal was to look into what frameworks actually do in the implementation. If it so happens that all of them lower softmax to the same sequence (and this is what we happen to have) - we can have it set as the default decomposition to avoid re-implementing the thing. The general idea and direction is to have an intermediate decomposition stage that deals with complex ops (such as softmax) to aid other transformations and analyses (this is also an easy route to adding more named ops, upstream and downstream, and not implement all the interfaces like tiling, but convert it to simpler ops instead and enjoy all the existing goodness). Note: I'm leaving the question of accumulation out for now, this should be addressed separately.
Are you saying that this PR is changing the semantics of softmax in Linalg?
I'd say it sets it. The implementation follows the op description, so there's no real 'change'.
Key info, please highlight in the summary.
Done.
I know what the options are, thanks. To me, your comment implies that
-linalg-decompose-named-opsis meant to be followed by-linalg-generalize-named-ops("aggregate ops are decomposed first and then generalized")? Is that what you had in mind?
This just describes the transform op decomposition change (in other words it used to produce generics + fills, now it internally produces named ops sequence and then runs generalization), there's no strict requirement to run generalization after the decomposition of course.
Do I understand correctly that you suggest having a single lit with the body of @softmax, the transformation IR, and runs both the decomposition via the pass and transform interpreter?
Yes, something along those lines. Basically, IMO, we should identify a canonical way to test transformations with both "passes" and TD so that we maximise "test case" re-use.
FWIW (without links to documentation), that IR is just an implementation detail of torch-mlir. In this PR we are discussing an implementation detail of Linalg. Are you saying that the implementation in Linalg should match torch-mlir? Why? What if the implementation in torch-mlir changes?
PyTorch & torch-mlir are red herrings. This PR is about softmax.
The current lowering of softmax into generics is unsurprisingly the same as both PyTorch and Tensorflow expect it to be. This is related but distinc to lowering softmax to named ops. So, let's close this tangent, as it is irrelevant.
I know what the options are, thanks. To me, your comment implies that
-linalg-decompose-named-opsis meant to be followed by-linalg-generalize-named-ops("aggregate ops are decomposed first and then generalized")? Is that what you had in mind?
Absolutely not.
The test where we lower to named ops was created in this PR. We verify that it does the "right thing".
The old test checks the generic lowering (because it was the only thing we had), so @kurapov-peter added the generalization pass to match the old expectation. This provides us with a clear "apples to apples" comparison, and shows that the old output can no longer be attained. Most importantly, you don't get a mix of named + generics. You either get all named or all generic.
This seems to be a problem to @MaheshRavishankar and I want to understand it better. My guess is that there are pattern matchers that won't work with the generic version of fill (and why we want named ops in the first place).
An option is to make some "special" ops never to generalize, for example linalg.fill, by the generalize pattern. Or to have a flag in the generalize pass that does that, but without the option, it converts all. An alternative option is to piece-wise generalize downstream. It depends on how the matcher expects the code to be.
This seems to be a problem to @MaheshRavishankar and I want to understand it better. My guess is that there are pattern matchers that won't work with the generic version of
fill(and why we want named ops in the first place).
I would like to understand the next steps here. @MaheshRavishankar, could you please elaborate ^^^?
Regarding the fill issue, I think customizable decomposition would be a reasonable solution - helps preserve the downstream usage and doesn't hold back the upstream.
Regarding broadcast, I could work on setting the semantics for implicit casting. One thing that is unclear to me though is whether having implicit cast semantics for named ops is beneficial. Wasn't the whole point of named ops to have a very explicit IR that is easy to analyze? In that regard, the absence of implicit casts is actually a good thing (I also don't see how it is ambiguous, could you please clarify?). Is there any real problem with broadcasts except for not being succinct? Wouldn't implicit casting just add unnecessary burden for analyses and transforms to handle various cases of arguments?
There is broadly two things that we need to make progress here
- This is two separate PRs, one that is changing the softmax decomposition and one that is adding a pass for decomposition. The latter should be easy to land.
- The change to the decomposition of softmax IMO is an inefficient lowering of softmax and will require "something" to get the state back. This should be part of the PR that is changing the decomposition. It is moving from a more succinct representation that Linalg allows to something that is (artifically) hamstrung with current definitions of the named ops. I dont expect the issue with named ops to be fixed as a precursor (though that would be the right thing to do IMO), but for this PR, I dont see how we can land it without having an option to chose how to decompose softmax (with default being what it is today, and an option to lower to sequence of named ops). On top of that adding a generalization to convert everything to
linalg.generics is a non-starter IMO. You will be forcing all downstream users to either use "recognizers" heavily to retrieve back the information that is lost by generalization and not giving downstream users control on when they want to generalize.
This seems to be a problem to @MaheshRavishankar and I want to understand it better. My guess is that there are pattern matchers that won't work with the generic version of
fill(and why we want named ops in the first place).
Just to map back to what I said above, we can "recognize" that its a fill, but that seems like an unnecessary burden added to downstream users because it has been generalized too early without any control. I can go into details about why I think "fill" is special but thats a separate issue IMO.
Regarding the
fillissue, I think customizable decomposition would be a reasonable solution - helps preserve the downstream usage and doesn't hold back the upstream.Regarding
broadcast, I could work on setting the semantics for implicit casting. One thing that is unclear to me though is whether having implicit cast semantics for named ops is beneficial. Wasn't the whole point of named ops to have a very explicit IR that is easy to analyze? In that regard, the absence of implicit casts is actually a good thing (I also don't see how it is ambiguous, could you please clarify?). Is there any real problem with broadcasts except for not being succinct? Wouldn't implicit casting just add unnecessary burden for analyses and transforms to handle various cases of arguments?
I want clarify this. This is NOT implicit broadcasting. This is very much unambiguous broadcast representation. For example, linalg.generic allows you to represent broadcast-add this way
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0, %1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%3 = arith.addf %b0, %b1 : f32
linalg.yield %3: f32
} -> tensor<?x?xf32>
There is nothing ambiguous or implicit in this broadcasting. The problem with named ops is that it forces all operands to be of the same rank, which is an unnecessary requirement at Linalg level. The fix is to allow named ops to make use of the broadcast representation that Linalg inherently allows. In the name of "explicit broadcasting" we have an artificial requirement of getting all operands to the same rank that is unnecessary IMO. Also it strictly easier to go from this representation to a representation that requires all operands to be of same rank (its essentially a lowering, you break up the operation into multiple ops). Going from a representation where all ops are "broadcasted" to the same rank to the above representation is IMO a lifting.
Actually that brings me to maybe a potential solution. You can take the existing lowering for softmax and then add a pass to explicitly split out the broadcast and then generalize. That will get you to the state you want here?
- The change to the decomposition of softmax IMO is an inefficient lowering of softmax and will require "something" to get the state back. This should be part of the PR that is changing the decomposition. It is moving from a more succinct representation that Linalg allows to something that is (artifically) hamstrung with current definitions of the named ops. I dont expect the issue with named ops to be fixed as a precursor (though that would be the right thing to do IMO), but for this PR, I dont see how we can land it without having an option to chose how to decompose softmax (with default being what it is today, and an option to lower to sequence of named ops). On top of that adding a generalization to convert everything to
linalg.generics is a non-starter IMO. You will be forcing all downstream users to either use "recognizers" heavily to retrieve back the information that is lost by generalization and not giving downstream users control on when they want to generalize.
Ok, I see. So this position is the opposite of what I'm proposing: changing the default decomposition to target named ops (note that this has nothing to do with generalization).
Here I'm summarizing the arguments for preserving the status quo and against it.
Cons of changing default:
- Additional steps are required to reach the same IR state.
- The IR is less succinct: explicit broadcasts are inserted due to named ops requirements.
Pro of changing default:
- Current decomposition is a mix of an actual decomposition to exp, sum, sub, max, and div named ops + partial generalization + some fusion (I'll call the existing one a mixed decomposition here to differentiate between the proposed approach and the existing one). The proposed approach limits decomposition to a single responsibility.
- Separating these three stages is beneficial because you can control the behavior better. For example, after decomposing softmax into a sequence of named ops one can fuse and tile them together with another named op that was not part of softmax. With the current approach, you'd still need the fusion pass run after the mixed decomposition to reach the same state, so pipeline complexity is the same. Moreover, new possibilities open up for pipelines that don't want to generalize the result of the decomposition.
Mitigating cons n.1: Even though reaching the same result of the mixed decomposition requires additional steps, those are existing upstream transformations. Hence, the downstream changes won't be a significant burden. Mitigating cons n.2: As for broadcasting, as I mentioned earlier, adding implicit casting for named ops is an option. Though I still don't see an actual problem with the "same rank" requirement other than it is "unnecessary", I'm willing to work on it if it proves valuable.
I suggest we make a decision on the direction here @ftynse, @nicolasvasilache, @dcaballe, @rengolin.
Just to map back to what I said above, we can "recognize" that its a fill, but that seems like an unnecessary burden added to downstream users because it has been generalized too early without any control. I can go into details about why I think "fill" is special but thats a separate issue IMO.
I think you missed the point. The proposed decomposition only converts an aggregate operation into a sequence of non-aggregate ones. This has nothing to do with generalization. Downstreams don't need to recognize a fill from its generic form. The solution for you would be to do partial generalization, leaving fills intact.
Actually that brings me to maybe a potential solution. You can take the existing lowering for softmax and then add a pass to explicitly split out the broadcast and then generalize. That will get you to the state you want here?
Same here. Generalized IR with broadcasts is not the target state. The target is a sequence of named ops.
I have been trying to find a way to help land this PR without asking for too much "as a precursor" work, but given that there hasnt been much change in the approach the real issue IMO are two
- It seems like there is a missing control in decomposition in general. An op might have multiple ways of decomposition that should be controllable (either through registering different interfaces or having an options struct that allows you to control the decomposition, dont know about the exact mechanism). I dont think there is a way this PR can land in its current form without introducing such a mechanism from the get go. We can decide the defaults, but the optionality to decomposition needs to be built (I think this should be done anyway, and I am happy to do it, but its not a priority for me right now)
- If named ops didnt have (what I consider) an outdated semantics for broadcast handling, then there would be a path where you could just lower to named ops without structurally changing the lowering. That is also another path to make this work. This is also worth doing, but requires a broader agreement on direction.
- The change to the decomposition of softmax IMO is an inefficient lowering of softmax and will require "something" to get the state back. This should be part of the PR that is changing the decomposition. It is moving from a more succinct representation that Linalg allows to something that is (artifically) hamstrung with current definitions of the named ops. I dont expect the issue with named ops to be fixed as a precursor (though that would be the right thing to do IMO), but for this PR, I dont see how we can land it without having an option to chose how to decompose softmax (with default being what it is today, and an option to lower to sequence of named ops). On top of that adding a generalization to convert everything to
linalg.generics is a non-starter IMO. You will be forcing all downstream users to either use "recognizers" heavily to retrieve back the information that is lost by generalization and not giving downstream users control on when they want to generalize.Ok, I see. So this position is the opposite of what I'm proposing: changing the default decomposition to target named ops (note that this has nothing to do with generalization).
Here I'm summarizing the arguments for preserving the status quo and against it.
Cons of changing default:
- Additional steps are required to reach the same IR state.
- The IR is less succinct: explicit broadcasts are inserted due to named ops requirements.
Pro of changing default:
- Current decomposition is a mix of an actual decomposition to exp, sum, sub, max, and div named ops + partial generalization + some fusion (I'll call the existing one a mixed decomposition here to differentiate between the proposed approach and the existing one). The proposed approach limits decomposition to a single responsibility.
I dont agree with this characterization. If you want to lower to named ops, which are you generalizing after, and representing broadcasts more succinctly is not a fusion IMO. This seems like it is transplanting ideas from tosa/torch etc. into Linalg. There you need to have broadcast as a separate operation. You dont need that in Linalg. I wouldnt characterize it as a fusion (rather tosa/torch are artifically forcing front-ends/lowering/programmers to introduce a broadcast since they dont have mechanisms to represent broadcasts effectively).
- Separating these three stages is beneficial because you can control the behavior better. For example, after decomposing softmax into a sequence of named ops one can fuse and tile them together with another named op that was not part of softmax. With the current approach, you'd still need the fusion pass run after the mixed decomposition to reach the same state, so pipeline complexity is the same. Moreover, new possibilities open up for pipelines that don't want to generalize the result of the decomposition.
Mitigating cons n.1: Even though reaching the same result of the mixed decomposition requires additional steps, those are existing upstream transformations. Hence, the downstream changes won't be a significant burden. Mitigating cons n.2: As for broadcasting, as I mentioned earlier, adding implicit casting for named ops is an option. Though I still don't see an actual problem with the "same rank" requirement other than it is "unnecessary", I'm willing to work on it if it proves valuable.
Again, I dont agree that this is implicit casting. This is very much explicit representation of broadcasting behavior. And if you think downstream changes to get back to current state is not a significant burden, please add that to your PR and then lets discuss how to package it from a user perspective. I could very well come and change the behavior back because it "suits my need" and we will never be able to reach a stable state.
I think you missed the point. The proposed decomposition only converts an aggregate operation into a sequence of non-aggregate ones. This has nothing to do with generalization. Downstreams don't need to recognize a
fillfrom its generic form. The solution for you would be to do partial generalization, leavingfills intact.
Again I disagree. Representing broadcasting semantics more succinctly is not about it being an "aggregate" op, but rather there should never have been a need to have a linalg.broadcast operation in the first place IMO (and you are over-indexing on the fill issue. I understand completely how to do partial generalization). There was never an agreement that decomposing a softmax op into named ops as it exists today is the way to go.
I have been trying to find a way to help land this PR without asking for too much "as a precursor" work, but given that there hasnt been much change in the approach the real issue IMO are two
- It seems like there is a missing control in decomposition in general. An op might have multiple ways of decomposition that should be controllable (either through registering different interfaces or having an options struct that allows you to control the decomposition, dont know about the exact mechanism). I dont think there is a way this PR can land in its current form without introducing such a mechanism from the get go. We can decide the defaults, but the optionality to decomposition needs to be built (I think this should be done anyway, and I am happy to do it, but its not a priority for me right now)
- If named ops didnt have (what I consider) an outdated semantics for broadcast handling, then there would be a path where you could just lower to named ops without structurally changing the lowering. That is also another path to make this work. This is also worth doing, but requires a broader agreement on direction.
- The change to the decomposition of softmax IMO is an inefficient lowering of softmax and will require "something" to get the state back. This should be part of the PR that is changing the decomposition. It is moving from a more succinct representation that Linalg allows to something that is (artifically) hamstrung with current definitions of the named ops. I dont expect the issue with named ops to be fixed as a precursor (though that would be the right thing to do IMO), but for this PR, I dont see how we can land it without having an option to chose how to decompose softmax (with default being what it is today, and an option to lower to sequence of named ops). On top of that adding a generalization to convert everything to
linalg.generics is a non-starter IMO. You will be forcing all downstream users to either use "recognizers" heavily to retrieve back the information that is lost by generalization and not giving downstream users control on when they want to generalize.Ok, I see. So this position is the opposite of what I'm proposing: changing the default decomposition to target named ops (note that this has nothing to do with generalization). Here I'm summarizing the arguments for preserving the status quo and against it. Cons of changing default:
- Additional steps are required to reach the same IR state.
- The IR is less succinct: explicit broadcasts are inserted due to named ops requirements.
Pro of changing default:
- Current decomposition is a mix of an actual decomposition to exp, sum, sub, max, and div named ops + partial generalization + some fusion (I'll call the existing one a mixed decomposition here to differentiate between the proposed approach and the existing one). The proposed approach limits decomposition to a single responsibility.
I dont agree with this characterization. If you want to lower to named ops, which are you generalizing after, and representing broadcasts more succinctly is not a fusion IMO. This seems like it is transplanting ideas from tosa/torch etc. into Linalg. There you need to have broadcast as a separate operation. You dont need that in Linalg. I wouldnt characterize it as a fusion (rather tosa/torch are artifically forcing front-ends/lowering/programmers to introduce a broadcast since they dont have mechanisms to represent broadcasts effectively).
- Separating these three stages is beneficial because you can control the behavior better. For example, after decomposing softmax into a sequence of named ops one can fuse and tile them together with another named op that was not part of softmax. With the current approach, you'd still need the fusion pass run after the mixed decomposition to reach the same state, so pipeline complexity is the same. Moreover, new possibilities open up for pipelines that don't want to generalize the result of the decomposition.
Mitigating cons n.1: Even though reaching the same result of the mixed decomposition requires additional steps, those are existing upstream transformations. Hence, the downstream changes won't be a significant burden. Mitigating cons n.2: As for broadcasting, as I mentioned earlier, adding implicit casting for named ops is an option. Though I still don't see an actual problem with the "same rank" requirement other than it is "unnecessary", I'm willing to work on it if it proves valuable.
Again, I dont agree that this is implicit casting. This is very much explicit representation of broadcasting behavior. And if you think downstream changes to get back to current state is not a significant burden, please add that to your PR and then lets discuss how to package it from a user perspective. I could very well come and change the behavior back because it "suits my need" and we will never be able to reach a stable state.
I think you missed the point. The proposed decomposition only converts an aggregate operation into a sequence of non-aggregate ones. This has nothing to do with generalization. Downstreams don't need to recognize a
fillfrom its generic form. The solution for you would be to do partial generalization, leavingfills intact.Again I disagree. Representing broadcasting semantics more succinctly is not about it being an "aggregate" op, but rather there should never have been a need to have a
linalg.broadcastoperation in the first place IMO (and you are over-indexing on thefillissue. I understand completely how to do partial generalization). There was never an agreement that decomposing a softmax op into named ops as it exists today is the way to go.
Flyby review of a conversation that appears to be looping.
For number 1 (controllable (de)composition) -- big +1. This is how pytorch does it, and it is a super power (and a big part of what is making things attractive there). Basically, you end up with a way to say what your backend prefers and the framework gets you as close to that as it can. It is a practical way to get out of the issue that really, no one agrees on every detail of this stuff (and never will / there is no best).
For number 2 -- I'm watching this same basic discussion loop on multiple threads and PRs (basically the role of named ops and broadcasting). We're using sloppy language (calling it cast, explicit, implicit, fusion, etc), so I'm not going to add to that. But it is quite clear to me that there are a couple of opposing viewpoints on this being ground out patch by patch (with the usual amount of friction that entails). My tendency is to side with Mahesh's viewpoint on this -- not because of upstream/downstream/whatever -- but because that viewpoint is more compatible with all of the transformations that we should be able to do on this opset (and I've lived too many lives with the really substandard backdrop of trying to use the fixed function op libraries of the frameworks for transformations and optimizations). But if I squint, I can see the value in the "named op everything" viewpoint iff it is part of a holistic design supporting configurable specialization and robust generalization.
I don't want to litigate any of this on a PR like this, but I do think there are a couple of broader discussions here that we'd be better off to have.
I have been trying to find a way to help land this PR without asking for too much "as a precursor" work
I actually don't mind it as long as we agree on the direction. After revisiting our discussion, @MaheshRavishankar, I think we are talking about a similar end state using different languages. I'll try to confirm it below.
An op might have multiple ways of decomposition that should be controllable
Agree. This is a more generic description of what I named partial generalization to reach the same end state of the current decomposition. How about we make this the first step? I can start with an rfc to collect the requirements and we can team up on the design/implementation.
If named ops didnt have (what I consider) an outdated semantics for broadcast handling, then there would be a path where you could just lower to named ops without structurally changing the lowering. That is also another path to make this work. This is also worth doing, but requires a broader agreement on direction.
Right, this is what I referred to as implicit casting. It is less clear to me whether it is a good thing, but again, I am happy to work on it if there's a broad agreement. Here I might be missing something though. I see you both suggesting this would be a solution and you disagree and call it an "explicit representation of broadcasting behavior". This looks contradictory to me. Still, I assume we both think of "named ops can accept tensors of different ranks and decomposition does not produce an actual linalg.broadcast" as a target state, correct?
I can see the value in the "named op everything" viewpoint iff it is part of a holistic design supporting configurable specialization and robust generalization.
@stellaraccident, right, the PR is a naive attempt to go there. I assumed that this was an agreed-upon direction.
2. If named ops didnt have (what I consider) an outdated semantics for broadcast handling, then there would be a path where you could just lower to named ops without structurally changing the lowering. That is also another path to make this work. This is also worth doing, but requires a broader agreement on direction.
This is the key here: define the semantics of the named ops. Landing the PR is secondary.
Can we agree that, IFF we have broadcast/transpose semantics to the named ops, we should decompose softmax to those instead of generics?
How we get there is a matter for another RFC, but I want to make sure our efforts there will lead to this decision agreed on a consensus.
An op might have multiple ways of decomposition that should be controllable
Agree. This is a more generic description of what I named partial generalization to reach the same end state of the current decomposition. How about we make this the first step? I can start with an rfc to collect the requirements and we can team up on the design/implementation.
Lets leave the named ops discussion aside (there is a discussion on going here https://discourse.llvm.org/t/rfc-transpose-attribute-for-linalg-matmul-operations/80092/36?u=maheshravishankar) . But for this PR, lets maybe take a break, and an RFC to allow controlling the decomposition would be great. I confess, I have no great ideas to suggest, just some vague ones. So I am at a loss to really suggest how to do that.
FYI, we're working on a simplification of named ops with affine maps to avoid this problem, and I believe this is the solution for the current problem: https://github.com/plaidml/tpp-mlir/wiki/Linalg-matmul-with-affine-maps#implementation-details
After some thought and discussion, my view on this changed. I tried to write a proposal for how should a mechanism for generic decompositions should look like. The more detail I add the more it resembles the regular rewriter patterns. At this point, it makes no sense to me to introduce yet another very similar mechanism (that is restricted to a specific interface). If we are not changing the default decomposition there's not much value in having additional ones upstream. Those can exist as rewrites downstream.
I'm closing this. Please let me know if there's anything I'm missing or there's still interest in additional decompositions.