tvm
tvm copied to clipboard
[Bug] assert_structural_equal(specialize PrimFunc, const PrimFunc) failed
Expected behavior
tir.PrimFunc.specialize() works as expected.
Actual behavior
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((4096, 4096), "float16"), B: T.Buffer((4096, 4096), "float16"), C: T.Buffer((4096, 4096), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for m_i, n_i, k_i in T.grid(4096, 4096, 4096):
with T.block("C"):
vm, vn, vk = T.axis.remap("SSR", [m_i, n_i, k_i])
T.reads(A[vm, vk], B[vn, vk])
T.writes(C[vm, vn])
with T.init():
C[vm, vn] = T.float16(0)
C[vm, vn] = C[vm, vn] + A[vm, vk] * B[vn, vk]
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((4096, 4096), "float16"), B: T.Buffer((4096, 4096), "float16"), C: T.Buffer((4096, 4096), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for m_i, n_i, k_i in T.grid(4096, 4096, 4096):
with T.block("C"):
vm, vn, vk = T.axis.remap("SSR", [m_i, n_i, k_i])
T.reads(A[vm, vk], B[vn, vk])
T.writes(C[vm, vn])
with T.init():
C[vm, vn] = T.float16(0)
C[vm, vn] = C[vm, vn] + A[vm, vk] * B[vn, vk]
Traceback (most recent call last):
File "/mnt/kaiwu-user-honglinzhu/ampere_gemm/issue.py", line 53, in <module>
tvm.ir.assert_structural_equal(gemm_d, gemm_const)
File "/mnt/kaiwu-user-honglinzhu/tvm/python/tvm/ir/base.py", line 259, in assert_structural_equal
_ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) # type: ignore # pylint: disable=no-member
File "/mnt/kaiwu-user-honglinzhu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/mnt/kaiwu-user-honglinzhu/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
ValueError: Traceback (most recent call last):
5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::__mk_TVM3::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#1}>(tvm::__mk_TVM3::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::allocator<char>, tvm::runtime::TVMArgs const&)
4: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
3: tvm::SEqualHandlerDefault::Impl::RunTasks()
2: tvm::SEqualHandlerDefault::DispatchSEqualReduce(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
1: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
0: _ZN3tvm7runtime6deta
File "/mnt/kaiwu-user-honglinzhu/tvm/src/node/structural_equal.cc", line 376
ValueError: StructuralEqual check failed, caused by lhs at <root>.buffer_map[a].data.type_annotation.element_type.dtype:
Environment
TVM commit: 90320b2
Steps to reproduce
import tvm
from tvm.script import tir as T
@T.prim_func
def gemm(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
k = T.var("int32")
A = T.match_buffer(a, (m, k), "float16")
B = T.match_buffer(b, (n, k), "float16")
C = T.match_buffer(c, (m, n), "float16")
for m_i in T.serial(m):
for n_i in T.serial(n):
for k_i in T.serial(k):
with T.block("C"):
vm, vn, vk = T.axis.remap("SSR", [m_i, n_i, k_i])
with T.init():
C[vm, vn] = T.float16(0)
C[vm, vn] = C[vm, vn] + A[vm, vk] * B[vn, vk]
@T.prim_func
def gemm_const(a: T.handle, b: T.handle, c: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A = T.match_buffer(a, (4096, 4096), "float16")
B = T.match_buffer(b, (4096, 4096), "float16")
C = T.match_buffer(c, (4096, 4096), "float16")
for m_i in T.serial(4096):
for n_i in T.serial(4096):
for k_i in T.serial(4096):
with T.block("C"):
vm, vn, vk = T.axis.remap("SSR", [m_i, n_i, k_i])
with T.init():
C[vm, vn] = T.float16(0)
C[vm, vn] = C[vm, vn] + A[vm, vk] * B[vn, vk]
if __name__ == "__main__":
data, weight, _ = gemm.params
gemm_sp = gemm.specialize(
{
data: tvm.tir.decl_buffer((4096, 4096)), weight: tvm.tir.decl_buffer((4096, 4096)),
}
)
print(gemm_sp)
print(gemm_const)
tvm.ir.assert_structural_equal(gemm_sp, gemm_const)
Triage
- tir:ir
@Hzfengsy
I tried this out and it looks like the issue is that dtype has to be specified when creating the decl_buffer for specialization. The default dtype would become float32 and the const version directly specified float16.
Something like this
gemm_sp = gemm.specialize(
{
data: tvm.tir.decl_buffer((4096, 4096), dtype="float16"), weight: tvm.tir.decl_buffer((4096, 4096), dtype="float16"),
}
)
However the printed IR for the specialized version is misleading as it shows the dtype as "float16", which I understand is why this issue was created, so that probably has to be fixed.