tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] assert_structural_equal(specialize PrimFunc, const PrimFunc) failed

Open HongHongHongL opened this issue 1 year ago • 2 comments

Expected behavior

tir.PrimFunc.specialize() works as expected.

Actual behavior

# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((4096, 4096), "float16"), B: T.Buffer((4096, 4096), "float16"), C: T.Buffer((4096, 4096), "float16")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for m_i, n_i, k_i in T.grid(4096, 4096, 4096):
        with T.block("C"):
            vm, vn, vk = T.axis.remap("SSR", [m_i, n_i, k_i])
            T.reads(A[vm, vk], B[vn, vk])
            T.writes(C[vm, vn])
            with T.init():
                C[vm, vn] = T.float16(0)
            C[vm, vn] = C[vm, vn] + A[vm, vk] * B[vn, vk]
# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((4096, 4096), "float16"), B: T.Buffer((4096, 4096), "float16"), C: T.Buffer((4096, 4096), "float16")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    for m_i, n_i, k_i in T.grid(4096, 4096, 4096):
        with T.block("C"):
            vm, vn, vk = T.axis.remap("SSR", [m_i, n_i, k_i])
            T.reads(A[vm, vk], B[vn, vk])
            T.writes(C[vm, vn])
            with T.init():
                C[vm, vn] = T.float16(0)
            C[vm, vn] = C[vm, vn] + A[vm, vk] * B[vn, vk]
Traceback (most recent call last):
  File "/mnt/kaiwu-user-honglinzhu/ampere_gemm/issue.py", line 53, in <module>
    tvm.ir.assert_structural_equal(gemm_d, gemm_const)
  File "/mnt/kaiwu-user-honglinzhu/tvm/python/tvm/ir/base.py", line 259, in assert_structural_equal
    _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)  # type: ignore # pylint: disable=no-member
  File "/mnt/kaiwu-user-honglinzhu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/mnt/kaiwu-user-honglinzhu/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::__mk_TVM3::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#1}>(tvm::__mk_TVM3::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::allocator<char>, tvm::runtime::TVMArgs const&)
  4: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
  3: tvm::SEqualHandlerDefault::Impl::RunTasks()
  2: tvm::SEqualHandlerDefault::DispatchSEqualReduce(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  1: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  0: _ZN3tvm7runtime6deta
  File "/mnt/kaiwu-user-honglinzhu/tvm/src/node/structural_equal.cc", line 376
ValueError: StructuralEqual check failed, caused by lhs at <root>.buffer_map[a].data.type_annotation.element_type.dtype:

Environment

TVM commit: 90320b2

Steps to reproduce

import tvm
from tvm.script import tir as T

@T.prim_func
def gemm(a: T.handle, b: T.handle, c: T.handle) -> None:

    T.func_attr({"global_symbol": "main", "tir.noalias": True})

    m = T.var("int32")
    n = T.var("int32")
    k = T.var("int32")

    A = T.match_buffer(a, (m, k), "float16")
    B = T.match_buffer(b, (n, k), "float16")
    C = T.match_buffer(c, (m, n), "float16")

    for m_i in T.serial(m):
        for n_i in T.serial(n):
            for k_i in T.serial(k):
                with T.block("C"):
                    vm, vn, vk = T.axis.remap("SSR", [m_i, n_i, k_i])
                    with T.init():
                        C[vm, vn] = T.float16(0)
                    C[vm, vn] = C[vm, vn] + A[vm, vk] * B[vn, vk]

@T.prim_func
def gemm_const(a: T.handle, b: T.handle, c: T.handle) -> None:

    T.func_attr({"global_symbol": "main", "tir.noalias": True})

    A = T.match_buffer(a, (4096, 4096), "float16")
    B = T.match_buffer(b, (4096, 4096), "float16")
    C = T.match_buffer(c, (4096, 4096), "float16")

    for m_i in T.serial(4096):
        for n_i in T.serial(4096):
            for k_i in T.serial(4096):
                with T.block("C"):
                    vm, vn, vk = T.axis.remap("SSR", [m_i, n_i, k_i])
                    with T.init():
                        C[vm, vn] = T.float16(0)
                    C[vm, vn] = C[vm, vn] + A[vm, vk] * B[vn, vk]

if __name__ == "__main__":
    data, weight, _ = gemm.params
    gemm_sp = gemm.specialize(
        {
            data: tvm.tir.decl_buffer((4096, 4096)), weight: tvm.tir.decl_buffer((4096, 4096)),
        }
    )
    print(gemm_sp)
    print(gemm_const)
    tvm.ir.assert_structural_equal(gemm_sp, gemm_const)

Triage

  • tir:ir

HongHongHongL avatar Jan 30 '24 08:01 HongHongHongL

@Hzfengsy

HongHongHongL avatar Jan 30 '24 08:01 HongHongHongL

I tried this out and it looks like the issue is that dtype has to be specified when creating the decl_buffer for specialization. The default dtype would become float32 and the const version directly specified float16.

Something like this

gemm_sp = gemm.specialize(
       {
            data: tvm.tir.decl_buffer((4096, 4096), dtype="float16"), weight: tvm.tir.decl_buffer((4096, 4096), dtype="float16"),
        }
    )

However the printed IR for the specialized version is misleading as it shows the dtype as "float16", which I understand is why this issue was created, so that probably has to be fixed.

quic-sanirudh avatar Feb 02 '24 14:02 quic-sanirudh