relax icon indicating copy to clipboard operation
relax copied to clipboard

[Bug][VM] Cannot pass a closure to a function call

Open slyubomirsky opened this issue 2 years ago • 4 comments

The following program results in a VM code generation error:

import tvm
import tvm.script
from tvm import relax
from tvm.script import relax as R

@tvm.script.ir_module
class PrintClosure:
    @R.function
    def main():
	@R.function
        def closure():
	    return relax.const(1)
	y = relax.print(closure)
        return y

mod = PrintClosure
mod = relax.transform.LambdaLift()(mod)
target = tvm.target.Target("llvm", host="llvm")
ex = relax.vm.build(mod, target)  # error happens on this line
vm = relax.VirtualMachine(ex, tvm.cpu())

ret = vm["main"]()

The error is as follows:

  5: tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  4: tvm::relax::relax_vm::CodeGen(tvm::IRModule, tvm::runtime::Optional<tvm::runtime::Module>, tvm::runtime::Array<tvm::runtime::Module, void>, tvm::Target, tvm::runtime::Map<tvm::runtime::String, tvm::runtime::NDArray, void, void>)
  3: tvm::relax::relax_vm::VMCodeGen::CodeGen(tvm::IRModule)
  2: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::FunctionNode const*)
  1: tvm::relax::relax_vm::CodeGenVM::VisitExpr_(tvm::relax::SeqExprNode const*)
  0: tvm::relax::ExprFunctor<tvm::runtime::relax_vm::Instruction::Arg (tvm::RelayExpr const&)>::VisitExprDefault_(tvm::runtime::Object const*)
  File "[...]/relax/include/tvm/relax/expr_functor.h", line 114
TVMError: Do not have a default for GlobalVar

It looks like there's a missing case in VM code generation

slyubomirsky avatar Aug 16 '22 01:08 slyubomirsky

out of curiosity: what is the expected output here? Is there a difference between the following:

  1. y = relax.print(closure)
  2. y = relax.print(closure())

psrivas2 avatar Aug 16 '22 13:08 psrivas2

Calling the closure would return the int. I was curious to see how closures were represented internally, so I'm not sure what the expected output necessarily was but we should be able to generate code for it.

slyubomirsky avatar Aug 16 '22 18:08 slyubomirsky

Great exploration! For this specific case, the inner closure function inside main is not a closure, since it does not have free variables.

IRModule after lambda lifting (lambda lifting pass lifts all local functions no matter they are closure or not):

@tvm.script.ir_module
class Module:
    @R.function
    def main() -> Tuple():
        # block 0
        closure = lifted_func_0
        y: Tuple() = relax.print(closure, format="", attrs_type_key="relax.attrs.PrintAttrs")
        return y

    @R.function
    def lifted_func_0() -> Tensor(None, "int32", ndim = 0):
        return 1

The closure variable is bound to a GlobalVar(lifted_func_0). We need to fix the codegen to handle binding a GlobalVar to a variable in VMCodegen (currently codegen only supports inline the GlobalVar in CallNode without var binding).

A closure case:

@tvm.script.ir_module
class PrintClosure2:
    @R.function
    def main(x: Tensor((2, 3), "float32")):
        @R.function
        def closure():
            return x
        y = relax.print(closure)
        return y

After lambda lifting:

@tvm.script.ir_module
class Module:
    @R.function
    def main(x: Tensor((2, 3), "float32")) -> Tuple():
        # block 0
        closure: Object = relax.make_closure(lifted_func_0, (x,))
        y: Tuple() = relax.print(closure, format="", attrs_type_key="relax.attrs.PrintAttrs")
        return y

    @R.function
    def lifted_func_0(x1: Tensor((2, 3), "float32")) -> Tensor(None, "float32", ndim = 2):
        return x1

The closure is of ObjectType, and it's defined here: https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h#L43. This IRModule can be compiled and run.

YuchenJin avatar Aug 16 '22 19:08 YuchenJin

Great. Also, on a very pedantic point, I would say that a function that does not capture anything should still be represented at run time as a closure, hence my using that term :)

I think our codegen should treat the cases uniformly, too (i.e., compile a reference to a global func into a closure with no captured variables), or at least make that the convention for passing functions to packed funcs. Later compilation passes could get rid of the closure wrapper if it's never used.

For printing closures, we don't need to display anything other than that it's a closure (most functional languages treat them as completely opaque).

slyubomirsky avatar Aug 16 '22 19:08 slyubomirsky