[TIR, analysis] Add expr hash sort in ExprDeepEqual
Refer to issue https://github.com/apache/tvm/issues/10211. The CSE pass can't handle commutativity because the arith system may not be able to do the commutativity。
The determination of the equality of two expressions (PrimExpr) is to use the method of structured determination, that is, to traverse the hierarchical structure of the two expressions, and to judge while traversing, if the structures of the two expressions are the same, and the smallest If the child nodes (nodes that cannot be traversed, generally such as Var) are the same, the expressions are considered to be the same.
Before performing PrimExpr comparison, a series of rewrite rules will be used to rewrite expressions to solve some operational problems, such as x * y + x * z will be rewritten as x * (y + z), so that to deal with distributivity. However, commutativity cannot be rewritten due to the characteristics of rewrite (it will fall into an infinite loop). This makes it impossible to compare the equality of some expressions, such as a * b *c != a * c * b, (a * b) * c != a * (b * c).
To solve this problem, one solution is to sort and rewrite the expressions according to the StructuralHash of the Var nodes in the expressions before comparing the expressions. If two expressions are equivalent if they satisfy the commutativity, then they will definitely produce equivalent expressions of the same structure after sorting.
Under the assumption that the two expressions have the same structure, the determination condition that the two expressions satisfying the commutativity are the same can be further refined to the same set of all elements in the sub-expressions satisfying the commutativity. The sub-expressions satisfying the commutativity in the expression can be grasped by constructing the expression syntax tree. The sub-expressions satisfying the commutativity are the sub-trees of the expression tree whose child nodes are identical (the child nodes are OP).
An example:
Var: a, b, c, d, e
StructuralHash(a > b >c >d > e)
cse_var_1 = a * b *c + d + e
cse_var_2 = b * a * c + e + d
Sort and rewrite cse_var_1 and cse_var_2, first extract their first subexpressions a * b *c and b * a * c, and rewrite them as a * b * c, and then extract the sub-expressions again ( a * b *c) + d + e and (a *b * c) + e + d, rewriting the sort as (a * b *c) + d + e, get the same expression.
cc @masahi @Hzfengsy @tqchen @FranckQC
I can take a look at this tomorrow.
It would be useful to make it as a different equality comparator (rather than change DeepEqual's behavior, e.g. we can do CommunicativeDeepEqual as a subclass), as communicative rewrite is something that goes deeper.
Another possibility is add a canonicalization pass to canonicalize the expressions before CSE
Hello!
Thank you for the discussion and the PR. Just a few thoughts here:
-
1. Indeed, we knew that the existing
Analyzer::Simplify()would not deal with commutativity, because it does not implement a normalization procedure that is guaranteed to converge towards the normal form (which would indeed imply sorting sub-terms, etc). Rather, all it does is it tries to "do its best" by rewriting some known patterns, with no guarantees to converge to a normal form. For this reason, they could not deal with commutativity inSimplify(), because that would indeed lead to non-terminating rewrite sequences (or more realistically return junk, as in practice they stop to rewrite after N rewrites are done). It often works fairly well in practice, but there is no guarantee of being complete (it behaves like a heuristic). However, it really must be correct (i.e, the result of simplify() must rely be equivalent to its input with algebraic laws). -
2. Before trying to make
Analyzer::Simplify()able to deal with with commutativity, it could be useful to see if people are in practice facing the issue where a lot of redundant computations appear, but written differently due to the commutativity of some operators (like + and *). If so, it would be cool to see such concrete examples. To be honest, I don't even think that that many people are turning ON the already existingbool identify_equiv_termsof the CSE passPass CommonSubexprElimTIR, which usesAnalyzer::Simplify(), which itself does what it can with associativity, neutral elements, etc. These things are probably pretty rare, and commutativity is probably too.
Although I designed the pass in a way that it can potentially identify terms that are equivalent according to any equivalence relation (instead of just the syntactical equality ExprDeepEqual), it might not necessary be often needed in practice. This design that allows to use a custom equivalence relation did not cost much more, so I went for it, in case it could become useful to someone one day for a particular case, so that it would not be needed to write another CSE pass for dealing with that. But I don't necessary think that we should do comparisons modulo equivalence often, and this is why by default bool identify_equiv_terms of the Pass CommonSubexprElimTIR is set to false.
-
3. If we decide that
Analyzer::Simplify()(or another new Analyzer!) should deal with commutativity, then it should itself deal with commutativity, rather than baking commutativity intoExprDeepEqualwhich is supposed to be just a deep syntactical equality, and is used as such in many many places of TVM's codebase (not just the CSE). So clearly it should not being changed (as @tqchen noted too). -
4. Remember that normalizing terms properly in order to deal with commutativity (which indeed includes sorting sub-terms) will likely be computationally expensive, which will make compilation of ML models longer. Actually, just the smaller work that
Analyzer::Simplify()does is already time consuming, and that's probably why people leave thebool identify_equiv_termsof the CSE pass set to false (which is the default behavior, as I wrote earlier). It might make people want even less to turn ON thisbool identify_equiv_terms. Perhaps the pseudo-normalization thatAnalyzer::Simplify()does is not too bad as a compromise: still usable in practice when needed (i.e does not take too long), and deals with most simplifications needed (although not commutativity). -
5. If all the other algebraic properties (associativity, simplification of neutral elements, etc) are still done by the pseudo-normalization
Analyzer::Simplify()that is not guaranteed to find a normal form, I am not sure that a "normalizer for commutativity" built on top would be complete -even just in regard to commutativity. Is it worth it to then makeAnalyzer::Simplify()slower while still being incomplete?
Thanks!
Hello!
Thank you for the discussion and the PR. Just a few thoughts here:
- 1. Indeed, we knew that the existing
Analyzer::Simplify()would not deal with commutativity, because it does not implement a normalization procedure that is guaranteed to converge towards the normal form (which would indeed imply sorting sub-terms, etc). Rather, all it does is it tries to "do its best" by rewriting some known patterns, with no guarantees to converge to a normal form. For this reason, they could not deal with commutativity inSimplify(), because that would indeed lead to non-terminating rewrite sequences (or more realistically return junk, as in practice they stop to rewrite after N rewrites are done). It often works fairly well in practice, but there is no guarantee of being complete (it behaves like a heuristic). However, it really must be correct (i.e, the result of simplify() must rely be equivalent to its input with algebraic laws).- 2. Before trying to make
Analyzer::Simplify()able to deal with with commutativity, it could be useful to see if people are in practice facing the issue where a lot of redundant computations appear, but written differently due to the commutativity of some operators (like + and *). If so, it would be cool to see such concrete examples. To be honest, I don't even think that that many people are turning ON the already existingbool identify_equiv_termsof the CSE passPass CommonSubexprElimTIR, which usesAnalyzer::Simplify(), which itself does what it can with associativity, neutral elements, etc. These things are probably pretty rare, and commutativity is probably too.Although I designed the pass in a way that it can potentially identify terms that are equivalent according to any equivalence relation (instead of just the syntactical equality
ExprDeepEqual), it might not necessary be often needed in practice. This design that allows to use a custom equivalence relation did not cost much more, so I went for it, in case it could become useful to someone one day for a particular case, so that it would not be needed to write another CSE pass for dealing with that. But I don't necessary think that we should do comparisons modulo equivalence often, and this is why by defaultbool identify_equiv_termsof thePass CommonSubexprElimTIRis set to false.
- 3. If we decide that
Analyzer::Simplify()(or another new Analyzer!) should deal with commutativity, then it should itself deal with commutativity, rather than baking commutativity intoExprDeepEqualwhich is supposed to be just a deep syntactical equality, and is used as such in many many places of TVM's codebase (not just the CSE). So clearly it should not being changed (as @tqchen noted too).- 4. Remember that normalizing terms properly in order to deal with commutativity (which indeed includes sorting sub-terms) will likely be computationally expensive, which will make compilation of ML models longer. Actually, just the smaller work that
Analyzer::Simplify()does is already time consuming, and that's probably why people leave thebool identify_equiv_termsof the CSE pass set to false (which is the default behavior, as I wrote earlier). It might make people want even less to turn ON thisbool identify_equiv_terms. Perhaps the pseudo-normalization thatAnalyzer::Simplify()does is not too bad as a compromise: still usable in practice when needed (i.e does not take too long), and deals with most simplifications needed (although not commutativity).- 5. If all the other algebraic properties (associativity, simplification of neutral elements, etc) are still done by the pseudo-normalization
Analyzer::Simplify()that is not guaranteed to find a normal form, I am not sure that a "normalizer for commutativity" built on top would be complete -even just in regard to commutativity. Is it worth it to then makeAnalyzer::Simplify()slower while still being incomplete?Thanks!
Hi,@FranckQC,thanks for the questions and comments。
I quite agree with what you said about the complexity of the Analyzer::Simplify(), the overhead of the Analyzer::Simplify() is very expensive, changes to the ExprDeepEqual will not improve the function and may slow down its performance, so, as suggested by tqchen, make it as A subclass, maybe a good approach.
The source of this submission is that in my previous use, there was a size judgment for the input and output size of the operator containing the reshape operation, and I rewrote deep equal to meet my needs. After seeing your issue, I think this part of the rewrite is helpful, so submit this PR.
It would be useful to make it as a different equality comparator (rather than change DeepEqual's behavior, e.g. we can do CommunicativeDeepEqual as a subclass), as communicative rewrite is something that goes deeper.
Another possibility is add a canonicalization pass to canonicalize the expressions before CSE
Hi,@tqchen,Thanks a lot for the review! I agree that it would be useful to make it as a different equality comparator, this will not have an impact on TVM's infrastructure. So how about making it a new subclass, as you suggested?
Hi,@tqchen. I have implemented the function as a subclass,would you like to take a look and merge it if everything looks good?
@FranckQC @masahi can you help to take a look
@FranckQC @masahi can you help to take a look
Sure, will do on Monday if that's ok :)
of course
Hi, @FranckQC @masahi. Is there anything in the code that I need to update?
Hi, @FranckQC @masahi. If nothing needs to be changed in this PR, can you help merge it?
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
- Built docs for commit 3d9e540c5807cfbc72b4cd6bae6c5adc97ffc835 can be found here.
Generated by tvm-bot
Please update the PR title and description to reflect the current status. In particular, please make it more concise and explain what the goal is clearly.
For example, rather than "Add expr hash sort in ExprDeepEqual", explain why you want to do this.
I was finally able to take some time this week to have a closer look at this PR. It is definitely better than it was before, as it leaves the current DeepEqual unchanged, which was very important for me. That was my initial point 3 in my earlier comment, which I consider being addressed, thank you :).
However, I agree with @masahi comments. The implementation could have a lot more comments and documentation about the variables being used, what the function do, and what parts of the algorithm do. It would help a lot reading the code, which increases the confidence one can have in the implementation.
It's a great thing that there is quite a lot of tests, thanks for taking the time to write many of them. However, I'd also like to be able to see a real usage for new equivalence relation (that was point 2 in my earlier comment). TVM is a compiler, not a tool for just doing algebraic manipulations of mathematical terms like Matlab, so I would really like to see some natural use cases for this, where this get used/integrated into a pass, or into something else that ultimately lead to improvements in the code produced by the compilation of some ML models.
More minor thing: Finally, I'd like to know how one is supposed to use this CommutativeDeepEqual along the Analyzer::Simplify() function that performs other kind of simplification (simplification of neutral elements, applying distributivity, etc, but which unfortunately can't handle commutativity, as discussed earlier in the thread), in order to have a function that uses all the algebraic properties available. I imagine it would call Simplify() on both sides and then uses this new CommutativeDeepEqual. Would that be enough for being complete?
The reason behind that is the following: I believe most of the people who could be interested in equality-modulo-commutativity will be coming here after having discovered that Analyzer::Simplify() can't do all the simplifications for them. So when they will learn that there is this CommutativeDeepEqual equivalence relation for dealing with commutativity, their first question will likely be "how do I combine both?". So I think demonstrating that could be useful.
The most important thing for me at this stage for this PR are to add comments to the code, and to show some real use case/integration for this. Thank you for your work!