HalideIR
HalideIR copied to clipboard
substitute make expr more complicated
import tvm
def register_mem(scope_tb, max_bits):
#Register mem
@tvm.register_func("tvm.info.mem.%s" % scope_tb)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits= 16,
max_simd_bits=32,
max_num_bits=max_bits,
head_address=None)
def test():
scope_tb = "local.L0v"
max_bits = 1024 * 1024 * 1024
ib = tvm.ir_builder.create()
A = ib.allocate("int32", 200, name="A", scope=scope_tb)
with ib.for_range(0, 10, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
A[i*10+j] = 1
B = ib.allocate("int32", 200, name="B", scope=scope_tb)
with ib.for_range(0, 10, name="i") as i:
with ib.for_range(0, 10, name="j") as j:
with ib.if_scope(j == A[i]):
B[i*10+j] = 2
body = ib.get()
print(tvm.ir_pass.Simplify(body))
test()
before
// attr [A] storage_scope = "local.L0v"
allocate A[int32 * 200]
for (i, 0, 10) {
for (j, 0, 10) {
A[((i*10) + j)] = 1
}
}
// attr [B] storage_scope = "local.L0v"
allocate B[int32 * 200]
for (j, 0, 10) {
for (j, 0, 10) {
if ((j == A[j])) {
B[((j*10) + j)] = 2
}
}
}
after, got B[((j*10) + A[j])], which is a more complicated expr
// attr [A] storage_scope = "local.L0v"
allocate A[int32 * 200]
for (i, 0, 10) {
for (j, 0, 10) {
A[((i*10) + j)] = 1
}
}
// attr [B] storage_scope = "local.L0v"
allocate B[int32 * 200]
for (j, 0, 10) {
for (j, 0, 10) {
if ((j == A[j])) {
B[((j*10) + A[j])] = 2
}
}
}
diff --git a/src/arithmetic/Simplify.cpp b/src/arithmetic/Simplify.cpp
index 8a0d6e3..0dbb63e 100644
--- a/src/arithmetic/Simplify.cpp
+++ b/src/arithmetic/Simplify.cpp
@@ -3917,7 +3917,7 @@ private:
const Variable *var = eq ? eq->a.as<Variable>() : next.as<Variable>();
if (eq && var) {
- if (!or_chain) {
+ if (!or_chain && is_const(eq->b)) {
then_case = substitute(var, eq->b, then_case);
}
if (!and_chain && eq->b.type().is_bool()) {