tvm
tvm copied to clipboard
[Bug] Cannot add loops on top of the root block
Actual behavior
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/simple/bug_add_loop.py", line 51, in <module>
mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
raise_last_ffi_error()
File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
ValueError: Traceback (most recent call last):
8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
7: tvm::transform::Pass::operator()(tvm::IRModule) const
6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_3tir9transform18DefaultGPUScheduleEvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
3: tvm::tir::transform::ThreadBind(tvm::tir::Schedule, tvm::tir::BlockRV const&, long, long)
2: tvm::tir::TracedScheduleNode::AddUnitLoop(tvm::tir::BlockRV const&)
1: tvm::tir::ConcreteScheduleNode::AddUnitLoop(tvm::tir::BlockRV const&)
0: tvm::tir::AddUnitLoop(tvm::tir::ScheduleState, tvm::tir::StmtSRef)
File "/software/tvm-lunder/src/tir/schedule/primitive/loop_transformation.cc", line 1153
ValueError: Check failed: (sref->parent != nullptr) is false: Cannot add loops on top of the root block
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def min(v3_0: T.Buffer((T.int64(63), T.int64(1)), "float16"), v3_0_red: T.Buffer((T.int64(63),), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, k1 in T.grid(T.int64(63), T.int64(1)):
with T.block("v3_0_red"):
v_ax0, v_k1 = T.axis.remap("SR", [ax0, k1])
T.reads(v3_0[v_ax0, v_k1])
T.writes(v3_0_red[v_ax0])
with T.init():
v3_0_red[v_ax0] = T.float16(65504)
v3_0_red[v_ax0] = T.min(v3_0_red[v_ax0], v3_0[v_ax0, v_k1])
@T.prim_func(private=True)
def scatter_elements(var_x: T.handle, var_indices: T.handle, var_updates: T.handle, out_buf: T.Buffer((T.int64(4), T.int64(4)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
x = T.match_buffer(var_x, (T.int64(4), T.int64(4)), offset_factor=1)
indices = T.match_buffer(var_indices, (T.int64(2), T.int64(2)), "int64", offset_factor=1)
updates = T.match_buffer(var_updates, (T.int64(2), T.int64(2)), offset_factor=1)
with T.block("scatter_elements_generic"):
T.reads()
T.writes()
for i in T.parallel(T.int64(16)):
out_buf[i // T.int64(4), i % T.int64(4)] = x[i // T.int64(4), i % T.int64(4)]
for fused in T.parallel(T.int64(2)):
for k in range(T.int64(2)):
out_buf[(fused * T.int64(4) + (indices[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)] + T.Cast("int64", indices[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)] < T.int64(0)) * T.int64(4))) // T.int64(4), (fused * T.int64(4) + (indices[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)] + T.Cast("int64", indices[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)] < T.int64(0)) * T.int64(4))) % T.int64(4)] = updates[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)]
@R.function
def main(v3_0: R.Tensor((63, 1), dtype="float16")) -> R.Tensor((4, 4), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.min, (v3_0,), out_sinfo=R.Tensor((63,), dtype="float16"))
R.output(lv)
return lv
mod = Module
#mod = tvm.relax.transform.DeadCodeElimination()(mod)
mod.show()
with tvm.target.Target("cuda"):
mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
ex = relax.build(mod, target='cuda')
cc @Lunderberg @junrushao
For the dead code (i.e., def scatter_elements()...), if we keep it and the call mod = tvm.tir.transform.DefaultGPUSchedule()(mod), to execute the relax IR, the test case will crash unexpectedly. However, if we remove the dead function, the test case can run well.
@Lunderberg Can you help me review this issue? Thanks!
Hmm. The PrimFunc definition is a bit odd. The present of the with T.block means that it is schedulable TIR, but there aren't any with T.block annotations inside the loops themselves. So the body looks like it is after the ConvertBlocksToOpaque transform, but DefaultGPUSchedule requires the annotations from before that.
The reason why it works when the dead function is removed is because DefaultGPUSchedule attempts to schedule all TIR functions, regardless of whether they are actually used.