tvm
tvm copied to clipboard
[Bug] Missing PackedFunc with Specific Transformation Sequence in Relax Module
I encountered an issue while running a Relax module with a specific transformation sequence. Specifically, when FuseTIR() is applied once, the VM fails to find the PackedFunc fused_relax_nn_attention_cutlass_gv. However, when the FuseTIR() optimization is applied again before AllocateWorkspace(), the problem disappears.
Expected behavior
The script is expected to run successfully without errors.
Actual behavior
InternalError: Check failed: (func.defined()) is false: Error: Cannot find PackedFunc fused_relax_nn_attention_cutlass_gv in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in global Relax functions of the VM executable
Steps to reproduce
The following script reproduces the issue:
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def attention(q_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), k_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), v_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), T_transpose: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
T_transpose_1 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16")
T_reshape = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
T_transpose_2 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16")
T_reshape_1 = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
T_batch_matmul_NT = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
T_divide = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
T_softmax_maxelem = T.alloc_buffer((T.int64(512), T.int64(8)), "float16")
T_softmax_exp = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
T_softmax_expsum = T.alloc_buffer((T.int64(512), T.int64(8)), "float16")
T_softmax_norm = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
T_transpose_3 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16")
T_reshape_2 = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
T_batch_matmul_NN = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
T_reshape_3 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16")
for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)):
with T.block("T_transpose"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(q_1[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = q_1[v_ax0, v_ax2, v_ax1, v_ax3]
for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)):
with T.block("T_transpose_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(k_1[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = k_1[v_ax0, v_ax2, v_ax1, v_ax3]
for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
with T.block("T_reshape_1"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2])
T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
for b, i, j, k in T.grid(T.int64(512), T.int64(8), T.int64(8), T.int64(8)):
with T.block("T_batch_matmul_NT"):
v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k])
T.writes(T_batch_matmul_NT[v_b, v_i, v_j])
T.block_attr({"layout_free_placeholders": [T_reshape_1]})
with T.init():
T_batch_matmul_NT[v_b, v_i, v_j] = T.float16(0)
T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k]
for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
with T.block("T_divide"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2])
T.writes(T_divide[v_ax0, v_ax1, v_ax2])
T_divide[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, v_ax1, v_ax2] / T.sqrt(T.float16(8))
for i0, i1, k in T.grid(T.int64(512), T.int64(8), T.int64(8)):
with T.block("T_softmax_maxelem"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(T_divide[v_i0, v_i1, v_k])
T.writes(T_softmax_maxelem[v_i0, v_i1])
with T.init():
T_softmax_maxelem[v_i0, v_i1] = T.float16(-65504)
T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], T_divide[v_i0, v_i1, v_k])
for i0, i1, i2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
with T.block("T_softmax_exp"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(T_divide[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1])
T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(T_divide[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
for i0, i1, k in T.grid(T.int64(512), T.int64(8), T.int64(8)):
with T.block("T_softmax_expsum"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(T_softmax_exp[v_i0, v_i1, v_k])
T.writes(T_softmax_expsum[v_i0, v_i1])
with T.init():
T_softmax_expsum[v_i0, v_i1] = T.float16(0)
T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
for i0, i1, i2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
with T.block("T_softmax_norm"):
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1])
T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
T.block_attr({"axis": 2})
T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)):
with T.block("T_transpose_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(v_1[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = v_1[v_ax0, v_ax2, v_ax1, v_ax3]
for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
with T.block("T_reshape_2"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(T_transpose_3[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)])
T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2])
T_reshape_2[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
for b, i, j, k in T.grid(T.int64(512), T.int64(8), T.int64(8), T.int64(8)):
with T.block("T_batch_matmul_NN"):
v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
T.reads(T_softmax_norm[v_b, v_i, v_k], T_reshape_2[v_b, v_k, v_j])
T.writes(T_batch_matmul_NN[v_b, v_i, v_j])
T.block_attr({"layout_free_placeholders": [T_reshape_2]})
with T.init():
T_batch_matmul_NN[v_b, v_i, v_j] = T.float16(0)
T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, v_i, v_j] + T_softmax_norm[v_b, v_i, v_k] * T_reshape_2[v_b, v_k, v_j]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)):
with T.block("T_reshape_3"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(16) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(8) + v_ax1) % T.int64(512), (v_ax3 // T.int64(8) + v_ax2) % T.int64(8), v_ax3 % T.int64(8)])
T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_batch_matmul_NN[(v_ax0 * T.int64(16) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(8) + v_ax1) % T.int64(512), (v_ax3 // T.int64(8) + v_ax2) % T.int64(8), v_ax3 % T.int64(8)]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(8), T.int64(16), T.int64(8)):
with T.block("T_transpose_3"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(T_reshape_3[v_ax0, v_ax2, v_ax1, v_ax3])
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_3[v_ax0, v_ax2, v_ax1, v_ax3]
@R.function
def entry_b(q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
cls = Module
with R.dataflow():
lv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass(q, k, v)
R.output(lv)
return lv
@R.function
def fused_relax_nn_attention_cutlass(q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
R.func_attr({"Codegen": "cutlass", "WorkspaceSize": 65536})
cls = Module
@R.function
def gv(q_1: R.Tensor((32, 8, 16, 8), dtype="float16"), k_1: R.Tensor((32, 8, 16, 8), dtype="float16"), v_1: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, "WorkspaceSize": 65536})
with R.dataflow():
gv_2 = R.call_tir(cls.attention, (q_1, k_1, v_1), out_sinfo=R.Tensor((32, 8, 16, 8), dtype="float16"))
R.output(gv_2)
return gv_2
gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v)
return gv1
mod = Module
# crash
mod = tvm.transform.Sequential([relax.transform.FuseTIR(), relax.transform.LambdaLift(), relax.transform.AllocateWorkspace()])(mod)
# pass
#mod = tvm.transform.Sequential([relax.transform.FuseTIR(), relax.transform.LambdaLift(), relax.transform.FuseTIR(), relax.transform.AllocateWorkspace()])(mod)
with tvm.transform.PassContext(opt_level=4):
ex = relax.build(mod, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())
Any guidance on whether this is a bug or a known order dependency would be greatly appreciated. @Lunderberg
This is a bit of a bug and a bit of an ordering dependency.
- The
LambdaLiftpass extracts local lambda functions into the module. However,FuseOpsandFuseOpsByPatternuse local lambda functions to represent functions that will be replaced with specific kernel invocations. - The
AllocateWorkspacepass adds a newworkspaceparameter to all top-level functions that have a"WorkspaceSize"attribute, and updates all other functions to provide the new workspace However, if there is a call from one function with a"WorkspaceSize"attribute to another such function, it gets left with a danglingGlobalVar. - The
FuseTIRpass removes the Relax function altogether, replacing it with a PrimFunc, avoiding the issue altogether. It only inspectsmod->functions, and not local lambda functions, which is why it only had an effect afterLambdaLift.
There's a couple of options for short-term fixes, and a couple of options for long-term fixes.
-
Short-term
- After calling
FuseOpsByPatternandrelax.backend.contrib.cutlass.annotate_workspace, immediately callAllocateWorkspace. This ensures that the annotations are correct at the time when they are used. - After calling
AllocateWorkspace, immediately callrelax.transform.RunCodegen. This ensures that the local lambda function is present when the cutlass codegen looks for it.
- After calling
-
Medium-term
- Update
LambdaLiftto ignore lambda functions that have thetvm::relax::attr::kPrimitiveattribute. This would prevent it from lifting out a function that is intended for use byRunCodegen. This would work, but would be additional cross-talk between otherwise unrelated transforms. - Update the way in which
AllocateWorkspacelocates functions to be updated, with a top-down approach rather than bottom-up. Instead of first updating the functions that require a workspace and then updating their callers, it would start at the externally-exposed functions, and walk along the Relax call graph to find callees that should be updated. This would ensure that all caller/callee pairs are updated at the same time, preventing dangling pointers.
- Update
-
Long-term
- Update
FuseOpsByPatternto include the workspace. AfterFuseOpsByPattern, the workspace would be expressed explicitly as an allocation, and the"WorkspaceSize"attribute would never be generated. The various workspaces would then be replaced by a single workspace in theStaticPlanBlockMemorypass.
- Update
Unfortunately, I don't have time to implement the medium/long term solutions at the moment, but could help guide somebody in their implementation if there's interest.
Thank you very much for your thorough analysis and explanation of the root cause of the bug, as well as the detailed guidance on how to address it. Unfortunately, I'm not too familiar with the relax source code, which means I might struggle with submitting a PR to fix this myself. I do hope someone with the right expertise and interest can pick this up.
Thanks again for all your help, and I'm looking forward to seeing this issue tackled by the community!