tvm
tvm copied to clipboard
[Relax] Implement relax.transform.RemoveSymbolicExpressionsInSubroutine
This is a follow-up commit to
https://github.com/apache/tvm/pull/16637, which updated relax.transform.FuseOps to provide additional parameters defining symbolic variables required by the fused functions. While this ensures that relax.transform.FuseOps produces well-formed Relax functions, these additional arguments can break some kernel implementations.
This commit implements a new transform
RemoveSymbolicExpressionsInSubroutine to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes.
For example, consider the following Relax function:
@R.function
def func(
data: R.Tensor(["batch_size * seq_len", "hidden_size"]),
weights: R.Tensor(["hidden_size", "intermediate_size"]),
dummy_arg: R.Shape(["batch_size", "seq_len"]),
) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]):
batch_size = T.int64()
seq_len = T.int64()
intermediate_size = T.int64()
hidden_size = T.int64()
output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights)
return output
The data tensor may be used to infer hidden_size, but cannot be used to infer batch_size or seq_len. The R.Shape parameter exists solely to define batch_size and seq_len, since all symbolic variables must be defined. However, neither batch_size nor seq_len are ever used outside of the expression batch_size * seq_len, and the value of batch_size * seq_len could be inferred from the shape of the data tensor.
This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the dummy_arg: R.Shape be entirely unused, so a later use of relax.transform.RemoveUnusedParameters() can remove the parameter altogether.
@R.function
def func(
data: R.Tensor(["data_dim0", "hidden_size"]),
weights: R.Tensor(["hidden_size", "intermediate_size"]),
dummy_arg: R.Shape(["batch_size", "seq_len"]),
):
data_dim0 = T.int64()
intermediate_size = T.int64()
hidden_size = T.int64()
output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights)
return output
This transform is intended to be used in the implementation of https://github.com/apache/tvm/pull/16450, as recommended here.