mlc-llm icon indicating copy to clipboard operation
mlc-llm copied to clipboard

[Bug] RWKV v6 models fail to compile with latest mlc_llm

Open MollySophia opened this issue 6 months ago • 1 comments

🐛 Bug

RWKV v6 models fail to compile with latest mlc_llm.

Edit: Also it seems that there's currently only rwkv v5 compiling test in ci. Should rwkv v6 be added in ci test too?

To Reproduce

Steps to reproduce the behavior:

  1. Get https://huggingface.co/RWKV/rwkv-6-world-1b6
  2. $ mlc_llm convert_weight rwkv-6-world-1b6 --quantization q4f16_1 -o rwkv-6-world-1b6-MLC
  3. $ mlc_llm gen_config rwkv-6-world-1b6 --quantization q4f16_1 --conv-template rwkv_world -o rwkv-6-world-1b6-MLC
  4. $ mlc_llm compile rwkv-6-world-1b6-MLC/mlc-chat-config.json --device metal --host arm64-apple-darwin -o rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so (or build with other hosts)
It gets the following error messages: (Click to expand)
$ mlc_llm compile rwkv-6-world-1b6-MLC/mlc-chat-config.json --device metal --host arm64-apple-darwin -o rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so
[2024-09-02 10:57:53] INFO auto_config.py:70: Found model configuration: rwkv-6-world-1b6-MLC/mlc-chat-config.json
[2024-09-02 10:57:54] INFO auto_device.py:79: Found device: metal:0
[2024-09-02 10:57:54] INFO auto_target.py:78: Found configuration of target device "metal:0": {"thread_warp_size": runtime.BoxInt(32), "max_threads_per_block": runtime.BoxInt(1024), "max_function_args": runtime.BoxInt(31), "max_num_threads": runtime.BoxInt(256), "kind": "metal", "max_shared_memory_per_block": runtime.BoxInt(32768), "tag": "", "keys": ["metal", "gpu"]}
[2024-09-02 10:57:54] INFO auto_target.py:114: Using LLVM triple specified by --host: arm64-apple-darwin
[2024-09-02 10:57:54] INFO auto_config.py:154: Found model type: rwkv6. Use `--model-type` to override.
Compiling with arguments:
  --config          RWKV6Config(hidden_size=2048, intermediate_size=7168, num_hidden_layers=24, vocab_size=65536, model_version='6_0', tensor_parallel_shards=1, rescale_every=6, head_size=64, layer_norm_epsilon=1e-05, context_window_size=-1, prefill_chunk_size=4096, num_heads=32, max_batch_size=80, kwargs={})
  --quantization    GroupQuantize(name='q4f16_1', kind='group-quant', group_size=32, quantize_dtype='int4', storage_dtype='uint32', model_dtype='float16', linear_weight_layout='NK', quantize_embedding=True, quantize_final_fc=True, num_elem_per_storage=8, num_storage_per_group=4, max_int_value=7, tensor_parallel_shards=0)
  --model-type      rwkv6
  --target          {"thread_warp_size": runtime.BoxInt(32), "host": {"kind": "llvm", "tag": "", "keys": ["arm_cpu", "cpu"], "mtriple": "arm64-apple-darwin"}, "max_threads_per_block": runtime.BoxInt(1024), "max_function_args": runtime.BoxInt(31), "max_num_threads": runtime.BoxInt(256), "kind": "metal", "max_shared_memory_per_block": runtime.BoxInt(32768), "tag": "", "keys": ["metal", "gpu"]}
  --opt             flashinfer=0;cublas_gemm=0;faster_transformer=0;cudagraph=0;cutlass=0;ipc_allreduce_strategy=NONE
  --system-lib-prefix ""
  --output          rwkv-6-world-1b6-MLC/libs/rwkv-6-world-1b6-MLC-q4f16-metal.so
  --overrides       context_window_size=None;sliding_window_size=None;prefill_chunk_size=None;attention_sink_size=None;max_batch_size=None;tensor_parallel_shards=None;pipeline_parallel_stages=None
