tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] Missing PackedFunc with Specific Transformation Sequence in Relax Module

Open Thrsu opened this issue 5 months ago • 2 comments

I encountered an issue while running a Relax module with a specific transformation sequence. Specifically, when FuseTIR() is applied once, the VM fails to find the PackedFunc fused_relax_nn_attention_cutlass_gv. However, when the FuseTIR() optimization is applied again before AllocateWorkspace(), the problem disappears.

Expected behavior

The script is expected to run successfully without errors.

Actual behavior

InternalError: Check failed: (func.defined()) is false: Error: Cannot find PackedFunc fused_relax_nn_attention_cutlass_gv in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in global Relax functions of the VM executable

Steps to reproduce

The following script reproduces the issue:

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def attention(q_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), k_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), v_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), T_transpose: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_transpose_1 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16")
        T_reshape = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
        T_transpose_2 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16")
        T_reshape_1 = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
        T_batch_matmul_NT = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
        T_divide = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
        T_softmax_maxelem = T.alloc_buffer((T.int64(512), T.int64(8)), "float16")
        T_softmax_exp = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
        T_softmax_expsum = T.alloc_buffer((T.int64(512), T.int64(8)), "float16")
        T_softmax_norm = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
        T_transpose_3 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16")
        T_reshape_2 = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
        T_batch_matmul_NN = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16")
        T_reshape_3 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)):
            with T.block("T_transpose"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(q_1[v_ax0, v_ax2, v_ax1, v_ax3])
                T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3])
                T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = q_1[v_ax0, v_ax2, v_ax1, v_ax3]
        for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)):
            with T.block("T_transpose_1"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(k_1[v_ax0, v_ax2, v_ax1, v_ax3])
                T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3])
                T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = k_1[v_ax0, v_ax2, v_ax1, v_ax3]
        for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
            with T.block("T_reshape_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)])
                T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2])
                T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
        for b, i, j, k in T.grid(T.int64(512), T.int64(8), T.int64(8), T.int64(8)):
            with T.block("T_batch_matmul_NT"):
                v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
                T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k])
                T.writes(T_batch_matmul_NT[v_b, v_i, v_j])
                T.block_attr({"layout_free_placeholders": [T_reshape_1]})
                with T.init():
                    T_batch_matmul_NT[v_b, v_i, v_j] = T.float16(0)
                T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k]
        for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
            with T.block("T_divide"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2])
                T.writes(T_divide[v_ax0, v_ax1, v_ax2])
                T_divide[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, v_ax1, v_ax2] / T.sqrt(T.float16(8))
        for i0, i1, k in T.grid(T.int64(512), T.int64(8), T.int64(8)):
            with T.block("T_softmax_maxelem"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(T_divide[v_i0, v_i1, v_k])
                T.writes(T_softmax_maxelem[v_i0, v_i1])
                with T.init():
                    T_softmax_maxelem[v_i0, v_i1] = T.float16(-65504)
                T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], T_divide[v_i0, v_i1, v_k])
        for i0, i1, i2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
            with T.block("T_softmax_exp"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(T_divide[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1])
                T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
                T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(T_divide[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
        for i0, i1, k in T.grid(T.int64(512), T.int64(8), T.int64(8)):
            with T.block("T_softmax_expsum"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(T_softmax_exp[v_i0, v_i1, v_k])
                T.writes(T_softmax_expsum[v_i0, v_i1])
                with T.init():
                    T_softmax_expsum[v_i0, v_i1] = T.float16(0)
                T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
        for i0, i1, i2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
            with T.block("T_softmax_norm"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1])
                T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
                T.block_attr({"axis": 2})
                T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)):
            with T.block("T_transpose_2"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(v_1[v_ax0, v_ax2, v_ax1, v_ax3])
                T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3])
                T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = v_1[v_ax0, v_ax2, v_ax1, v_ax3]
        for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)):
            with T.block("T_reshape_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_transpose_3[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)])
                T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2])
                T_reshape_2[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
        for b, i, j, k in T.grid(T.int64(512), T.int64(8), T.int64(8), T.int64(8)):
            with T.block("T_batch_matmul_NN"):
                v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
                T.reads(T_softmax_norm[v_b, v_i, v_k], T_reshape_2[v_b, v_k, v_j])
                T.writes(T_batch_matmul_NN[v_b, v_i, v_j])
                T.block_attr({"layout_free_placeholders": [T_reshape_2]})
                with T.init():
                    T_batch_matmul_NN[v_b, v_i, v_j] = T.float16(0)
                T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, v_i, v_j] + T_softmax_norm[v_b, v_i, v_k] * T_reshape_2[v_b, v_k, v_j]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)):
            with T.block("T_reshape_3"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(16) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(8) + v_ax1) % T.int64(512), (v_ax3 // T.int64(8) + v_ax2) % T.int64(8), v_ax3 % T.int64(8)])
                T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_batch_matmul_NN[(v_ax0 * T.int64(16) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(8) + v_ax1) % T.int64(512), (v_ax3 // T.int64(8) + v_ax2) % T.int64(8), v_ax3 % T.int64(8)]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(8), T.int64(16), T.int64(8)):
            with T.block("T_transpose_3"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(T_reshape_3[v_ax0, v_ax2, v_ax1, v_ax3])
                T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
                T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_3[v_ax0, v_ax2, v_ax1, v_ax3]

    @R.function
    def entry_b(q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
        cls = Module
        with R.dataflow():
            lv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass(q, k, v)
            R.output(lv)
        return lv

    @R.function
    def fused_relax_nn_attention_cutlass(q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
        R.func_attr({"Codegen": "cutlass", "WorkspaceSize": 65536})
        cls = Module
        
        @R.function
        def gv(q_1: R.Tensor((32, 8, 16, 8), dtype="float16"), k_1: R.Tensor((32, 8, 16, 8), dtype="float16"), v_1: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
            R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, "WorkspaceSize": 65536})
            with R.dataflow():
                gv_2 = R.call_tir(cls.attention, (q_1, k_1, v_1), out_sinfo=R.Tensor((32, 8, 16, 8), dtype="float16"))
                R.output(gv_2)
            return gv_2

        gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v)
        return gv1

mod = Module

# crash
mod = tvm.transform.Sequential([relax.transform.FuseTIR(), relax.transform.LambdaLift(), relax.transform.AllocateWorkspace()])(mod)

# pass
#mod = tvm.transform.Sequential([relax.transform.FuseTIR(), relax.transform.LambdaLift(), relax.transform.FuseTIR(), relax.transform.AllocateWorkspace()])(mod)

with tvm.transform.PassContext(opt_level=4):
    ex = relax.build(mod, target='llvm')
    vm = relax.VirtualMachine(ex, tvm.cpu())

Any guidance on whether this is a bug or a known order dependency would be greatly appreciated. @Lunderberg

Thrsu avatar Sep 10 '24 06:09 Thrsu