tvm
tvm copied to clipboard
[Draft][Transform] Check for zero-param operators in LiftTransformParams
Prior to this commit, LiftTransformParams would extract out all variable binding that have no runtime dependencies. As a result, expressions such as R.zeros([16], "int32") would be extracted out into the parameter transformation, even though they do not depend on any parameters.
This commit updates LiftTransformParams to only output variables that depend on at least one compile-time parameter.
The unit test for this functionality also found that relax::Call was erroneously calling MarkGraphNode in SEqualReduce and SHashReduce. This should only be called for nodes that have have reference equality, such as relax::Var, and not for composite objects. This caused erroneous failures in the unit test when two instances of R.zeros([16], "int32") were being compared by reference equality in StructuralEqual. These extra calls to MarkGraphNode have been removed.
This PR depends on changes made in https://github.com/apache/tvm/pull/16594, and is marked as a draft until it lands.
As a result, expressions such as R.zeros([16], "int32") would be extracted out into the parameter transformation, even though they do not depend on any parameters.
Does this affect the result? If a weight transformation depends on some values like R.zeros, such transformation will no longer be lifted if R.zeros is not lifted, maybe better we can check such dependency
Does this affect the result? If a weight transformation depends on some values like
R.zeros, such transformation will no longer be lifted ifR.zerosis not lifted, maybe better we can check such dependency
Good question, and this case should be handled, by allowing zero-param operators to potentially appear in both functions. (See this unit test for how this looks in practice.) While the case isn't ever explicitly handled, it instead results from the overall lifting.
- Add every variables binding that doesn't require runtime parameters to the lifted
transform_paramsfunction. - For every variable binding that is present in
transform_paramsand depends on the model weights, replace with the output fromtransform_params. - Run dead-code elimination on both functions.
Since the zero-param operators don't depend on runtime parameters, they appear in the transform_params. Since the zero-param operators don't depend on the model weights, they aren't replaced in the original function. If they are only required in one of the two functions, then the dead-code elimination will remove them from the other.
I realize it's a draft but I had a look anyway. Since the code changes also included those from https://github.com/apache/tvm/pull/16594, it was a little difficult to see what had changed.
Apologies there. Once #16594 lands, the "Files Changed" tab in this PR should update to only show the changes unique to this PR. In the meantime, this PR branch has its changes in a separate commit (link), where they can be viewed separately from the #16594 changes.
Thank you for the changes. The new unit test and the new comment are both helpful.
With #16594 landed, I've rebased this PR on top of it, and it is now ready for review.