Make optimizer agnostic about ordering of associative operations
The GELU fusion will fuse X * 0.5 * (1 + Erf(X / Sqrt(2))) but not X * (1 + Erf(X / Sqrt(2))) * 0.5, even though these are equivalent. See note. More generally the optimizer won't fuse a subgraph F(X, F(Y, Z)) if its pattern expects F(F(X, Y), Z), where F is an associative operation.
An example where this comes up is that the GELU operator is not fused in this model.
Some notes on canonicalization passes that compilers use to handle this:
- https://gcc.gnu.org/onlinedocs/gccint/Insn-Canonicalizations.html
- https://mlir.llvm.org/docs/Canonicalization/
- https://www.npopov.com/2023/04/10/LLVM-Canonicalization-and-target-independence.html
I'm interested in contributing to this project! The issue description is clear and I have some ideas for implementation. Would love to discuss the approach.
Fusion is implemented by finding patterns in the graph using a pattern matcher implemented in pattern_matcher.rs. This matcher doesn't understand associativity so a pattern for Add(X, Add(Y, Z)) won't match a graph that looks like Add(Add(X, Y), Z) for example.
To solve this either the pattern matching needs to be made more flexible somehow, or the graph needs to be transformed into a predictable (canonical) format before pattern matching is applied, and the patterns only need to match that canonical format. For example the GCC docs linked above mention that associative operations always chain to the left after canonicalization. One complication here compared to a normal programming language is that operators like Add and Mul do broadcasting. I haven't thought about how that will affect what kinds of transformation can be done.
I would probably start by doing some research to see if there are good examples of how other machine learning runtimes / compilers solve this, and whether they have an approach that can be adapted.