tvm
tvm copied to clipboard
[Draft][Transform] Check for zero-param operators in LiftTransformParams
Prior to this commit, LiftTransformParams
would extract out all variable binding that have no runtime dependencies. As a result, expressions such as R.zeros([16], "int32")
would be extracted out into the parameter transformation, even though they do not depend on any parameters.
This commit updates LiftTransformParams
to only output variables that depend on at least one compile-time parameter.
The unit test for this functionality also found that relax::Call
was erroneously calling MarkGraphNode
in SEqualReduce
and SHashReduce
. This should only be called for nodes that have have reference equality, such as relax::Var
, and not for composite objects. This caused erroneous failures in the unit test when two instances of R.zeros([16], "int32")
were being compared by reference equality in StructuralEqual
. These extra calls to MarkGraphNode
have been removed.
This PR depends on changes made in https://github.com/apache/tvm/pull/16594, and is marked as a draft until it lands.
As a result, expressions such as R.zeros([16], "int32") would be extracted out into the parameter transformation, even though they do not depend on any parameters.
Does this affect the result? If a weight transformation depends on some values like R.zeros
, such transformation will no longer be lifted if R.zeros
is not lifted, maybe better we can check such dependency
Does this affect the result? If a weight transformation depends on some values like
R.zeros
, such transformation will no longer be lifted ifR.zeros
is not lifted, maybe better we can check such dependency
Good question, and this case should be handled, by allowing zero-param operators to potentially appear in both functions. (See this unit test for how this looks in practice.) While the case isn't ever explicitly handled, it instead results from the overall lifting.
- Add every variables binding that doesn't require runtime parameters to the lifted
transform_params
function. - For every variable binding that is present in
transform_params
and depends on the model weights, replace with the output fromtransform_params
. - Run dead-code elimination on both functions.
Since the zero-param operators don't depend on runtime parameters, they appear in the transform_params
. Since the zero-param operators don't depend on the model weights, they aren't replaced in the original function. If they are only required in one of the two functions, then the dead-code elimination will remove them from the other.
I realize it's a draft but I had a look anyway. Since the code changes also included those from https://github.com/apache/tvm/pull/16594, it was a little difficult to see what had changed.
Apologies there. Once #16594 lands, the "Files Changed" tab in this PR should update to only show the changes unique to this PR. In the meantime, this PR branch has its changes in a separate commit (link), where they can be viewed separately from the #16594 changes.
Thank you for the changes. The new unit test and the new comment are both helpful.
With #16594 landed, I've rebased this PR on top of it, and it is now ready for review.