tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Relax] Implement Function.check_for_special_case

Open Lunderberg opened this issue 1 year ago • 3 comments

If a dynamic model is frequently called with specific arguments or shapes of arguments, performance may be improved by generating to specialized versions of the model. Previously, specialized versions of a relax function func could be generated using func.bind_params and func.bind_symbolic_vars. However, use of these specialized versions requires the calling scope to explicitly check the preconditions of each kernel and call the appropriate one.

This commit implements a new utility, check_for_special_case, which handles both the generating of the special case, and checking whether the special case applies. The function's user-facing signature is unmodified, while internally it delegates to either the original function or the specialized version depending on the result of the check. This allows optimized kernels for specific static shapes to be introduced solely by changing the optimization pipeline, with no changes required in the calling scope.

Lunderberg avatar Jan 23 '24 16:01 Lunderberg

Currently a proof-of-concept, as it would require testing

  1. Verify that relax::IfNode supports a relax::PrimValue with boolean dtype.
  2. Handling special-cases that check for a relax value (e.g. R.zeros), rather than just symbolic variables.
  3. Improved StructInfo inference when calling a function within a constrained scope. (e.g. It's okay to pass R.Tensor([N]) to a function expecting R.Tensor([16]) if it is within a scope of if N==16.)
  4. Improved TVMScript parsing for relax::If. Currently, the use of Emit without a previously defined struct info causes EraseToWellDefined to remove the known shape.
  5. Improved LCA handling for the branches. For example, if the if N==16 branch produces R.Tensor([16]) and the else branch produces R.Tensor([N]), the LCA should be R.Tensor([N]). Currently, this results in R.Tensor(ndim=1).

Lunderberg avatar Jan 23 '24 16:01 Lunderberg

Looks great! It seems like a very clean way to approach the problem.

jwfromm avatar Jan 23 '24 17:01 jwfromm

With https://github.com/apache/tvm/pull/16642 landed, the functionality in this PR can now be compiled, and run. Thank you to @csullivan for adding an IRModule transform, so the dynamic check can be inserted as part of an optimization pipeline.

The unit tests in this PR require https://github.com/apache/tvm/pull/16844 to pass, and so this PR branch includes those changes in its history. Other than that, everything is ready and working, so I'm marking this PR as ready to review.

Lunderberg avatar Apr 03 '24 19:04 Lunderberg