tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] Data Type Mismatch (int64 vs int32) in T.match_buffer when Working with Scalar Buffers in TIR

Open Thrsu opened this issue 1 year ago • 0 comments

I encountered an issue when trying to build a TIR-based module. The following error occurs when using T.match_buffer for a scalar element:

File "/software/tvm/src/tir/transforms/lower_match_buffer.cc", line 222
TVMError: Check failed: arg.dtype() == value.dtype() (int64 vs. int32) : The data type mismatched: int64 vs. int32

If I replace the line C1 = T.match_buffer(C0[jj], ()) with direct assignment using C0[jj] = T.float32(0), the error goes away, and the code builds successfully without any issues.

Steps to reproduce

import tvm
from tvm import tir, relax
from tvm.script import ir as I
from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main():
        # with T.block("root"):
        C = T.alloc_buffer((128, 128))
        for i in range(128):
            with T.block(""):
                vi = T.axis.spatial(128, i)
                T.reads()
                T.writes(C[vi, 0:128])
                C0 = T.match_buffer(C[vi, 0:128], (128))
                for j in range(128):
                    with T.block(""):
                        jj = T.axis.spatial(128, j)
                        T.reads()
                        T.writes(C0[jj])
                        C1 = T.match_buffer(C0[jj], ())
                        C1[()] = T.float32(0)
mod = Module
ex = relax.build(mod, target='llvm')

CC @Lunderberg @tqchen @junrushao

Thrsu avatar Sep 19 '24 09:09 Thrsu