tvm
tvm copied to clipboard
[Relax] Implement Function.check_for_special_case
If a dynamic model is frequently called with specific arguments or shapes of arguments, performance may be improved by generating to specialized versions of the model. Previously, specialized versions of a relax function func could be generated using func.bind_params and func.bind_symbolic_vars. However, use of these specialized versions requires the calling scope to explicitly check the preconditions of each kernel and call the appropriate one.
This commit implements a new utility, check_for_special_case, which handles both the generating of the special case, and checking whether the special case applies. The function's user-facing signature is unmodified, while internally it delegates to either the original function or the specialized version depending on the result of the check. This allows optimized kernels for specific static shapes to be introduced solely by changing the optimization pipeline, with no changes required in the calling scope.
Currently a proof-of-concept, as it would require testing
- Verify that
relax::IfNodesupports arelax::PrimValuewith boolean dtype. - Handling special-cases that check for a relax value (e.g.
R.zeros), rather than just symbolic variables. - Improved StructInfo inference when calling a function within a constrained scope. (e.g. It's okay to pass
R.Tensor([N])to a function expectingR.Tensor([16])if it is within a scope ofif N==16.) - Improved TVMScript parsing for
relax::If. Currently, the use ofEmitwithout a previously defined struct info causesEraseToWellDefinedto remove the known shape. - Improved LCA handling for the branches. For example, if the
if N==16branch producesR.Tensor([16])and theelsebranch producesR.Tensor([N]), the LCA should beR.Tensor([N]). Currently, this results inR.Tensor(ndim=1).
Looks great! It seems like a very clean way to approach the problem.
With https://github.com/apache/tvm/pull/16642 landed, the functionality in this PR can now be compiled, and run. Thank you to @csullivan for adding an IRModule transform, so the dynamic check can be inserted as part of an optimization pipeline.
The unit tests in this PR require https://github.com/apache/tvm/pull/16844 to pass, and so this PR branch includes those changes in its history. Other than that, everything is ready and working, so I'm marking this PR as ready to review.