tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] [Relax] cannot remove the hint_on_device

Open MellowArtisan opened this issue 1 year ago • 5 comments

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/7bug_assert.py", line 25, in <module>
    tvm.ir.assert_structural_equal(mod_seq, mod)  # assert failed
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/base.py", line 256, in assert_structural_equal
    _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)  # type: ignore # pylint: disable=no-member
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  5: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  4: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}>(tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  3: _ZN3tvm20SEqualHandlerDefault5EqualERKNS_
  2: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
  1: tvm::SEqualHandlerDefault::Impl::RunTasks()
  0: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  File "/software/tvm/src/node/structural_equal.cc", line 392
ValueError: StructuralEqual check failed, caused by lhs at <root>.functions[I.GlobalVar("foo")].body.blocks[0].bindings[0].value:
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"):
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0") = y
                                                                       ^
            R.output(lv0)
        return lv0
and rhs at <root>.functions[I.GlobalVar("foo")].body.blocks[0].bindings[0].value:
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})
    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32", vdevice="llvm:0"):
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:0") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
                                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            R.output(lv0)
        return lv0

Environment

TVM: 0.17.dev0

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"attr": 10})
    I.module_global_infos({"vdevice": [I.vdevice({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}, 0, "global"), I.vdevice({"arch": "sm_50", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global"), I.vdevice({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}, 0, "global"), I.vdevice({"arch": "sm_80", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}, 0, "global")]})

    @R.function
    def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32"), z: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv0: R.Tensor((2, 3), dtype="float32") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
            R.output(lv0)
        return lv0

mod = Module
mod_seq = tvm.transform.Sequential([relax.transform.RealizeVDevice()])(mod)
mod = relax.transform.RealizeVDevice()(mod)
mod_seq.show()
mod.show()  # cannot remove the 'hint_on_device'
tvm.ir.assert_structural_equal(mod_seq, mod)  # assert failed

cc @junrushao

MellowArtisan avatar Jul 26 '24 17:07 MellowArtisan

Hi all, I found the tvm.transform.Sequential([relax.transform.RealizeVDevice()])(mod) can remove the "R.hint_on_device()", however, using relax.transform.RealizeVDevice()(mod) directly cannot remove it.

@Lunderberg @tqchen Why do the different usages for the same pass (i.e., RealizeVDevice) give different optimization results?

MellowArtisan avatar Jul 26 '24 17:07 MellowArtisan

Hmm. It looks like it's even weirder. It looks like only the first use of RealizeVDevice produces the correct output. The first execution correctly removes the R.hint_on_device, but the second execution does not.

relax.transform.RealizeVDevice()(Module).show(name='First')
relax.transform.RealizeVDevice()(Module).show(name='Second')

Lunderberg avatar Jul 26 '24 18:07 Lunderberg

Aha! The problem is that HintOnDeviceRemover (the first step of RealizeVDevice) is mutating the relax expression in-place, which is not legal. As a result, expressions that are in the input Module are being mutated to have different StructInfo. The second time that RealizeVDevice is applied, its input has been mutated to already include vdevice annotations.

## Running these commands

Module["foo"].show(name="Before")
relax.transform.RealizeVDevice()(Module)
Module["foo"].show(name="OrigAfter")

## Produces the following output

@R.function
def Before(
    x: R.Tensor((2, 3), dtype="float32"),
    y: R.Tensor((2, 3), dtype="float32"),
    z: R.Tensor((2, 3), dtype="float32"),
) -> R.Tensor((2, 3), dtype="float32"):
    R.func_attr({"global_symbol": "foo"})
    with R.dataflow():
        lv0: R.Tensor((2, 3), dtype="float32") = R.hint_on_device(y, R.device(dev_type=1, dev_id=0))
        R.output(lv0)
    return lv0


@R.function
def OrigAfter(
    x: R.Tensor((2, 3), dtype="float32"),
    y: R.Tensor((2, 3), dtype="float32", vdevice="llvm:-1"),
    z: R.Tensor((2, 3), dtype="float32"),
) -> R.Tensor((2, 3), dtype="float32"):
    R.func_attr({"global_symbol": "foo"})
    with R.dataflow():
        lv0: R.Tensor((2, 3), dtype="float32", vdevice="llvm:-1") = R.hint_on_device(
            y, R.device(dev_type=1, dev_id=0)
        )
        R.output(lv0)
    return lv0

The input module should never be modified when running any IRModule transform, so this definitely narrows the bug down to the RealizeVDevice implementation.

Lunderberg avatar Jul 26 '24 19:07 Lunderberg

@Cookiee235 Can you verify the fix implemented in #17213? It removes the in-place mutation from RealizeVDevice, and should resolve this issue.

Edit: Whoops, meant @MellowArtisan . With multiple issues/PRs in-flight, I got them mixed up.

Lunderberg avatar Jul 29 '24 15:07 Lunderberg

@Cookiee235 Can you verify the fix implemented in #17213? It removes the in-place mutation from RealizeVDevice, and should resolve this issue.

@Lunderberg Yes! This PR fixed the bug. Thanks for your efforts!

Cookiee235 avatar Jul 29 '24 16:07 Cookiee235