tvm
tvm copied to clipboard
[Relax] Allow `out_sinfo` to be omitted from `R.call_tir`
Prior to this commit, the Relax type produced by calling a TIR PrimFunc needed to be explicitly specified using the out_sinfo argument. These output shapes are required in order to allocate output tensors during the CallTIRRewrite lowering pass. However, specifying them explicitly, especially in hand-written functions, duplicates information that is already present in the PrimFunc signature, and introduces the potential for inconsistencies.
This commit updates the MakeCallTIR function to infer out_sinfo if not explicitly specified. This inference uses the number of relax arguments to identify output parameters in the signature of the PrimFunc, which then become the return values from R.call_tir. Currently, this inference of out_sinfo occurs when constructing the relax::Call object, after which the out_sinfo is always present in the Relax IR.
Just want to note that it is not always possible to do such inference.
class IRModule:
@T.prim_func
def reshape(A : Buffer((2, 4)), B: Buffer((n, m)):
def main(A: Buffer((2, 4))):
lv0 = R.call_tir(reshape, [A], R.Tensor((1, 8)))
For example, the above code is a valid tir call, but needs the output sinfo to be explicitly specified. Because we have such cases, and call_tir is a lower level function, it is safer to always ask for sinfo, but checks its consistency with the corresponding prim_func signature if needed
For example, the above code is a valid tir call, but needs the output sinfo to be explicitly specified. Because we have such cases, and
call_tiris a lower level function, it is safer to always ask for sinfo, but checks its consistency with the corresponding prim_func signature if needed
That's a good point, and I agree that we should always be able to explicitly specify the output struct info, as output tensor shapes in TIR may define symbolic shapes. However, I don't think it should a required argument.
I've added a new test case, based on your example with reshape, to validate the behavior when the output shape cannot be inferred. While the initial implementation did identify this failure and throw an error, the error message wasn't ideal. I've added an earlier check for non-inferable output shapes, so that the error message can direct the user to provide the out_sinfo field.
Does the udpated check/error messages address your concerns for this PR?
I think this is mainly a design consideration here on what do we view the intended use of CreateCallTIR, in terms of different expectations we have on caller of the function. I can see some merits on auto deduction or call for explicitness
Given call_tir is lower level, having "less automation" here during pass and have explicitly checking would ensure correctness while indeed asking pass writers to do a bit more. It is like explicitly annotating types when writing c++ code versus writing auto. I think encouraging pass writers to explicitly think about the DPS pattern and always provide the return argument helps to reduce uncertainty here. While I can indeed see some merits of automated decusion, given it is not always possible, I still prefer we have the explicitness and provide good amount of consistency checking
I think encouraging pass writers to explicitly think about the DPS pattern and always provide the return argument helps to reduce uncertainty here.
While I think this would be an interesting point to discuss, I don't think it's relevant to this specific change. This PR keeps the exact same out_sinfo in the C++ IR types, and still requires pass writers to explicitly provide the output info. The MakeCallTIR function is not exposed to the back-end C++ API, only through the front-end Python API.
This change is solely in the front-end, for cases where an IRModule is being hand-written. I'd like to make that use-case less error-prone.
Thanks for pointing out the frontend case, I still think being explicit is helpful and aims for a consistency check with good error messages. Having such explicit argument makes the "intent" clear, with the explicit sinfo, we can write down the semantics in a clear fashion
def call_tir(func, args, out_sinfo):
out = alloc_outputs(out_sinfo)
func(*args, unpack_outputs(out))
return out
omitting the out_sinfo, while indeed ok in some cases, was not always derivable, and the intent was less clear. I know the arguments can go another way to reduce the amount users type. In this particular case, having good well form check about consistency would help a lot toward that direction
Having such explicit argument makes the "intent" clear, with the explicit sinfo, we can write down the semantics in a clear fashion
Good point on the semantics. This change would add an additional step to the user-facing semantics of R.call_tir.
def call_tir(func, args, out_sinfo):
if out_sinfo is None:
out_sinfo = infer_out_sinfo(func, args) # may throw
out = alloc_outputs(out_sinfo)
func(*args, unpack_outputs(out))
return out
I suppose that I'm getting stuck on is the "intent" part. While there are exceptions, in the majority of cases, there's one and only one correct value for out_sinfo. Since the user doesn't have any choice in it, we can't infer any intention from the user about it. On the other hand, if the user has the option of omitting the out_sinfo, then we could distinguish between the intent of "use whichever output is valid" (e.g. R.call_tir(unary_abs, [x])) and "verify and use the output I expect" (e.g. R.call_tir(unary_abs, [x], R.Tensor([16],'float16'))).
In this particular case, having good well form check about consistency would help a lot toward that direction
Agreed. I think for now, let's put this PR on hold, and I'll update the well-formed checker to verify consistent between the R.call_tir callee and the input/output arguments. (Since that's a change that we both agree on, and covers many of the same error modes.)
closed in favor of #17285