tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Draft][Transform] Check for zero-param operators in LiftTransformParams

Open Lunderberg opened this issue 1 year ago • 1 comments

Prior to this commit, LiftTransformParams would extract out all variable binding that have no runtime dependencies. As a result, expressions such as R.zeros([16], "int32") would be extracted out into the parameter transformation, even though they do not depend on any parameters.

This commit updates LiftTransformParams to only output variables that depend on at least one compile-time parameter.

The unit test for this functionality also found that relax::Call was erroneously calling MarkGraphNode in SEqualReduce and SHashReduce. This should only be called for nodes that have have reference equality, such as relax::Var, and not for composite objects. This caused erroneous failures in the unit test when two instances of R.zeros([16], "int32") were being compared by reference equality in StructuralEqual. These extra calls to MarkGraphNode have been removed.

Lunderberg avatar Feb 16 '24 19:02 Lunderberg

This PR depends on changes made in https://github.com/apache/tvm/pull/16594, and is marked as a draft until it lands.

Lunderberg avatar Feb 16 '24 19:02 Lunderberg

As a result, expressions such as R.zeros([16], "int32") would be extracted out into the parameter transformation, even though they do not depend on any parameters. Does this affect the result? If a weight transformation depends on some values like R.zeros, such transformation will no longer be lifted if R.zeros is not lifted, maybe better we can check such dependency

vinx13 avatar Feb 21 '24 01:02 vinx13

Does this affect the result? If a weight transformation depends on some values like R.zeros, such transformation will no longer be lifted if R.zeros is not lifted, maybe better we can check such dependency

Good question, and this case should be handled, by allowing zero-param operators to potentially appear in both functions. (See this unit test for how this looks in practice.) While the case isn't ever explicitly handled, it instead results from the overall lifting.

  1. Add every variables binding that doesn't require runtime parameters to the lifted transform_params function.
  2. For every variable binding that is present in transform_params and depends on the model weights, replace with the output from transform_params.
  3. Run dead-code elimination on both functions.

Since the zero-param operators don't depend on runtime parameters, they appear in the transform_params. Since the zero-param operators don't depend on the model weights, they aren't replaced in the original function. If they are only required in one of the two functions, then the dead-code elimination will remove them from the other.

Lunderberg avatar Feb 21 '24 04:02 Lunderberg

I realize it's a draft but I had a look anyway. Since the code changes also included those from https://github.com/apache/tvm/pull/16594, it was a little difficult to see what had changed.

Apologies there. Once #16594 lands, the "Files Changed" tab in this PR should update to only show the changes unique to this PR. In the meantime, this PR branch has its changes in a separate commit (link), where they can be viewed separately from the #16594 changes.

Lunderberg avatar Feb 22 '24 15:02 Lunderberg

Thank you for the changes. The new unit test and the new comment are both helpful.

slyubomirsky avatar Feb 22 '24 22:02 slyubomirsky

With #16594 landed, I've rebased this PR on top of it, and it is now ready for review.

Lunderberg avatar Feb 23 '24 15:02 Lunderberg