relax
relax copied to clipboard
[Bug][VM] Cannot pass a closure to a function call
The following program results in a VM code generation error:
import tvm
import tvm.script
from tvm import relax
from tvm.script import relax as R
@tvm.script.ir_module
class PrintClosure:
@R.function
def main():
@R.function
def closure():
return relax.const(1)
y = relax.print(closure)
return y
mod = PrintClosure
mod = relax.transform.LambdaLift()(mod)
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target) # error happens on this line
vm = relax.VirtualMachine(ex, tvm.cpu())
ret = vm["main"]()
The error is as follows:
5: tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
4: tvm::relax::relax_vm::CodeGen(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)
3: tvm::relax::relax_vm::VMCodeGen::CodeGen(tvm::IRModule)
2: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::FunctionNode const*)
1: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::SeqExprNode const*)
0: tvm::relax::ExprFunctor<tvm::runtime::relax_vm::Instruction::Arg (tvm::RelayExpr const&)>::VisitExprDefault_(tvm::runtime::Object const*)
File "[...]/relax/include/tvm/relax/expr_functor.h", line 114
TVMError: Do not have a default for GlobalVar
It looks like there's a missing case in VM code generation
out of curiosity: what is the expected output here? Is there a difference between the following:
-
y = relax.print(closure)
-
y = relax.print(closure())
Calling the closure would return the int. I was curious to see how closures were represented internally, so I'm not sure what the expected output necessarily was but we should be able to generate code for it.
Great exploration! For this specific case, the inner closure
function inside main
is not a closure, since it does not have free variables.
IRModule after lambda lifting (lambda lifting pass lifts all local functions no matter they are closure or not):
@tvm.script.ir_module
class Module:
@R.function
def main() -> Tuple():
# block 0
closure = lifted_func_0
y: Tuple() = relax.print(closure, format="", attrs_type_key="relax.attrs.PrintAttrs")
return y
@R.function
def lifted_func_0() -> Tensor(None, "int32", ndim = 0):
return 1
The closure variable is bound to a GlobalVar(lifted_func_0). We need to fix the codegen to handle binding a GlobalVar to a variable in VMCodegen (currently codegen only supports inline the GlobalVar in CallNode without var binding).
A closure case:
@tvm.script.ir_module
class PrintClosure2:
@R.function
def main(x: Tensor((2, 3), "float32")):
@R.function
def closure():
return x
y = relax.print(closure)
return y
After lambda lifting:
@tvm.script.ir_module
class Module:
@R.function
def main(x: Tensor((2, 3), "float32")) -> Tuple():
# block 0
closure: Object = relax.make_closure(lifted_func_0, (x,))
y: Tuple() = relax.print(closure, format="", attrs_type_key="relax.attrs.PrintAttrs")
return y
@R.function
def lifted_func_0(x1: Tensor((2, 3), "float32")) -> Tensor(None, "float32", ndim = 2):
return x1
The closure is of ObjectType
, and it's defined here: https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h#L43. This IRModule can be compiled and run.
Great. Also, on a very pedantic point, I would say that a function that does not capture anything should still be represented at run time as a closure, hence my using that term :)
I think our codegen should treat the cases uniformly, too (i.e., compile a reference to a global func into a closure with no captured variables), or at least make that the convention for passing functions to packed funcs. Later compilation passes could get rid of the closure wrapper if it's never used.
For printing closures, we don't need to display anything other than that it's a closure (most functional languages treat them as completely opaque).