[2024-09-02 10:57:54] INFO compile.py:140: Creating model from: RWKV6Config(hidden_size=2048, intermediate_size=7168, num_hidden_layers=24, vocab_size=65536, model_version='6_0', tensor_parallel_shards=1, rescale_every=6, head_size=64, layer_norm_epsilon=1e-05, context_window_size=-1, prefill_chunk_size=4096, num_heads=32, max_batch_size=80, kwargs={})
[2024-09-02 10:57:54] INFO compile.py:158: Exporting the model to TVM Unity compiler
[2024-09-02 10:57:57] INFO compile.py:164: Running optimizations using TVM Unity
[2024-09-02 10:57:57] INFO compile.py:185: Registering metadata: {'model_type': 'rwkv6', 'quantization': 'q4f16_1', 'context_window_size': -1, 'sliding_window_size': -1, 'attention_sink_size': -1, 'prefill_chunk_size': 4096, 'tensor_parallel_shards': 1, 'pipeline_parallel_stages': 1, 'kv_state_kind': 'rnn_state', 'max_batch_size': 80}
[2024-09-02 10:57:57] INFO pipeline.py:54: Running TVM Relax graph-level optimizations
[2024-09-02 10:57:59] INFO pipeline.py:54: Lowering to TVM TIR kernels
[2024-09-02 10:58:04] INFO pipeline.py:54: Running TVM TIR-level optimizations
[2024-09-02 10:58:22] INFO pipeline.py:54: Running TVM Dlight low-level optimizations
[2024-09-02 10:58:27] INFO pipeline.py:54: Lowering to VM bytecode
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `alloc_embedding_tensor`: 16.00 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_decode`: 106.57 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_prefill`: 293.50 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_verify`: 273.75 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `create_rnn_state`: 0.00 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `decode`: 1.32 MB
[2024-09-02 10:58:30] INFO estimate_memory_usage.py:58: [Memory usage] Function `embed`: 16.00 MB
[2024-09-02 10:58:31] INFO estimate_memory_usage.py:58: [Memory usage] Function `prefill`: 273.75 MB
[2024-09-02 10:58:31] INFO estimate_memory_usage.py:58: [Memory usage] Function `softmax_with_temperature`: 0.00 MB
[2024-09-02 10:58:32] INFO pipeline.py:54: Compiling external modules
[2024-09-02 10:58:32] INFO pipeline.py:54: Compilation complete! Exporting to disk
Traceback (most recent call last):
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/bin/mlc_llm", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/__main__.py", line 33, in main
    cli.main(sys.argv[2:])
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/cli/compile.py", line 129, in main
    compile(
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/interface/compile.py", line 243, in compile
    _compile(args, model_config)
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/interface/compile.py", line 188, in _compile
    args.build_func(
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/mlc_llm/support/auto_target.py", line 311, in build
    relax.build(
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/relax/vm_build.py", line 341, in build
    return _vmlink(
           ^^^^^^^^
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/relax/vm_build.py", line 247, in _vmlink
    lib = tvm.build(
          ^^^^^^^^^^
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/driver/build_module.py", line 297, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
  File "/Users/molly/miniconda3/envs/mlc-llm-latest/lib/python3.11/site-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  Did you forget to bind?
    Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `add768` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `compute` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `add768` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "/Users/catalyst/Workspace/mlc-ai-package-self-runner/_work/package/package/tvm/src/tir/analysis/verify_memory.cc", line 205
RuntimeError: Memory verification failed with the following errors:
# from tvm.script import tir as T

@T.prim_func
def fused_matmul4_tir_tanh2(p_add768: T.handle, model_blocks_0_attention_time_maa_w13: T.Buffer((2048, 160), "float16"), p_output0: T.handle):
    T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "arm64-apple-darwin", "tag": ""}, "keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
    seq_len = T.int32()
    add768 = T.match_buffer(p_add768, (1, seq_len, 2048), "float16")
    compute_intermediate = T.match_buffer(p_output0, (1, seq_len, 160), "float16")
    add768_pad = T.allocate([(seq_len + 1) // 2 * 4096], "float16", "global")
    matmul_intermediate_pad_rf = T.allocate([(seq_len + 1) // 2 * 10240], "float16", "global")
    matmul_intermediate_pad_rf_1 = T.allocate([(seq_len + 1) // 2 * 1280], "float16", "global")
    matmul_intermediate_pad = T.allocate([(seq_len + 1) // 2 * 320], "float16", "global")
    add768_pad_1 = T.allocate([(seq_len + 3) // 4 * 8192], "float16", "global")
    matmul_intermediate_pad_rf_2 = T.allocate([(seq_len + 3) // 4 * 20480], "float16", "global")
    matmul_intermediate_pad_rf_3 = T.allocate([(seq_len + 3) // 4 * 2560], "float16", "global")
    matmul_intermediate_pad_1 = T.allocate([(seq_len + 3) // 4 * 640], "float16", "global")
    add768_1 = T.Buffer((seq_len * 2048,), "float16", data=add768.data)
    model_blocks_0_attention_time_maa_w13_1 = T.Buffer((327680,), "float16", data=model_blocks_0_attention_time_maa_w13.data)
    compute_intermediate_1 = T.Buffer((seq_len * 160,), "float16", data=compute_intermediate.data)
    if T.tvm_thread_invariant(seq_len <= 2):
        add768_pad_2 = T.Buffer(((seq_len + 1) // 2 * 4096,), "float16", data=add768_pad)
        for ax0 in range((seq_len + 1) // 2 * 2):
            if ax0 < seq_len:
                for ax1 in range(2048):
                    cse_var_1: T.int32 = ax0 * 2048 + ax1
                    add768_pad_2[cse_var_1] = add768_1[cse_var_1]
            else:
                for ax1 in range(2048):
                    add768_pad_2[ax0 * 2048 + ax1] = T.float16(0.0)
        matmul_intermediate_pad_rf_4 = T.Buffer(((seq_len + 1) // 2 * 10240,), "float16", data=matmul_intermediate_pad_rf)
        with T.launch_thread("blockIdx.y", (seq_len + 1) // 2) as blockIdx_y:
            blockIdx_x = T.launch_thread("blockIdx.x", 3)
            threadIdx_x = T.launch_thread("threadIdx.x", 64)
            threadIdx_y = T.launch_thread("threadIdx.y", 4)
            for ax2_fused_0, ax2_fused_2, ax0_1, ax2_fused_1_ax2_fused_3_fused_1_0 in T.grid(16, 4, 2, 2):
                if blockIdx_x * 2 + threadIdx_x // 32 < 5:
                    if ax2_fused_0 == 0 and ax2_fused_2 == 0:
                        matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] = T.Broadcast(T.float16(0.0), 4)
                    matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] = matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 1) // 2) * 2560 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 1) // 2) * 1280 + blockIdx_y * 320 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 1) // 2 * 320, 4)] + add768_pad_2[blockIdx_y * 4096 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4:blockIdx_y * 4096 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4 + 4] * model_blocks_0_attention_time_maa_w13_1[ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x:ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x + 640:160]
        matmul_intermediate_pad_rf_5 = T.Buffer(((seq_len + 1) // 2 * 1280,), "float16", data=matmul_intermediate_pad_rf_1)
        for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0, ax2_fused_1_ax2_fused_3_fused_1 in T.grid((seq_len + 1) // 2, 3, 64, 2, 4, 8):
            cse_var_4: T.int32 = ax0_0 * 320
            cse_var_3: T.int32 = ax0_1 * 160
            cse_var_2: T.int32 = ax1_fused_0 * 64
            if ax2_fused_1_ax2_fused_3_fused_1 == 0:
                matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] = T.float16(0.0)
            matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] = matmul_intermediate_pad_rf_5[cse_var_4 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1] + matmul_intermediate_pad_rf_4[ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 2560 + cse_var_4 + ax2_fused_1_ax2_fused_3_fused_1 * ((seq_len + 1) // 2) * 320 + cse_var_3 + cse_var_2 + ax1_fused_1]
        matmul_intermediate_pad_2 = T.Buffer(((seq_len + 1) // 2 * 320,), "float16", data=matmul_intermediate_pad)
        for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0 in T.grid((seq_len + 1) // 2, 3, 64, 2, 4):
            cse_var_8: T.int32 = ax0_0 * 320
            cse_var_7: T.int32 = ax0_1 * 160
            cse_var_6: T.int32 = ax1_fused_0 * 64
            cse_var_5: T.int32 = cse_var_8 + cse_var_7 + cse_var_6 + ax1_fused_1
            if ax2_fused_1_ax2_fused_3_fused_0 == 0:
                matmul_intermediate_pad_2[cse_var_5] = T.float16(0.0)
            matmul_intermediate_pad_2[cse_var_5] = matmul_intermediate_pad_2[cse_var_5] + matmul_intermediate_pad_rf_5[cse_var_8 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 1) // 2) * 320 + cse_var_7 + cse_var_6 + ax1_fused_1]
        for ax0, ax1 in T.grid(seq_len, 160):
            cse_var_9: T.int32 = ax0 * 160 + ax1
            compute_intermediate_1[cse_var_9] = T.tanh(matmul_intermediate_pad_2[cse_var_9])
    else:
        if T.tvm_thread_invariant(seq_len <= 8):
            add768_pad_2 = T.Buffer(((seq_len + 3) // 4 * 8192,), "float16", data=add768_pad_1)
            for ax0 in range((seq_len + 3) // 4 * 4):
                if ax0 < seq_len:
                    for ax1 in range(2048):
                        cse_var_10: T.int32 = ax0 * 2048 + ax1
                        add768_pad_2[cse_var_10] = add768_1[cse_var_10]
                else:
                    for ax1 in range(2048):
                        add768_pad_2[ax0 * 2048 + ax1] = T.float16(0.0)
            matmul_intermediate_pad_rf_4 = T.Buffer(((seq_len + 3) // 4 * 20480,), "float16", data=matmul_intermediate_pad_rf_2)
            with T.launch_thread("blockIdx.y", (seq_len + 3) // 4) as blockIdx_y:
                blockIdx_x = T.launch_thread("blockIdx.x", 3)
                threadIdx_x = T.launch_thread("threadIdx.x", 64)
                threadIdx_y = T.launch_thread("threadIdx.y", 4)
                for ax2_fused_0, ax2_fused_2, ax0_1, ax2_fused_1_ax2_fused_3_fused_1_0 in T.grid(16, 4, 4, 2):
                    if blockIdx_x * 2 + threadIdx_x // 32 < 5:
                        if ax2_fused_0 == 0 and ax2_fused_2 == 0:
                            matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] = T.Broadcast(T.float16(0.0), 4)
                        matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] = matmul_intermediate_pad_rf_4[T.Ramp(threadIdx_y * ((seq_len + 3) // 4) * 5120 + ax2_fused_1_ax2_fused_3_fused_1_0 * ((seq_len + 3) // 4) * 2560 + blockIdx_y * 640 + ax0_1 * 160 + blockIdx_x * 64 + threadIdx_x, (seq_len + 3) // 4 * 640, 4)] + add768_pad_2[blockIdx_y * 8192 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4:blockIdx_y * 8192 + ax0_1 * 2048 + ax2_fused_0 * 128 + threadIdx_y * 32 + ax2_fused_2 * 8 + ax2_fused_1_ax2_fused_3_fused_1_0 * 4 + 4] * model_blocks_0_attention_time_maa_w13_1[ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x:ax2_fused_0 * 20480 + threadIdx_y * 5120 + ax2_fused_2 * 1280 + ax2_fused_1_ax2_fused_3_fused_1_0 * 640 + blockIdx_x * 64 + threadIdx_x + 640:160]
            matmul_intermediate_pad_rf_5 = T.Buffer(((seq_len + 3) // 4 * 2560,), "float16", data=matmul_intermediate_pad_rf_3)
            for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0, ax2_fused_1_ax2_fused_3_fused_1 in T.grid((seq_len + 3) // 4, 3, 64, 4, 4, 8):
                cse_var_13: T.int32 = ax0_0 * 640
                cse_var_12: T.int32 = ax0_1 * 160
                cse_var_11: T.int32 = ax1_fused_0 * 64
                if ax2_fused_1_ax2_fused_3_fused_1 == 0:
                    matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] = T.float16(0.0)
                matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] = matmul_intermediate_pad_rf_5[cse_var_13 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1] + matmul_intermediate_pad_rf_4[ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 5120 + cse_var_13 + ax2_fused_1_ax2_fused_3_fused_1 * ((seq_len + 3) // 4) * 640 + cse_var_12 + cse_var_11 + ax1_fused_1]
            matmul_intermediate_pad_2 = T.Buffer(((seq_len + 3) // 4 * 640,), "float16", data=matmul_intermediate_pad_1)
            for ax0_0, ax1_fused_0, ax1_fused_1, ax0_1, ax2_fused_1_ax2_fused_3_fused_0 in T.grid((seq_len + 3) // 4, 3, 64, 4, 4):
                cse_var_17: T.int32 = ax0_0 * 640
                cse_var_16: T.int32 = ax0_1 * 160
                cse_var_15: T.int32 = ax1_fused_0 * 64
                cse_var_14: T.int32 = cse_var_17 + cse_var_16 + cse_var_15 + ax1_fused_1
                if ax2_fused_1_ax2_fused_3_fused_0 == 0:
                    matmul_intermediate_pad_2[cse_var_14] = T.float16(0.0)
                matmul_intermediate_pad_2[cse_var_14] = matmul_intermediate_pad_2[cse_var_14] + matmul_intermediate_pad_rf_5[cse_var_17 + ax2_fused_1_ax2_fused_3_fused_0 * ((seq_len + 3) // 4) * 640 + cse_var_16 + cse_var_15 + ax1_fused_1]
            for ax0, ax1 in T.grid(seq_len, 160):
                cse_var_18: T.int32 = ax0 * 160 + ax1
                compute_intermediate_1[cse_var_18] = T.tanh(matmul_intermediate_pad_2[cse_var_18])
        else:
            blockIdx_z = T.launch_thread("blockIdx.z", 1)
            matmul_intermediate_reindex_pad_metal_simdgroup = T.allocate([256], "float16", "metal.simdgroup")
            add768_reindex_pad_shared = T.allocate([512], "float16", "shared")
            model_blocks_0_attention_time_maa_w13_reindex_pad_shared = T.allocate([2048], "float16", "shared")
            add768_reindex_pad_shared_metal_simdgroup = T.allocate([128], "float16", "metal.simdgroup")
            model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup = T.allocate([128], "float16", "metal.simdgroup")
            blockIdx_x = T.launch_thread("blockIdx.x", (seq_len + 15) // 16)
            blockIdx_y = T.launch_thread("blockIdx.y", 3)
            threadIdx_x = T.launch_thread("threadIdx.x", 32)
            threadIdx_y = T.launch_thread("threadIdx.y", 1)
            threadIdx_z = T.launch_thread("threadIdx.z", 4)
            for ax1_2_init, ax2_2_init in T.grid(2, 2):
                T.make_filled_simdgroup_matrix(matmul_intermediate_reindex_pad_metal_simdgroup, ax1_2_init * 2 + ax2_2_init, T.float32(0.0), 8, 8)
            for ax3_0 in range(64):
                add768_reindex_pad_shared_1 = T.Buffer((512,), "float16", data=add768_reindex_pad_shared, scope="shared")
                add768_reindex_pad_shared_1[threadIdx_z * 128 + threadIdx_x * 4:threadIdx_z * 128 + threadIdx_x * 4 + 4] = T.if_then_else(blockIdx_x * 16 + threadIdx_z * 4 + threadIdx_x // 8 < seq_len, add768_1[blockIdx_x * 32768 + threadIdx_z * 8192 + threadIdx_x // 8 * 2048 + ax3_0 * 32 + threadIdx_x % 8 * 4:blockIdx_x * 32768 + threadIdx_z * 8192 + threadIdx_x // 8 * 2048 + ax3_0 * 32 + threadIdx_x % 8 * 4 + 4], T.Broadcast(T.float16(0.0), 4))
                for ax1_ax2_fused_0 in range(4):
                    model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1 = T.Buffer((2048,), "float16", data=model_blocks_0_attention_time_maa_w13_reindex_pad_shared, scope="shared")
                    model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1[ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4:ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4 + 4] = T.if_then_else(blockIdx_y * 2 + ax1_ax2_fused_0 // 2 < 5, model_blocks_0_attention_time_maa_w13_1[ax3_0 * 5120 + threadIdx_x % 8 * 640 + blockIdx_y * 64 + ax1_ax2_fused_0 * 16 + threadIdx_z * 4 + threadIdx_x // 8:ax3_0 * 5120 + threadIdx_x % 8 * 640 + blockIdx_y * 64 + ax1_ax2_fused_0 * 16 + threadIdx_z * 4 + threadIdx_x // 8 + 640:160], T.Broadcast(T.float16(0.0), 4))
                for ax3_1 in range(4):
                    for ax0_0 in range(2):
                        T.simdgroup_load(add768_reindex_pad_shared_metal_simdgroup, ax0_0, T.tvm_access_ptr(T.type_annotation("float16"), add768_reindex_pad_shared, ax0_0 * 256 + ax3_1 * 8, 256, 1), 32, 8, 8, T.bool(False))
                    for ax0_0 in range(2):
                        T.simdgroup_load(model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup, ax0_0, T.tvm_access_ptr(T.type_annotation("float16"), model_blocks_0_attention_time_maa_w13_reindex_pad_shared, threadIdx_z * 512 + ax0_0 * 256 + ax3_1 * 8, 256, 1), 32, 8, 8, T.bool(True))
                    for ax1_2, ax2_2 in T.grid(2, 2):
                        cse_var_19: T.int32 = ax1_2 * 2 + ax2_2
                        T.simdgroup_multiply_accumulate(matmul_intermediate_reindex_pad_metal_simdgroup, cse_var_19, add768_reindex_pad_shared_metal_simdgroup, ax1_2, model_blocks_0_attention_time_maa_w13_reindex_pad_shared_metal_simdgroup, ax2_2, matmul_intermediate_reindex_pad_metal_simdgroup, cse_var_19)
            for ax1_0, ax2_0 in T.grid(2, 2):
                T.simdgroup_store(matmul_intermediate_reindex_pad_metal_simdgroup, ax1_0 * 2 + ax2_0, T.tvm_access_ptr(T.type_annotation("float16"), model_blocks_0_attention_time_maa_w13_reindex_pad_shared, ax1_0 * 512 + threadIdx_z * 16 + ax2_0 * 8, 512, 2), 64, 8, 8, T.bool(False))
            for ax1_ax2_fused_0 in range(2):
                if blockIdx_x * 16 + ax1_ax2_fused_0 * 8 + threadIdx_z * 2 + threadIdx_x // 16 < seq_len and blockIdx_y * 2 + threadIdx_x % 16 // 8 < 5:
                    model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1 = T.Buffer((1024,), "float16", data=model_blocks_0_attention_time_maa_w13_reindex_pad_shared, scope="shared")
                    compute_intermediate_1[blockIdx_x * 2560 + ax1_ax2_fused_0 * 1280 + threadIdx_z * 320 + threadIdx_x // 16 * 160 + blockIdx_y * 64 + threadIdx_x % 16 * 4:blockIdx_x * 2560 + ax1_ax2_fused_0 * 1280 + threadIdx_z * 320 + threadIdx_x // 16 * 160 + blockIdx_y * 64 + threadIdx_x % 16 * 4 + 4] = T.tanh(model_blocks_0_attention_time_maa_w13_reindex_pad_shared_1[ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4:ax1_ax2_fused_0 * 512 + threadIdx_z * 128 + threadIdx_x * 4 + 4])

Expected behavior

Model lib successfully compiles.

Environment

  • Platform: Metal; CUDA
  • Operating system: MacOS; ArchLinux
  • Device: MacBook Air m2; archlinux x86_64 with CUDA
  • How you installed MLC-LLM: python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly mlc-ai-nightly
  • How you installed TVM-Unity (pip, source): python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly
  • Python version (e.g. 3.10): 3.11
  • TVM Unity Hash Tag (python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))", applicable if you compile models): f65b73221a83b2a94c383c5d1b0bfd6d75c69800

MollySophia avatar Sep 02 '24 03:09 MollySophia