rten icon indicating copy to clipboard operation
rten copied to clipboard

Make optimizer agnostic about ordering of associative operations

Open robertknight opened this issue 8 months ago • 3 comments

The GELU fusion will fuse X * 0.5 * (1 + Erf(X / Sqrt(2))) but not X * (1 + Erf(X / Sqrt(2))) * 0.5, even though these are equivalent. See note. More generally the optimizer won't fuse a subgraph F(X, F(Y, Z)) if its pattern expects F(F(X, Y), Z), where F is an associative operation.

An example where this comes up is that the GELU operator is not fused in this model.

robertknight avatar Apr 21 '25 13:04 robertknight

Some notes on canonicalization passes that compilers use to handle this:

  • https://gcc.gnu.org/onlinedocs/gccint/Insn-Canonicalizations.html
  • https://mlir.llvm.org/docs/Canonicalization/
  • https://www.npopov.com/2023/04/10/LLVM-Canonicalization-and-target-independence.html

robertknight avatar Jul 06 '25 17:07 robertknight

I'm interested in contributing to this project! The issue description is clear and I have some ideas for implementation. Would love to discuss the approach.

shubhamos-ai avatar Aug 29 '25 09:08 shubhamos-ai

Fusion is implemented by finding patterns in the graph using a pattern matcher implemented in pattern_matcher.rs. This matcher doesn't understand associativity so a pattern for Add(X, Add(Y, Z)) won't match a graph that looks like Add(Add(X, Y), Z) for example.

To solve this either the pattern matching needs to be made more flexible somehow, or the graph needs to be transformed into a predictable (canonical) format before pattern matching is applied, and the patterns only need to match that canonical format. For example the GCC docs linked above mention that associative operations always chain to the left after canonicalization. One complication here compared to a normal programming language is that operators like Add and Mul do broadcasting. I haven't thought about how that will affect what kinds of transformation can be done.

I would probably start by doing some research to see if there are good examples of how other machine learning runtimes / compilers solve this, and whether they have an approach that can be adapted.

robertknight avatar Aug 29 '25 10:08 robertknight