mlc-llm
mlc-llm copied to clipboard
[Bug] Check failed: (args.size() == initial_indices_orig.size()) is false
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
- I built a model with op conv2d
- one of the calculation graphs is : permute dim--> conv2d-->layernorm
- I encountered the following problems during compilation.
i think this problem is caused by the fusion of permute and conv operators after dl.gpu.Matmul(), resulting in a mismatch between buffer shape and index_map shape.
1、error log
tvm.error.InternalError: Traceback (most recent call last): 4: operator() at /workspace/tvm-unity/src/tir/schedule/schedule.cc:287 3: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/traced_schedule.cc:678 2: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/concrete_schedule.cc:993 1: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc:1160 0: tvm::tir::LegalizeIndexMapDType(tvm::tir::IndexMap const&, tvm::runtime::Array<tvm::PrimExpr, void> const&) at /workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc:1106 File "/workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc", line 1106 InternalError: Check failed: (args.size() == initial_indices_orig.size()) is false:
2、other message
1). T.index_map(lambda i0, i1, i2, i3, i4, i5: (T.int64(0), i1 * T.int64(64) + i2, i3)) ??? is not match?
2). with T.block("conv2d_nchw", no_realize=True): v_nn = T.axis.spatial(T.int64(1)) v_ff = T.axis.spatial(T.int64(256)) v_yy = T.axis.spatial(T.int64(64)) v_xx = T.axis.spatial(T.int64(64)) v_rc = T.axis.reduce(T.int64(768)) v_ry = T.axis.reduce(T.int64(1)) v_rx = T.axis.reduce(T.int64(1)) pad_temp = T.Buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16") B = T.Buffer((T.int64(256), T.int64(768), T.int64(1), T.int64(1)), "float16") T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], B[v_ff, v_rc, v_ry, v_rx]) conv2d_nchw = T.Buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float16") T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx]) with T.init(): conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float16(0) conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * B[v_ff, v_rc, v_ry, v_rx]
3).
@T.prim_func(private=True) def main(permute_dims161: T.Buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16"), vision_tower_vision_tower_high_neck_0_weight1: T.Buffer((T.int64(256), T.int64(768), T.int64(1), T.int64(1)), "float16"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): pad_temp = T.alloc_buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16") conv2d_nchw_intermediate = T.alloc_buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float16") for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(768), T.int64(64), T.int64(64)): with T.block("pad_temp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(permute_dims161[v_i0, v_i1, v_i2, v_i3]) T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3]) pad_temp[v_i0, v_i1, v_i2, v_i3] = permute_dims161[v_i0, v_i1, v_i2, v_i3] for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(256), T.int64(64), T.int64(64), T.int64(768), T.int64(1), T.int64(1)): with T.block("conv2d_nchw"): v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx]) T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], vision_tower_vision_tower_high_neck_0_weight1[v_ff, v_rc, v_ry, v_rx]) T.writes(conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx]) with T.init(): conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = T.float16(0) conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * vision_tower_vision_tower_high_neck_0_weight1[v_ff, v_rc, v_ry, v_rx] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(256), T.int64(64), T.int64(64)): with T.block("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_nchw_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(compute_intermediate[v_i0, v_i1, v_i2, v_i3]) compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", conv2d_nchw_intermediate[v_i0, v_i1, v_i2, v_i3])
T.index_map(lambda i0, i1, i2, i3, i4, i5: (T.int64(0), i1 * T.int64(64) + i2, i3))
Expected behavior
Environment
- Platform (e.g. CUDA):
- Operating system (e.g. Ubuntu.):
- Device (e.g. orin...)
- How you installed MLC-LLM (
conda
, source): - How you installed TVM-Unity (
pip
, source): - Python version (e.g. 3.10):
- GPU driver version (if applicable):
- CUDA/cuDNN version (if applicable):
- TVM Unity Hash Tag (
python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))"
, applicable if you compile models): - Any other relevant information:
Additional context
Thanks for reporting if it is possibe to get a minimum repro that would be helpful. You can do so by dumping out the TVMScript before the transform, minimize it and run the transform you mentioned
@jpf888 I met the same problem, did you solve it?
Thanks for reporting if it is possibe to get a minimum repro that would be helpful. You can do so by dumping out the TVMScript before the transform, minimize it and run the transform you mentioned
Hi, this bugs can repro like:
from tvm.relax.frontend import nn
class _Conv2d(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.modules.Conv2D(
in_channels=768,
out_channels=256,
kernel_size=1,
padding=0,
bias=False
)
def forward(self, x: nn.Tensor):
return self.conv(x)
from tvm.relax.frontend.nn import spec
forward_spec = {
"forward": {
"x": spec.Tensor([1, 768, 64, 64], dtype="float32")
}
}
_conv2d_mod, params = _Conv2d().export_tvm(
spec=forward_spec,
debug=True
)
mod = _conv2d_mod
def _pipeline(mod):
seq = tvm.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.relax.transform.AnnotateTIROpPattern(),
tvm.relax.transform.FoldConstant(),
tvm.relax.transform.FuseOps(),
tvm.relax.transform.FuseTIR(),
dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)
]
)
mod = seq(mod)
return mod
with tvm.target.Target("nvidia/geforce-rtx-4090", host="llvm"):
mod = _pipeline(mod)
mod.show()
Only when kernel_size is equal to 1, dl.gpu.Matmul
will report an error after tvm.relax.transform.LegalizeOps()
. In a MLLM model, image embedding may be involved, and this operation (kernel_size equals to 1 in conv2d) may be used.
Hi @senlyu163 looks like It's a known issue when applying dlight on conv2d with a kernel size of 1. This issue arises because the reindex schedule performs simplifications on the expr. To address this, I previously created a draft PR. You can merge the relevant changes and modify the normalize_to_matmul
function of dlight
.
checkout this draft pr: https://github.com/apache/tvm/pull/16440
The key component related to this issue is the addition of a skip_simplify
flag to cache_reindex
. You can apply the relevant changes as follows:
def normalize_to_matmul(sch: tir.Schedule,
main_block: BlockRV,
layout: Optional[List[str]] = None) -> Optional[tir.Schedule]:
if layout is None:
layout = ["n", "t", "n"]
block_stmt = sch.get(main_block)
# Let layout be 'a' to auto infer the layout
index_maps = get_index_map(block_stmt, layout=layout)
if index_maps is None:
logger.debug("Cannot find the appropriate index map for tensorcore")
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
# Use `skip_simplify` to avoid the bug in the 1x1 conv
block = sch.reindex(main_block, ("read", 0), skip_simplify=True)
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1), skip_simplify=True)
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0), skip_simplify=True)
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)
sch.mod["main"] = sch.mod["main"].with_attr("dlight.tensorcore_prenormlized", True)
return sch