tvm
tvm copied to clipboard
[Bug] [Relax] cannot remove the hint_on_device
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/7bug_assert.py", line 25, in <module>
tvm.ir.assert_structural_equal(mod_seq, mod) # assert failed
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/ir/base.py", line 256, in assert_structural_equal
_ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) # type: ignore # pylint: disable=no-member
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
ValueError: Traceback (most recent call last):
5: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
4: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}>(tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
3: _ZN3tvm20SEqualHandlerDefault5EqualERKNS_
2: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
1: tvm::SEqualHandlerDefault::Impl::RunTasks()
0: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
File "/software/tvm/src/node/structural_equal.cc", line 392
ValueError: StructuralEqual check failed, caused by lhs at <root>.functions[I.GlobalVar("foo")].body.blocks[0].bindings[0].value:
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"attr": 10})
I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
@R.function
def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"):
with R.dataflow():
lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0") = y
^
R.output(lv0)
return lv0
and rhs at <root>.functions[I.GlobalVar("foo")].body.blocks[0].bindings[0].value:
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"attr": 10})
I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
@R.function
def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"):
with R.dataflow():
lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
R.output(lv0)
return lv0
Environment
TVM: 0.17.dev0
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"attr": 10})
I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
@R.function
def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
cls = Module
with R.dataflow():
lv0: R.Tensor((2, 3), dtype="float32") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
R.output(lv0)
return lv0
mod = Module
mod_seq = tvm.transform.Sequential([relax.transform.RealizeVDevice()])(mod)
mod = relax.transform.RealizeVDevice()(mod)
mod_seq.show()
mod.show() # cannot remove the 'hint_on_device'
tvm.ir.assert_structural_equal(mod_seq, mod) # assert failed
cc @junrushao
Hi all, I found the tvm.transform.Sequential([relax.transform.RealizeVDevice()])(mod) can remove the "R.hint_on_device()", however, using relax.transform.RealizeVDevice()(mod) directly cannot remove it.
@Lunderberg @tqchen Why do the different usages for the same pass (i.e., RealizeVDevice) give different optimization results?
Hmm. It looks like it's even weirder. It looks like only the first use of RealizeVDevice produces the correct output. The first execution correctly removes the R.hint_on_device, but the second execution does not.
relax.transform.RealizeVDevice()(Module).show(name='First')
relax.transform.RealizeVDevice()(Module).show(name='Second')
Aha! The problem is that HintOnDeviceRemover (the first step of RealizeVDevice) is mutating the relax expression in-place, which is not legal. As a result, expressions that are in the input Module are being mutated to have different StructInfo. The second time that RealizeVDevice is applied, its input has been mutated to already include vdevice annotations.
## Running these commands
Module["foo"].show(name="Before")
relax.transform.RealizeVDevice()(Module)
Module["foo"].show(name="OrigAfter")
## Produces the following output
@R.function
def Before(
x: R.Tensor((2, 3), dtype="float32"),
y: R.Tensor((2, 3), dtype="float32"),
z: R.Tensor((2, 3), dtype="float32"),
) -> R.Tensor((2, 3), dtype="float32"):
R.func_attr({"global_symbol": "foo"})
with R.dataflow():
lv0: R.Tensor((2, 3), dtype="float32") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
R.output(lv0)
return lv0
@R.function
def OrigAfter(
x: R.Tensor((2, 3), dtype="float32"),
y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:-1"),
z: R.Tensor((2, 3), dtype="float32"),
) -> R.Tensor((2, 3), dtype="float32"):
R.func_attr({"global_symbol": "foo"})
with R.dataflow():
lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:-1") = R.hint_on_device(
y, R.device(dev_type=1, dev_id=0)
)
R.output(lv0)
return lv0
The input module should never be modified when running any IRModule transform, so this definitely narrows the bug down to the RealizeVDevice implementation.
@Cookiee235 Can you verify the fix implemented in #17213? It removes the in-place mutation from RealizeVDevice, and should resolve this issue.
Edit: Whoops, meant @MellowArtisan . With multiple issues/PRs in-flight, I got them mixed up.
@Cookiee235 Can you verify the fix implemented in #17213? It removes the in-place mutation from
RealizeVDevice, and should resolve this issue.
@Lunderberg Yes! This PR fixed the bug. Thanks for your efforts!