tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Transform] Improve symbolic variable handling in FuseOps

Open Lunderberg opened this issue 1 year ago • 5 comments

Prior to this commit, FuseOps and FuseOpsByPattern exposed a symbolic variable to the fused function if it was used within the fused function, but wasn't inferable from other parameter shapes. While this prevents undefined symbolic variables, it can cause issues for downstream use of CodegenJSON, which requires all arguments to be tensors, or tuple of tensors.

Frequently, all uses of a non-inferable symbolic shape occur within a symbolic expression that can be inferred. For example, a function that takes arg: R.Tensor([N+1]) and returns R.add(arg, R.const(1)) cannot infer N. However, all occurrences of N occur as part of the expression N+1, and the value of N+1 can be inferred. Therefore, if we replace N+1 with M, the additional ShapeTuple argument isn't required.

In addition, prior to this commit, the CompositeFunctionAnnotator visited the body of functions without the parameters being considered in-scope. As a result, EraseToWellDefined would remove known shapes from the function body's StructInfo.

Lunderberg avatar Jan 22 '24 17:01 Lunderberg

Ideally we don't want to change FuseOps behavior, since in cases where expressions are intermediate (e.g. intermediate compute include values that contains exprs like n * 4).

This is because we should get maybe we should look into compose them? FuseOps first then rewrite signatures

tqchen avatar Jan 22 '24 18:01 tqchen

I could see having a post-processing pass to update the signature, maybe as an extension of RemoveUnusedParameters. There would still need to be an update to FuseOps to have the fused functions marked as private, since the post-processing step would only be allowed to update the signature of internal functions.

Though, could you expand on what you mean by intermediate expressions? In either case, whether implemented in FuseOps or in a post-processing pass, I think intermediate expressions would be handled correctly. If an expression n*4 can be inferred from the tensor shapes, but n+42 also appears in the fused function, then there would still be a shape expr used to expose n to the fused function.

Lunderberg avatar Jan 22 '24 20:01 Lunderberg

Rebased onto main to resolve conflicts.

For long-term, I think I agree that it would be cleaner and more general-purpose to have the functionality separated out into three distinct passes:

  1. FuseOps, with the first commit in this PR to preserve symbolic variables in the ret_struct_info.
  2. A not-yet-existing HoistCommonSubexpressions, which would recognize that a symbolic variable is always used within a specific expression, and would hoist it to the calling scope.
  3. Applying the RemoveUnusedParameters to remove the no-longer-required R.shape param.

Lunderberg avatar Feb 14 '24 16:02 Lunderberg

I've separated the first commit of this PR branch into an independent PR (https://github.com/apache/tvm/pull/16637), as the bugfix it provides is independent of the concerns raised, and does not require the not-yet-implemented HoistCommonSubexpressions transform.

Lunderberg avatar Feb 23 '24 15:02 Lunderberg