mlc-llm
mlc-llm copied to clipboard
[Bug] RWKV v6 models fail to compile with latest mlc_llm
🐛 Bug
RWKV v6 models fail to compile with latest mlc_llm.
Edit: Also it seems that there's currently only rwkv v5 compiling test in ci. Should rwkv v6 be added in ci test too?
To Reproduce
Steps to reproduce the behavior:
- Get https://huggingface.co/RWKV/rwkv-6-world-1b6
-
$ mlc_llm convert_weight rwkv-6-world-1b6 --quantization q4f16_1 -o rwkv-6-world-1b6-MLC
-
$ mlc_llm gen_config rwkv-6-world-1b6 --quantization q4f16_1 --conv-template rwkv_world -o rwkv-6-world-1b6-MLC
-
$ mlc_llm compile rwkv-6-world-1b6-MLC/mlc-chat-config.json --device metal --host arm64-apple-darwin -o rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so
(or build with other hosts)
It gets the following error messages: (Click to expand)
$ mlc_llm compile rwkv-6-world-1b6-MLC/mlc-chat-config.json --device metal --host arm64-apple-darwin -o rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so
[2024-09-02 10:57:53] INFO auto_config.py:70: Found model configuration: rwkv-6-world-1b6-MLC/mlc-chat-config.json
[2024-09-02 10:57:54] INFO auto_device.py:79: Found device: metal:0
[2024-09-02 10:57:54] INFO auto_target.py:78: Found configuration of target device "metal:0": {"thread_warp_size": runtime.BoxInt(32), "max_threads_per_block": runtime.BoxInt(1024), "max_function_args": runtime.BoxInt(31), "max_num_threads": runtime.BoxInt(256), "kind": "metal", "max_shared_memory_per_block": runtime.BoxInt(32768), "tag": "", "keys": ["metal", "gpu"]}
[2024-09-02 10:57:54] INFO auto_target.py:114: Using LLVM triple specified by --host: arm64-apple-darwin
[2024-09-02 10:57:54] INFO auto_config.py:154: Found model type: rwkv6. Use `--model-type` to override.
Compiling with arguments:
--config RWKV6Config(hidden_size=2048, intermediate_size=7168, num_hidden_layers=24, vocab_size=65536, model_version='6_0', tensor_parallel_shards=1, rescale_every=6, head_size=64, layer_norm_epsilon=1e-05, context_window_size=-1, prefill_chunk_size=4096, num_heads=32, max_batch_size=80, kwargs={})
--quantization GroupQuantize(name='q4f16_1', kind='group-quant', group_size=32, quantize_dtype='int4', storage_dtype='uint32', model_dtype='float16', linear_weight_layout='NK', quantize_embedding=True, quantize_final_fc=True, num_elem_per_storage=8, num_storage_per_group=4, max_int_value=7, tensor_parallel_shards=0)
--model-type rwkv6
--target {"thread_warp_size": runtime.BoxInt(32), "host": {"kind": "llvm", "tag": "", "keys": ["arm_cpu", "cpu"], "mtriple": "arm64-apple-darwin"}, "max_threads_per_block": runtime.BoxInt(1024), "max_function_args": runtime.BoxInt(31), "max_num_threads": runtime.BoxInt(256), "kind": "metal", "max_shared_memory_per_block": runtime.BoxInt(32768), "tag": "", "keys": ["metal", "gpu"]}
--opt flashinfer=0;cublas_gemm=0;faster_transformer=0;cudagraph=0;cutlass=0;ipc_allreduce_strategy=NONE
--system-lib-prefix ""
--output rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so
--overrides context_window_size=None;sliding_window_size=None;prefill_chunk_size=None;attention_sink_size=None;max_batch_size=None;tensor_parallel_shards=None;pipeline_parallel_stages=None
[2024-09-02 10:57:54] INFO compile.py:140: Creating model from: RWKV6Config(hidden_size=2048, intermediate_size=7168, num_hidden_layers=24, vocab_size=65536, model_version='6_0', tensor_parallel_shards=1, rescale_every=6, head_size=64, layer_norm_epsilon=1e-05, context_window_size=-1, prefill_chunk_size=4096, num_heads=32, max_batch_size=80, kwargs={})
[2024-09-02 10:57:54] INFO compile.py:158: Exporting the model to TVM Unity compiler
[2024-09-02 10:57:57] INFO compile.py:164: Running optimizations using TVM Unity
[2024-09-02 10:57:57] INFO compile.py:185: Registering metadata: {'model_type': 'rwkv6', 'quantization': 'q4f16_1', 'context_window_size': -1, 'sliding_window_size': -1, 'attention_sink_size': -1, 'prefill_chunk_size': 4096, 'tensor_parallel_shards': 1, 'pipeline_parallel_stages': 1, 'kv_state_kind': 'rnn_state', 'max_batch_size': 80}
[2024-09-02 10:57:57] INFO pipeline.py:54: Running TVM Relax graph-level optimizations
[2024-09-02 10:57:59] INFO pipeline.py:54: Lowering to TVM TIR kernels
[2024-09-02 10:58:04] INFO pipeline.py:54: Running TVM TIR-level optimizations
[2024-09-02 10:58:22] INFO pipeline.py:54: Running TVM Dlight low-level optimizations
[2024-09-02 10:58:27] INFO pipeline.py:54: Lowering to VM bytecode
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `alloc_embedding_tensor`: 16.00 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_decode`: 106.57 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_prefill`: 293.50 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_verify`: 273.75 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `create_rnn_state`: 0.00 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `decode`: 1.32 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `embed`: 16.00 MB
[2024-09-02 10:58:31] INFO estimate_memory_usage.py:58: [Memory usage] Function `prefill`: 273.75 MB
[2024-09-02 10:58:31] INFO estimate_memory_usage.py:58: [Memory usage] Function `softmax_with_temperature`: 0.00 MB
[2024-09-02 10:58:32] INFO pipeline.py:54: Compiling external modules
[2024-09-02 10:58:32] INFO pipeline.py:54: Compilation complete! Exporting to disk
Traceback (most recent call last):
File "/Users/molly/miniconda3/envs/mlc-llm-latest/bin/mlc_llm", line 8, in <module>
sys.exit(main())
^^^^^^
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/__main__.py", line 33, in main
cli.main(sys.argv[2:])
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/cli/compile.py", line 129, in main
compile(
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/interface/compile.py", line 243, in compile
_compile(args, model_config)
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/interface/compile.py", line 188, in _compile
args.build_func(
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/support/auto_target.py", line 311, in build
relax.build(
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/relax/vm_build.py", line 341, in build
return _vmlink(
^^^^^^^^
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/relax/vm_build.py", line 247, in _vmlink
lib = tvm.build(
^^^^^^^^^^
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/driver/build_module.py", line 297, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
Did you forget to bind?
Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `add768` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `add768` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
File "/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T
@T.prim_func
def fused_matmul4_tir_tanh2(p_add768: T.handle, model_blocks_0_attention_time_maa_w13: T.Buffer((2048, 160), "float16"), p_output0: T.handle):
T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "arm64-apple-darwin", "tag": ""}, "keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
seq_len = T.int32()
add768 = T.match_buffer(p_add768, (1, seq_len, 2048), "float16")
compute_intermediate = T.match_buffer(p_output0, (1, seq_len, 160), "float16")
add768_pad = T.allocate([(seq_len + 1) // 2 * 4096], "float16", "global")
matmul_intermediate_pad_rf = T.allocate([(seq_len + 1) // 2 * 10240], "float16", "global")
matmul_intermediate_pad_rf_1 = T.allocate([(seq_len + 1) // 2 * 1280], "float16", "global")
matmul_intermediate_pad = T.allocate([(seq_len + 1) // 2 * 320], "float16", "global")
add768_pad_1 = T.allocate([(seq_len + 3) // 4 * 8192], "float16", "global")
matmul_intermediate_pad_rf_2 = T.allocate([(seq_len + 3) // 4 * 20480], "float16", "global")
matmul_intermediate_pad_rf_3 = T.allocate([(seq_len + 3) // 4 * 2560], "float16", "global")
matmul_intermediate_pad_1 = T.allocate([(seq_len + 3) // 4 * 640], "float16", "global")
add768_1 = T.Buffer((seq_len * 2048,), "float16", data=add768.data)
model_blocks_0_attention_time_maa_w13_1 = T.Buffer((327680,), "float16", data=model_blocks_0_attention_time_maa_w13.data)
compute_intermediate_1 = T.Buffer((seq_len * 160,), "float16", data=compute_intermediate.data)
if T.tvm_thread_invariant(seq_len <= 2):
add768_pad_2 = T.Buffer(((seq_len + 1) // 2 * 4096,), "float16", data=add768_pad)
for ax0 in range((seq_len + 1) // 2 * 2):
if ax0 < seq_len:
for ax1 in range(2048):
cse_var_1: T.int32 = ax0 * 2048 + ax1
add768_pad_2[cse_var_1] = add768_1[cse_var_1]
else:
for ax1 in range(2048):
add768_pad_2[ax0 * 2048 + ax1] = T.float16(0.0)
matmul_intermediate_pad_rf_4 = T.Buffer(((seq_len + 1) // 2 * 10240,), "float16", data=matmul_intermediate_pad_rf)
with T.launch_thread("blockIdx.y", (seq_len + 1) // 2) as blockIdx_y:
blockIdx_x = T.launch_thread("blockIdx.x", 3)
threadIdx_x = T.launch_thread("threadIdx.x", 64)
threadIdx_y = T.launch_thread("threadIdx.y", 4)
for ax2_fused_0, ax2_fused_2, ax0_1, ax2_fused_1_ax2_fused_3_fused_1_0 in T.grid(16, 4, 2, 2):
if blockIdx_x * 2 + threadIdx_x // 32 < 5:
if ax2_fused_0 == 0 and ax2_fused_2 == 0:
matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] = T.Broadcast(T.float16(0.0), 4)
matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] = matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] + add768_pad_2[blockIdx_y * 4096 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4:blockIdx_y * 4096 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4 + 4] * model_blocks_0_attention_time_maa_w13_1[ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x:ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x + 640:160]
matmul_intermediate_pad_rf_5 = T.Buffer(((seq_len + 1) // 2 * 1280,), "float16", data=matmul_intermediate_pad_rf_1)
for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0, ax2_fused_1_ax2_fused_3_fused_1 in T.grid((seq_len + 1) // 2, 3, 64, 2, 4, 8):
cse_var_4: T.int32 = ax0_0 * 320
cse_var_3: T.int32 = ax0_1 * 160
cse_var_2: T.int32 = ax1_fused_0 * 64
if ax2_fused_1_ax2_fused_3_fused_1 == 0:
matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] = T.float16(0.0)
matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] = matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] + matmul_intermediate_pad_rf_4[ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 2560 + cse_var_4 + ax2_fused_1_ax2_fused_3_fused_1 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1]
matmul_intermediate_pad_2 = T.Buffer(((seq_len + 1) // 2 * 320,), "float16", data=matmul_intermediate_pad)
for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0 in T.grid((seq_len + 1) // 2, 3, 64, 2, 4):
cse_var_8: T.int32 = ax0_0 * 320
cse_var_7: T.int32 = ax0_1 * 160
cse_var_6: T.int32 = ax1_fused_0 * 64
cse_var_5: T.int32 = cse_var_8 + cse_var_7 + cse_var_6 + ax1_fused_1
if ax2_fused_1_ax2_fused_3_fused_0 == 0:
matmul_intermediate_pad_2[cse_var_5] = T.float16(0.0)
matmul_intermediate_pad_2[cse_var_5] = matmul_intermediate_pad_2[cse_var_5] + matmul_intermediate_pad_rf_5[cse_var_8 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_7 + cse_var_6 + ax1_fused_1]
for ax0, ax1 in T.grid(seq_len, 160):
cse_var_9: T.int32 = ax0 * 160 + ax1
compute_intermediate_1[cse_var_9] = T.tanh(matmul_intermediate_pad_2[cse_var_9])
else:
if T.tvm_thread_invariant(seq_len <= 8):
add768_pad_2 = T.Buffer(((seq_len + 3) // 4 * 8192,), "float16", data=add768_pad_1)
for ax0 in range((seq_len + 3) // 4 * 4):
if ax0 < seq_len:
for ax1 in range(2048):
cse_var_10: T.int32 = ax0 * 2048 + ax1
add768_pad_2[cse_var_10] = add768_1[cse_var_10]
else:
for ax1 in range(2048):
add768_pad_2[ax0 * 2048 + ax1] = T.float16(0.0)
matmul_intermediate_pad_rf_4 = T.Buffer(((seq_len + 3) // 4 * 20480,), "float16", data=matmul_intermediate_pad_rf_2)
with T.launch_thread("blockIdx.y", (seq_len + 3) // 4) as blockIdx_y:
blockIdx_x = T.launch_thread("blockIdx.x", 3)
threadIdx_x = T.launch_thread("threadIdx.x", 64)
threadIdx_y = T.launch_thread("threadIdx.y", 4)
for ax2_fused_0, ax2_fused_2, ax0_1, ax2_fused_1_ax2_fused_3_fused_1_0 in T.grid(16, 4, 4, 2):
if blockIdx_x * 2 + threadIdx_x // 32 < 5:
if ax2_fused_0 == 0 and ax2_fused_2 == 0:
matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] = T.Broadcast(T.float16(0.0), 4)
matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] = matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] + add768_pad_2[blockIdx_y * 8192 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4:blockIdx_y * 8192 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4 + 4] * model_blocks_0_attention_time_maa_w13_1[ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x:ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x + 640:160]
matmul_intermediate_pad_rf_5 = T.Buffer(((seq_len + 3) // 4 * 2560,), "float16", data=matmul_intermediate_pad_rf_3)
for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0, ax2_fused_1_ax2_fused_3_fused_1 in T.grid((seq_len + 3) // 4, 3, 64, 4, 4, 8):
cse_var_13: T.int32 = ax0_0 * 640
cse_var_12: T.int32 = ax0_1 * 160
cse_var_11: T.int32 = ax1_fused_0 * 64
if ax2_fused_1_ax2_fused_3_fused_1 == 0:
matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] = T.float16(0.0)
matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] = matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] + matmul_intermediate_pad_rf_4[ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 5120 + cse_var_13 + ax2_fused_1_ax2_fused_3_fused_1 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1]
matmul_intermediate_pad_2 = T.Buffer(((seq_len + 3) // 4 * 640,), "float16", data=matmul_intermediate_pad_1)
for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0 in T.grid((seq_len + 3) // 4, 3, 64, 4, 4):
cse_var_17: T.int32 = ax0_0 * 640
cse_var_16: T.int32 = ax0_1 * 160
cse_var_15: T.int32 = ax1_fused_0 * 64
cse_var_14: T.int32 = cse_var_17 + cse_var_16 + cse_var_15 + ax1_fused_1
if ax2_fused_1_ax2_fused_3_fused_0 == 0:
matmul_intermediate_pad_2[cse_var_14] = T.float16(0.0)
matmul_intermediate_pad_2[cse_var_14] = matmul_intermediate_pad_2[cse_var_14] + matmul_intermediate_pad_rf_5[cse_var_17 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_16 + cse_var_15 + ax1_fused_1]
for ax0, ax1 in T.grid(seq_len, 160):
cse_var_18: T.int32 = ax0 * 160 + ax1
compute_intermediate_1[cse_var_18] = T.tanh(matmul_intermediate_pad_2[cse_var_18])
else:
blockIdx_z = T.launch_thread("blockIdx.z", 1)
matmul_intermediate_reindex_pad_metal_simdgroup = T.allocate([256], "float16", "metal.simdgroup")
add768_reindex_pad_shared = T.allocate([512], "float16", "shared")
model_blocks_0_attention_time_maa_w13_reindex_pad_shared = T.allocate([2048], "float16", "shared")
add768_reindex_pad_shared_metal_simdgroup = T.allocate([128], "float16", "metal.simdgroup")
model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup = T.allocate([128], "float16", "metal.simdgroup")
blockIdx_x = T.launch_thread("blockIdx.x", (seq_len + 15) // 16)
blockIdx_y = T.launch_thread("blockIdx.y", 3)
threadIdx_x = T.launch_thread("threadIdx.x", 32)
threadIdx_y = T.launch_thread("threadIdx.y", 1)
threadIdx_z = T.launch_thread("threadIdx.z", 4)
for ax1_2_init, ax2_2_init in T.grid(2, 2):
T.make_filled_simdgroup_matrix(matmul_intermediate_reindex_pad_metal_simdgroup, ax1_2_init * 2 + ax2_2_init, T.float32(0.0), 8, 8)
for ax3_0 in range(64):
add768_reindex_pad_shared_1 = T.Buffer((512,), "float16", data=add768_reindex_pad_shared, scope="shared")
add768_reindex_pad_shared_1[threadIdx_z * 128 + threadIdx_x * 4:threadIdx_z * 128 + threadIdx_x * 4 + 4] = T.if_then_else(blockIdx_x * 16 + threadIdx_z * 4 + threadIdx_x // 8 < seq_len, add768_1[blockIdx_x * 32768 + threadIdx_z * 8192 + threadIdx_x // 8 * 2048 + ax3_0 * 32 + threadIdx_x % 8 * 4:blockIdx_x * 32768 + threadIdx_z * 8192 + threadIdx_x // 8 * 2048 + ax3_0 * 32 + threadIdx_x % 8 * 4 + 4], T.Broadcast(T.float16(0.0), 4))
for ax1_ax2_fused_0 in range(4):
model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1 = T.Buffer((2048,), "float16", data=model_blocks_0_attention_time_maa_w13_reindex_pad_shared, scope="shared")
model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1[ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4:ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4 + 4] = T.if_then_else(blockIdx_y * 2 + ax1_ax2_fused_0 // 2 < 5, model_blocks_0_attention_time_maa_w13_1[ax3_0 * 5120 + threadIdx_x % 8 * 640 + blockIdx_y * 64 + ax1_ax2_fused_0 * 16 + threadIdx_z * 4 + threadIdx_x // 8:ax3_0 * 5120 + threadIdx_x % 8 * 640 + blockIdx_y * 64 + ax1_ax2_fused_0 * 16 + threadIdx_z * 4 + threadIdx_x // 8 + 640:160], T.Broadcast(T.float16(0.0), 4))
for ax3_1 in range(4):
for ax0_0 in range(2):
T.simdgroup_load(add768_reindex_pad_shared_metal_simdgroup, ax0_0, T.tvm_access_ptr(T.type_annotation("float16"), add768_reindex_pad_shared, ax0_0 * 256 + ax3_1 * 8, 256, 1), 32, 8, 8, T.bool(False))
for ax0_0 in range(2):
T.simdgroup_load(model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup, ax0_0, T.tvm_access_ptr(T.type_annotation("float16"), model_blocks_0_attention_time_maa_w13_reindex_pad_shared, threadIdx_z * 512 + ax0_0 * 256 + ax3_1 * 8, 256, 1), 32, 8, 8, T.bool(True))
for ax1_2, ax2_2 in T.grid(2, 2):
cse_var_19: T.int32 = ax1_2 * 2 + ax2_2
T.simdgroup_multiply_accumulate(matmul_intermediate_reindex_pad_metal_simdgroup, cse_var_19, add768_reindex_pad_shared_metal_simdgroup, ax1_2, model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup, ax2_2, matmul_intermediate_reindex_pad_metal_simdgroup, cse_var_19)
for ax1_0, ax2_0 in T.grid(2, 2):
T.simdgroup_store(matmul_intermediate_reindex_pad_metal_simdgroup, ax1_0 * 2 + ax2_0, T.tvm_access_ptr(T.type_annotation("float16"), model_blocks_0_attention_time_maa_w13_reindex_pad_shared, ax1_0 * 512 + threadIdx_z * 16 + ax2_0 * 8, 512, 2), 64, 8, 8, T.bool(False))
for ax1_ax2_fused_0 in range(2):
if blockIdx_x * 16 + ax1_ax2_fused_0 * 8 + threadIdx_z * 2 + threadIdx_x // 16 < seq_len and blockIdx_y * 2 + threadIdx_x % 16 // 8 < 5:
model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1 = T.Buffer((1024,), "float16", data=model_blocks_0_attention_time_maa_w13_reindex_pad_shared, scope="shared")
compute_intermediate_1[blockIdx_x * 2560 + ax1_ax2_fused_0 * 1280 + threadIdx_z * 320 + threadIdx_x // 16 * 160 + blockIdx_y * 64 + threadIdx_x % 16 * 4:blockIdx_x * 2560 + ax1_ax2_fused_0 * 1280 + threadIdx_z * 320 + threadIdx_x // 16 * 160 + blockIdx_y * 64 + threadIdx_x % 16 * 4 + 4] = T.tanh(model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1[ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4:ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4 + 4])
Expected behavior
Model lib successfully compiles.
Environment
- Platform: Metal; CUDA
- Operating system: MacOS; ArchLinux
- Device: MacBook Air m2; archlinux x86_64 with CUDA
- How you installed MLC-LLM:
python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly
- How you installed TVM-Unity (
pip
, source):python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly
- Python version (e.g. 3.10): 3.11
- TVM Unity Hash Tag (
python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))"
, applicable if you compile models): f65b73221a83b2a94c383c5d1b0bfd6d75c69800