HalideIR icon indicating copy to clipboard operation
HalideIR copied to clipboard

substitute make expr more complicated

Open xqdan opened this issue 5 years ago • 1 comments

import tvm

def register_mem(scope_tb, max_bits):
    #Register mem
    @tvm.register_func("tvm.info.mem.%s" % scope_tb)
    def mem_info_inp_buffer():
        return tvm.make.node("MemoryInfo",
                        unit_bits= 16,
                        max_simd_bits=32,
                        max_num_bits=max_bits,
                        head_address=None)


def test():
    scope_tb = "local.L0v"
    max_bits = 1024 * 1024 * 1024

    ib = tvm.ir_builder.create()
    A = ib.allocate("int32", 200, name="A", scope=scope_tb)
    with ib.for_range(0, 10, name="i") as i:
        with ib.for_range(0, 10, name="j") as j:
            A[i*10+j] = 1

    B = ib.allocate("int32", 200, name="B", scope=scope_tb)
    with ib.for_range(0, 10, name="i") as i:
        with ib.for_range(0, 10, name="j") as j:
            with ib.if_scope(j == A[i]):
                B[i*10+j] = 2

    body = ib.get()
    print(tvm.ir_pass.Simplify(body))

test()

before

// attr [A] storage_scope = "local.L0v"
allocate A[int32 * 200]
for (i, 0, 10) {
  for (j, 0, 10) {
    A[((i*10) + j)] = 1
  }
}
// attr [B] storage_scope = "local.L0v"
allocate B[int32 * 200]
for (j, 0, 10) {
  for (j, 0, 10) {
    if ((j == A[j])) {
      B[((j*10) + j)] = 2
    }
  }
}

after, got B[((j*10) + A[j])], which is a more complicated expr

// attr [A] storage_scope = "local.L0v"
allocate A[int32 * 200]
for (i, 0, 10) {
  for (j, 0, 10) {
    A[((i*10) + j)] = 1
  }
}
// attr [B] storage_scope = "local.L0v"
allocate B[int32 * 200]
for (j, 0, 10) {
  for (j, 0, 10) {
    if ((j == A[j])) {
      B[((j*10) + A[j])] = 2
    }
  }
}

xqdan avatar May 08 '19 11:05 xqdan

diff --git a/src/arithmetic/Simplify.cpp b/src/arithmetic/Simplify.cpp
index 8a0d6e3..0dbb63e 100644
--- a/src/arithmetic/Simplify.cpp
+++ b/src/arithmetic/Simplify.cpp
@@ -3917,7 +3917,7 @@ private:
                 const Variable *var = eq ? eq->a.as<Variable>() : next.as<Variable>();
 
                 if (eq && var) {
-                    if (!or_chain) {
+                    if (!or_chain && is_const(eq->b)) {
                         then_case = substitute(var, eq->b, then_case);
                     }
                     if (!and_chain && eq->b.type().is_bool()) {

xqdan avatar May 08 '19 11:05 xqdan